From 168f251c1c8e1c6292b2687a4ddeaf0a031539ea Mon Sep 17 00:00:00 2001 From: HighCloud Date: Wed, 20 May 2026 15:57:41 +0800 Subject: [PATCH] Fix: align sparse attention seqused semantics --- models/deepseek/v4/attention_csa.py | 17 ++--- models/deepseek/v4/attention_hca.py | 16 ++--- models/deepseek/v4/attention_swa.py | 12 ++-- models/deepseek/v4/decode_csa.py | 4 +- models/deepseek/v4/decode_hca.py | 4 +- models/deepseek/v4/decode_swa.py | 4 +- .../v4/deepseek_v4_decode_single_layer.md | 7 ++- models/deepseek/v4/sparse_attn.py | 63 +++++++++++++------ 8 files changed, 75 insertions(+), 52 deletions(-) diff --git a/models/deepseek/v4/attention_csa.py b/models/deepseek/v4/attention_csa.py index 581c2b5..ba5f383 100644 --- a/models/deepseek/v4/attention_csa.py +++ b/models/deepseek/v4/attention_csa.py @@ -149,7 +149,7 @@ def attention_csa( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], @@ -399,7 +399,7 @@ def attention_csa_test_refresh( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], @@ -524,7 +524,8 @@ def golden_sparse_attn_online(local_tensors): # Per-token gather + attention; token t belongs to batch t // S. for t in range(T): b = t // S - seq_used = int(seqused_kv.view(T)[t].item()) + s = t - b * S + seq_used = int(seqused_kv[b].item()) - S + 1 + s window_valid = min(WIN, seq_used) cmp_valid = max(seq_used - window_valid, 0) cmp_topk_valid = min(IDX_TOPK, cmp_valid) @@ -723,7 +724,7 @@ def golden_sparse_attn_online(local_tensors): "cmp_block_table": cmp_block_table, "cmp_sparse_indices": sparse_topk, "attn_sink": tensors["attn_sink"], - "seqused_kv": tensors["seqused_kv"].view(B, S), + "seqused_kv": tensors["seqused_kv"], "freqs_cos": rope_cos_t, "freqs_sin": rope_sin_t, "even_select_local": tensors["even_select_local"], @@ -928,9 +929,9 @@ def init_cmp_sparse_indices(): return torch.full((T, SPARSE_TOPK), -1, dtype=torch.int32) def init_seqused_kv(): - s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS - seq = torch.where(s <= WIN, s, WIN + s // COMPRESS_RATIO) - return seq.expand(B, S).clone() + seq = START_POS + S + sparse_len = seq if seq <= WIN else WIN + seq // COMPRESS_RATIO + return torch.full((B,), sparse_len, dtype=torch.int32) def init_wo_a(): return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5 @@ -1014,7 +1015,7 @@ def init_wo_b(): TensorSpec("cmp_block_table", [B, CMP_MAX_BLOCKS], torch.int32, init_value=init_cmp_block_table), TensorSpec("idx_kv_cache", [B, IDX_KV_LEN, IDX_HEAD_DIM], torch.bfloat16, init_value=lambda: shared_idx_kv_cache.clone()), TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink), - TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv), + TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv), TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a), TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8), TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale), diff --git a/models/deepseek/v4/attention_hca.py b/models/deepseek/v4/attention_hca.py index 3f3f787..363ec84 100644 --- a/models/deepseek/v4/attention_hca.py +++ b/models/deepseek/v4/attention_hca.py @@ -111,7 +111,7 @@ def attention_hca( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], # sparse_attn attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # o_proj (fused into sparse_attn) wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -352,7 +352,7 @@ def attention_hca_test( cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], @@ -516,7 +516,7 @@ def golden_attention_hca(tensors): "cmp_block_table": cmp_block_table, "cmp_sparse_indices": topk_idxs, "attn_sink": tensors["attn_sink"], - "seqused_kv": tensors["seqused_kv"].view(B, S), + "seqused_kv": tensors["seqused_kv"], "freqs_cos": rope_cos_T, "freqs_sin": rope_sin_T, "even_select_local": tensors["even_select_local"], @@ -656,10 +656,10 @@ def init_cmp_block_table(): def init_attn_sink(): return torch.zeros(H) def init_seqused_kv(): - # sparse_attn uses: window_valid = min(WIN, seq_used); cmp_valid = seq_used - window_valid. - s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS - seq = torch.where(s <= WIN, s, WIN + s // COMPRESS_RATIO) - return seq.expand(B, S).clone() + # sparse_attn derives per-token causal lengths from this final sparse length. + seq = START_POS + S + sparse_len = seq if seq <= WIN else WIN + seq // COMPRESS_RATIO + return torch.full((B,), sparse_len, dtype=torch.int32) def init_wo_a(): return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5 def init_wo_b(): @@ -702,7 +702,7 @@ def init_wo_b(): TensorSpec("cmp_kv", [CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_cmp_kv), TensorSpec("cmp_block_table", [B, CMP_MAX_BLOCKS], torch.int32, init_value=init_cmp_block_table), TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink), - TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv), + TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv), TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a), TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8), TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale), diff --git a/models/deepseek/v4/attention_swa.py b/models/deepseek/v4/attention_swa.py index c0e5100..a43427f 100644 --- a/models/deepseek/v4/attention_swa.py +++ b/models/deepseek/v4/attention_swa.py @@ -92,7 +92,7 @@ def attention_swa( block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], # sparse_attn attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # o_proj wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -249,7 +249,7 @@ def attention_swa_test( block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], # sparse_attn attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # o_proj wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -367,7 +367,7 @@ def golden_attention_swa(tensors): "cmp_block_table": cmp_block_table_dummy, "cmp_sparse_indices": sparse_topk, "attn_sink": tensors["attn_sink"], - "seqused_kv": seqused_kv.view(B, S), + "seqused_kv": seqused_kv, "freqs_cos": rope_cos_T, "freqs_sin": rope_sin_T, "even_select_local": tensors["even_select_local"], @@ -476,9 +476,7 @@ def init_block_table(): def init_attn_sink(): return torch.zeros(H) def init_seqused_kv(): - s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS - seq = torch.minimum(torch.full_like(s, WIN), s) - return seq.expand(B, S).clone() + return torch.full((B,), min(WIN, START_POS + S), dtype=torch.int32) def init_wo_a(): return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5 def init_wo_b(): @@ -511,7 +509,7 @@ def init_wo_b(): TensorSpec("kv_cache", [BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache), TensorSpec("block_table", [B, MAX_BLOCKS], torch.int32, init_value=init_block_table), TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink), - TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv), + TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv), TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a), TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8), TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale), diff --git a/models/deepseek/v4/decode_csa.py b/models/deepseek/v4/decode_csa.py index 0e2065a..9f13149 100644 --- a/models/deepseek/v4/decode_csa.py +++ b/models/deepseek/v4/decode_csa.py @@ -146,7 +146,7 @@ def decode_csa( idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], # ---- sparse_attn ---- attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # ---- o_proj weights ---- wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -263,7 +263,7 @@ def decode_csa_test( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], diff --git a/models/deepseek/v4/decode_hca.py b/models/deepseek/v4/decode_hca.py index aa26e74..f8d1aee 100644 --- a/models/deepseek/v4/decode_hca.py +++ b/models/deepseek/v4/decode_hca.py @@ -121,7 +121,7 @@ def decode_hca( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], # ---- sparse_attn ---- attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # ---- o_proj weights ---- wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -224,7 +224,7 @@ def decode_hca_test( cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], diff --git a/models/deepseek/v4/decode_swa.py b/models/deepseek/v4/decode_swa.py index e7fda53..5baedd0 100644 --- a/models/deepseek/v4/decode_swa.py +++ b/models/deepseek/v4/decode_swa.py @@ -92,7 +92,7 @@ def decode_swa( block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], # ---- sparse_attn ---- attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], # ---- o_proj weights ---- wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], @@ -181,7 +181,7 @@ def decode_swa_test( kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], wo_b_scale: pl.Tensor[[D], pl.FP32], diff --git a/models/deepseek/v4/deepseek_v4_decode_single_layer.md b/models/deepseek/v4/deepseek_v4_decode_single_layer.md index 59548cc..a1316be 100644 --- a/models/deepseek/v4/deepseek_v4_decode_single_layer.md +++ b/models/deepseek/v4/deepseek_v4_decode_single_layer.md @@ -143,15 +143,16 @@ sub-kernels below; the variant determines whether `compressor` and ║ model.py:533-534, 537-542 ║ ║ NOTE: outer loop is `for t in pl.range(T)` — per-query-token attention. ║ ║ Token t belongs to batch b = t // S and step s = t % S; each token ║ -║ reads seqused_kv[b, s]. Intra-query causal is enforced upstream by ║ -║ the topk index set the indexer produces per token. ║ +║ derives its sparse length from final seqused_kv[b] as ║ +║ `seqused_kv[b] - S + 1 + s`. Intra-query causal is enforced by this ║ +║ per-token derived length plus the topk index set. ║ ║ ║ ║ IN : q [T, H, HEAD_DIM] ║ ║ ori_kv (PA) — always ║ ║ cmp_kv (PA) — ratio>0 only ║ ║ topk_idxs [T, *] — per-token; ratio-dependent, see § below ║ ║ attn_sink [H] fp32 ║ -║ seqused_kv [B, S] — per-token valid sparse KV length ║ +║ seqused_kv [B] — final valid sparse KV length per batch ║ ║ freqs_cos/sin [T, ROPE_DIM] ║ ║ wo_a [O_GROUPS=8, O_LORA=1024, 4096] bf16 (grouped output LoRA) ║ ║ wo_b [D=4096, O_GROUPS*O_LORA=8192] int8 ║ diff --git a/models/deepseek/v4/sparse_attn.py b/models/deepseek/v4/sparse_attn.py index 65b6da8..75ae36f 100644 --- a/models/deepseek/v4/sparse_attn.py +++ b/models/deepseek/v4/sparse_attn.py @@ -30,7 +30,8 @@ - `cmp_sparse_indices[t, :]` may contain `-1` pads. - entries in `[0, WIN)` address the logical sliding-window ring slots. - entries in `[WIN, WIN + cmp_valid)` address compressed cache slots, where - `cmp_valid = max(seqused_kv[b, s] - min(WIN, seqused_kv[b, s]), 0)`. + `seq_used = seqused_kv[b] - S + 1 + s` + and `cmp_valid = max(seq_used - min(WIN, seq_used), 0)`. - the grouped projection layout matches: `o.view(bsz, seqlen, self.n_local_groups, -1)` with `self.n_local_groups == O_GROUPS` and `-1 == O_GROUP_IN`. @@ -110,7 +111,7 @@ def sparse_attn( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], even_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], @@ -132,7 +133,6 @@ def sparse_attn( cmp_kv_flat = pl.reshape(cmp_kv, [CMP_BLOCK_NUM * BLOCK_SIZE, HEAD_DIM]) cmp_block_table_flat = pl.reshape(cmp_block_table, [B * CMP_MAX_BLOCKS]) cmp_sparse_indices_flat = pl.reshape(cmp_sparse_indices, [T * TOPK]) - seqused_kv_flat = pl.reshape(seqused_kv, [T]) sparse_kv = pl.create_tensor([T * TOPK, HEAD_DIM], dtype=pl.BF16) attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.BF16) sparse_exp = pl.create_tensor([T * H * SPARSE_ATTN_BLOCKS, SPARSE_ATTN_TILE], dtype=pl.BF16) @@ -157,7 +157,9 @@ def sparse_attn( for gather_dt in pl.range(GATHER_TOKEN_TILE): gather_t = gather_t0 + gather_dt gather_b = gather_t // S - gather_seq_used = pl.read(seqused_kv_flat, [gather_t]) + gather_s = gather_t - gather_b * S + gather_seq_final = pl.read(seqused_kv, [gather_b]) + gather_seq_used = gather_seq_final - S + 1 + gather_s gather_window_valid = pl.min(WIN, gather_seq_used) gather_cmp_valid = gather_seq_used - gather_window_valid gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid) @@ -207,7 +209,10 @@ def sparse_attn( with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_qk_softmax_tile"): for qk_dt in pl.range(ATTN_TOKEN_TILE): qk_t = attn_t0 + qk_dt - qk_seq_used = pl.read(seqused_kv_flat, [qk_t]) + qk_b = qk_t // S + qk_s = qk_t - qk_b * S + qk_seq_final = pl.read(seqused_kv, [qk_b]) + qk_seq_used = qk_seq_final - S + 1 + qk_s qk_window_valid = pl.min(WIN, qk_seq_used) qk_cmp_valid = qk_seq_used - qk_window_valid qk_cmp_topk_valid = pl.min(IDX_TOPK, qk_cmp_valid) @@ -246,7 +251,10 @@ def sparse_attn( with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_pv_tile"): for pv_dt in pl.range(ATTN_TOKEN_TILE): pv_t = attn_t0 + pv_dt - pv_seq_used = pl.read(seqused_kv_flat, [pv_t]) + pv_b = pv_t // S + pv_s = pv_t - pv_b * S + pv_seq_final = pl.read(seqused_kv, [pv_b]) + pv_seq_used = pv_seq_final - S + 1 + pv_s pv_window_valid = pl.min(WIN, pv_seq_used) pv_cmp_valid = pv_seq_used - pv_window_valid pv_cmp_topk_valid = pl.min(IDX_TOPK, pv_cmp_valid) @@ -268,7 +276,10 @@ def sparse_attn( with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_merge_tile"): for merge_dt in pl.range(ATTN_TOKEN_TILE): merge_t = attn_t0 + merge_dt - merge_seq_used = pl.read(seqused_kv_flat, [merge_t]) + merge_b = merge_t // S + merge_s = merge_t - merge_b * S + merge_seq_final = pl.read(seqused_kv, [merge_b]) + merge_seq_used = merge_seq_final - S + 1 + merge_s merge_window_valid = pl.min(WIN, merge_seq_used) merge_cmp_valid = merge_seq_used - merge_window_valid merge_cmp_topk_valid = pl.min(IDX_TOPK, merge_cmp_valid) @@ -533,7 +544,7 @@ def sparse_attn_test( cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B, S], pl.INT32], + seqused_kv: pl.Tensor[[B], pl.INT32], freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], even_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], @@ -608,12 +619,12 @@ def golden_sparse_attn(tensors): o = torch.zeros(T, H, HEAD_DIM) - # Per-query-token attention. Each token has its own cmp_sparse_indices row - # and seqused_kv entry; block tables stay per-batch. + # Per-query-token attention. seqused_kv stores each batch's final sparse + # length for this decode chunk; token t derives its causal length from s. for t in range(T): b = t // S s = t - b * S - seq_used = int(seqused_kv[b, s].item()) + seq_used = int(seqused_kv[b].item()) - S + 1 + s window_valid = min(WIN, seq_used) cmp_valid = max(seq_used - window_valid, 0) gathered = [] @@ -670,7 +681,10 @@ def golden_sparse_attn(tensors): tensors["attn_out"][:] = out.to(torch.bfloat16) -def build_tensor_specs(compress_ratio: int = DEFAULT_COMPRESS_RATIO): +def build_tensor_specs( + compress_ratio: int = DEFAULT_COMPRESS_RATIO, + causal_regression_fixture: bool = False, +): """Build deterministic demo tensors for the merged standalone harness.""" import torch from golden import TensorSpec @@ -686,11 +700,17 @@ def seeded_uniform(shape, seed): def init_q(): """Initialize the query tensor used by the decode attention stage.""" - return seeded_uniform((T, H, HEAD_DIM), 1) + q = seeded_uniform((T, H, HEAD_DIM), 1) + if causal_regression_fixture: + q[0].fill_(1.0) + return q def init_ori_kv(): """Initialize the sliding-window KV cache pages.""" - return seeded_uniform((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), 2) + kv = seeded_uniform((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), 2) + if causal_regression_fixture: + kv[0, WIN - 1, 0].fill_(8.0) + return kv def init_cmp_kv(): """Initialize the compressed-cache KV pages.""" @@ -721,13 +741,14 @@ def init_cmp_sparse_indices(): win_part = torch.arange(WIN, dtype=torch.int32).unsqueeze(0).expand(T, -1) cmp_part = torch.full((T, IDX_TOPK), -1, dtype=torch.int32) cmp_part[:, :cmp_valid] = (torch.arange(cmp_valid, dtype=torch.int32) + WIN).unsqueeze(0).expand(T, -1) - return torch.cat([win_part, cmp_part], dim=-1).contiguous() + indices = torch.cat([win_part, cmp_part], dim=-1).contiguous() + if causal_regression_fixture: + indices[0, WIN - 1] = WIN - 1 + return indices def init_seqused_kv(): """Expose the demo sequence-used length that matches the chosen ratio mode.""" - token_s = torch.arange(S, dtype=torch.int32) - seq = torch.clamp(sparse_k - (S - 1 - token_s), min=1) - return seq.expand(B, S).clone() + return torch.full((B,), sparse_k, dtype=torch.int32) def init_cos(): """Build the split-half cosine table used by the inverse-RoPE reference.""" @@ -778,7 +799,7 @@ def init_wo_b_scale(): TensorSpec("cmp_block_table", [B, CMP_MAX_BLOCKS], torch.int32, init_value=init_cmp_block_table), TensorSpec("cmp_sparse_indices", [T, TOPK], torch.int32, init_value=init_cmp_sparse_indices), TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink), - TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv), + TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv), TensorSpec("freqs_cos", [T, ROPE_DIM], torch.bfloat16, init_value=init_cos), TensorSpec("freqs_sin", [T, ROPE_DIM], torch.bfloat16, init_value=init_sin), TensorSpec("even_select_local", [ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local), @@ -800,13 +821,15 @@ def init_wo_b_scale(): parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--compress-ratio", type=int, default=DEFAULT_COMPRESS_RATIO, choices=list(SUPPORTED_COMPRESS_RATIOS)) + parser.add_argument("--causal-regression-fixture", action="store_true", default=False, + help="Amplify the S=2 future-window-slot regression; use with --compress-ratio 0.") parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) parser.add_argument("--enable-pmu", nargs="?", const=2, default=0, type=int, choices=[0, 1, 2, 4]) args = parser.parse_args() result = run_jit( fn=sparse_attn_test, - specs=build_tensor_specs(args.compress_ratio), + specs=build_tensor_specs(args.compress_ratio, args.causal_regression_fixture), golden_fn=golden_sparse_attn, runtime_cfg=dict( platform=args.platform,