Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions models/deepseek/v4/attention_csa.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def attention_csa(
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down Expand Up @@ -399,7 +399,7 @@ def attention_csa_test_refresh(
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down Expand Up @@ -524,7 +524,8 @@ def golden_sparse_attn_online(local_tensors):
# Per-token gather + attention; token t belongs to batch t // S.
for t in range(T):
b = t // S
seq_used = int(seqused_kv.view(T)[t].item())
s = t - b * S
seq_used = int(seqused_kv[b].item()) - S + 1 + s
window_valid = min(WIN, seq_used)
cmp_valid = max(seq_used - window_valid, 0)
cmp_topk_valid = min(IDX_TOPK, cmp_valid)
Expand Down Expand Up @@ -723,7 +724,7 @@ def golden_sparse_attn_online(local_tensors):
"cmp_block_table": cmp_block_table,
"cmp_sparse_indices": sparse_topk,
"attn_sink": tensors["attn_sink"],
"seqused_kv": tensors["seqused_kv"].view(B, S),
"seqused_kv": tensors["seqused_kv"],
"freqs_cos": rope_cos_t,
"freqs_sin": rope_sin_t,
"even_select_local": tensors["even_select_local"],
Expand Down Expand Up @@ -928,9 +929,9 @@ def init_cmp_sparse_indices():
return torch.full((T, SPARSE_TOPK), -1, dtype=torch.int32)

def init_seqused_kv():
s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS
seq = torch.where(s <= WIN, s, WIN + s // COMPRESS_RATIO)
return seq.expand(B, S).clone()
seq = START_POS + S
sparse_len = seq if seq <= WIN else WIN + seq // COMPRESS_RATIO
return torch.full((B,), sparse_len, dtype=torch.int32)

def init_wo_a():
return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5
Expand Down Expand Up @@ -1014,7 +1015,7 @@ def init_wo_b():
TensorSpec("cmp_block_table", [B, CMP_MAX_BLOCKS], torch.int32, init_value=init_cmp_block_table),
TensorSpec("idx_kv_cache", [B, IDX_KV_LEN, IDX_HEAD_DIM], torch.bfloat16, init_value=lambda: shared_idx_kv_cache.clone()),
TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink),
TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv),
TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv),
TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a),
TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8),
TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale),
Expand Down
16 changes: 8 additions & 8 deletions models/deepseek/v4/attention_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def attention_hca(
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
# sparse_attn
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# o_proj (fused into sparse_attn)
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -352,7 +352,7 @@ def attention_hca_test(
cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16],
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down Expand Up @@ -516,7 +516,7 @@ def golden_attention_hca(tensors):
"cmp_block_table": cmp_block_table,
"cmp_sparse_indices": topk_idxs,
"attn_sink": tensors["attn_sink"],
"seqused_kv": tensors["seqused_kv"].view(B, S),
"seqused_kv": tensors["seqused_kv"],
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"even_select_local": tensors["even_select_local"],
Expand Down Expand Up @@ -656,10 +656,10 @@ def init_cmp_block_table():
def init_attn_sink():
return torch.zeros(H)
def init_seqused_kv():
# sparse_attn uses: window_valid = min(WIN, seq_used); cmp_valid = seq_used - window_valid.
s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS
seq = torch.where(s <= WIN, s, WIN + s // COMPRESS_RATIO)
return seq.expand(B, S).clone()
# sparse_attn derives per-token causal lengths from this final sparse length.
seq = START_POS + S
sparse_len = seq if seq <= WIN else WIN + seq // COMPRESS_RATIO
return torch.full((B,), sparse_len, dtype=torch.int32)
def init_wo_a():
return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5
def init_wo_b():
Expand Down Expand Up @@ -702,7 +702,7 @@ def init_wo_b():
TensorSpec("cmp_kv", [CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_cmp_kv),
TensorSpec("cmp_block_table", [B, CMP_MAX_BLOCKS], torch.int32, init_value=init_cmp_block_table),
TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink),
TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv),
TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv),
TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a),
TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8),
TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale),
Expand Down
12 changes: 5 additions & 7 deletions models/deepseek/v4/attention_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def attention_swa(
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
# sparse_attn
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# o_proj
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -249,7 +249,7 @@ def attention_swa_test(
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
# sparse_attn
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# o_proj
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -367,7 +367,7 @@ def golden_attention_swa(tensors):
"cmp_block_table": cmp_block_table_dummy,
"cmp_sparse_indices": sparse_topk,
"attn_sink": tensors["attn_sink"],
"seqused_kv": seqused_kv.view(B, S),
"seqused_kv": seqused_kv,
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"even_select_local": tensors["even_select_local"],
Expand Down Expand Up @@ -476,9 +476,7 @@ def init_block_table():
def init_attn_sink():
return torch.zeros(H)
def init_seqused_kv():
s = torch.arange(1, S + 1, dtype=torch.int32) + START_POS
seq = torch.minimum(torch.full_like(s, WIN), s)
return seq.expand(B, S).clone()
return torch.full((B,), min(WIN, START_POS + S), dtype=torch.int32)
def init_wo_a():
return torch.randn(O_GROUPS, O_LORA, O_GROUP_IN) / O_GROUP_IN ** 0.5
def init_wo_b():
Expand Down Expand Up @@ -511,7 +509,7 @@ def init_wo_b():
TensorSpec("kv_cache", [BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache),
TensorSpec("block_table", [B, MAX_BLOCKS], torch.int32, init_value=init_block_table),
TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink),
TensorSpec("seqused_kv", [B, S], torch.int32, init_value=init_seqused_kv),
TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv),
TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a),
TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=lambda: wo_b_i8),
TensorSpec("wo_b_scale", [D], torch.float32, init_value=lambda: wo_b_scale),
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/decode_csa.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def decode_csa(
idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16],
# ---- sparse_attn ----
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# ---- o_proj weights ----
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -263,7 +263,7 @@ def decode_csa_test(
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
idx_kv_cache: pl.Tensor[[B, IDX_KV_LEN, IDX_HEAD_DIM], pl.BF16],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/decode_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def decode_hca(
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
# ---- sparse_attn ----
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# ---- o_proj weights ----
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -224,7 +224,7 @@ def decode_hca_test(
cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16],
cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/decode_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def decode_swa(
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
# ---- sparse_attn ----
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
# ---- o_proj weights ----
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
Expand Down Expand Up @@ -181,7 +181,7 @@ def decode_swa_test(
kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16],
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
attn_sink: pl.Tensor[[H], pl.FP32],
seqused_kv: pl.Tensor[[B, S], pl.INT32],
seqused_kv: pl.Tensor[[B], pl.INT32],
wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16],
wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8],
wo_b_scale: pl.Tensor[[D], pl.FP32],
Expand Down
7 changes: 4 additions & 3 deletions models/deepseek/v4/deepseek_v4_decode_single_layer.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,16 @@ sub-kernels below; the variant determines whether `compressor` and
║ model.py:533-534, 537-542 ║
║ NOTE: outer loop is `for t in pl.range(T)` — per-query-token attention. ║
║ Token t belongs to batch b = t // S and step s = t % S; each token ║
║ reads seqused_kv[b, s]. Intra-query causal is enforced upstream by ║
║ the topk index set the indexer produces per token. ║
║ derives its sparse length from final seqused_kv[b] as ║
║ `seqused_kv[b] - S + 1 + s`. Intra-query causal is enforced by this ║
║ per-token derived length plus the topk index set. ║
║ ║
║ IN : q [T, H, HEAD_DIM] ║
║ ori_kv (PA) — always ║
║ cmp_kv (PA) — ratio>0 only ║
║ topk_idxs [T, *] — per-token; ratio-dependent, see § below ║
║ attn_sink [H] fp32 ║
║ seqused_kv [B, S] — per-token valid sparse KV length
║ seqused_kv [B] — final valid sparse KV length per batch
║ freqs_cos/sin [T, ROPE_DIM] ║
║ wo_a [O_GROUPS=8, O_LORA=1024, 4096] bf16 (grouped output LoRA) ║
║ wo_b [D=4096, O_GROUPS*O_LORA=8192] int8 ║
Expand Down
Loading
Loading