Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 25 additions & 24 deletions models/deepseek/v3_2/deepseek_v3_2_decode_front.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,14 @@ def deepseek_v3_2_decode_front_scope1234(
for b in pl.parallel(BATCH):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope"):
ctx_len = pl.read(seq_lens, [b])
cos_lo = rope_cos[ctx_len - 1, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[ctx_len - 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[ctx_len - 1, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[ctx_len - 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
pos = ctx_len - 1
cos_lo = rope_cos[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
for q_col in pl.range(QK_NOPE_HEAD_DIM, NUM_HEADS * QK_HEAD_DIM, QK_HEAD_DIM):
q_lo = pl.cast(q_proj[b, q_col : q_col + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
q_hi = pl.cast(q_proj[b, q_col + QK_ROPE_HEAD_DIM // 2 : q_col + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
q_lo = pl.cast(q_proj[b : b + 1, q_col : q_col + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
q_hi = pl.cast(q_proj[b : b + 1, q_col + QK_ROPE_HEAD_DIM // 2 : q_col + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
q_rot_lo = pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo))
q_rot_hi = pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi))
q_proj = pl.assemble(q_proj, pl.cast(q_rot_lo, target_type=pl.BF16), [b, q_col])
Expand All @@ -230,13 +231,13 @@ def deepseek_v3_2_decode_front_scope1234(
ctx_len = pl.read(seq_lens, [b])
pos = ctx_len - 1
cache_row = b * MAX_SEQ + pos
cos_lo = rope_cos[pos, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[pos, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[pos, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[pos, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
kv_normed_row = kv_normed_out[b, 0 : KV_LORA_RANK]
pe_lo = pl.cast(kv_a_out[b, KV_LORA_RANK : KV_LORA_RANK + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
pe_hi = pl.cast(kv_a_out[b, KV_LORA_RANK + QK_ROPE_HEAD_DIM // 2 : KV_LORA_RANK + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
cos_lo = rope_cos[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
kv_normed_row = kv_normed_out[b : b + 1, 0 : KV_LORA_RANK]
pe_lo = pl.cast(kv_a_out[b : b + 1, KV_LORA_RANK : KV_LORA_RANK + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
pe_hi = pl.cast(kv_a_out[b : b + 1, KV_LORA_RANK + QK_ROPE_HEAD_DIM // 2 : KV_LORA_RANK + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
pe_rot_lo = pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo))
pe_rot_hi = pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi))
kv_cache = pl.assemble(kv_cache, kv_normed_row, [cache_row, 0])
Expand Down Expand Up @@ -303,19 +304,19 @@ def deepseek_v3_2_decode_front_scope1234(
for b in pl.parallel(BATCH):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="s2_idx_rope"):
pos = pl.read(seq_lens, [b]) - 1
cos_lo = rope_cos[pos, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[pos, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[pos, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[pos, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
cos_lo = rope_cos[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
cos_hi = rope_cos[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
sin_lo = rope_sin[pos : pos + 1, 0 : QK_ROPE_HEAD_DIM // 2]
sin_hi = rope_sin[pos : pos + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM]
for q_col in pl.range(0, INDEX_Q_OUT, INDEX_HEAD_DIM):
s2_q_lo = pl.cast(q_idx_full[b, q_col : q_col + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
s2_q_hi = pl.cast(q_idx_full[b, q_col + QK_ROPE_HEAD_DIM // 2 : q_col + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
s2_q_lo = pl.cast(q_idx_full[b : b + 1, q_col : q_col + QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
s2_q_hi = pl.cast(q_idx_full[b : b + 1, q_col + QK_ROPE_HEAD_DIM // 2 : q_col + QK_ROPE_HEAD_DIM], target_type=pl.FP32)
s2_q_rot_lo = pl.sub(pl.col_expand_mul(s2_q_lo, cos_lo), pl.col_expand_mul(s2_q_hi, sin_lo))
s2_q_rot_hi = pl.add(pl.col_expand_mul(s2_q_hi, cos_hi), pl.col_expand_mul(s2_q_lo, sin_hi))
q_idx_full = pl.assemble(q_idx_full, pl.cast(s2_q_rot_lo, target_type=pl.BF16), [b, q_col])
q_idx_full = pl.assemble(q_idx_full, pl.cast(s2_q_rot_hi, target_type=pl.BF16), [b, q_col + QK_ROPE_HEAD_DIM // 2])
s2_k_lo = pl.cast(k_idx[b, 0 : QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
s2_k_hi = pl.cast(k_idx[b, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM], target_type=pl.FP32)
s2_k_lo = pl.cast(k_idx[b : b + 1, 0 : QK_ROPE_HEAD_DIM // 2], target_type=pl.FP32)
s2_k_hi = pl.cast(k_idx[b : b + 1, QK_ROPE_HEAD_DIM // 2 : QK_ROPE_HEAD_DIM], target_type=pl.FP32)
s2_k_rot_lo = pl.sub(pl.col_expand_mul(s2_k_lo, cos_lo), pl.col_expand_mul(s2_k_hi, sin_lo))
s2_k_rot_hi = pl.add(pl.col_expand_mul(s2_k_hi, cos_hi), pl.col_expand_mul(s2_k_lo, sin_hi))
k_idx = pl.assemble(k_idx, pl.cast(s2_k_rot_lo, target_type=pl.BF16), [b, 0])
Expand Down Expand Up @@ -435,7 +436,7 @@ def deepseek_v3_2_decode_front_scope1234(

# Stage 3.3: Run sort32 + mrgsort.
with pl.at(level=pl.Level.CORE_GROUP, name_hint="s3_sort"):
s3_score_row = scores[b, :]
s3_score_row = scores[b : b + 1, :]
idx_init = pl.tensor.arange(0, [1, SORT_LEN], dtype=pl.UINT32)
s3_sorted_t = pl.tensor.sort32(s3_score_row, idx_init)
s3_sorted_t = pl.tensor.mrgsort(s3_sorted_t, block_len=64)
Expand Down Expand Up @@ -472,7 +473,7 @@ def deepseek_v3_2_decode_front_scope1234(

# Stage 4.1: Load q_pe and project q_nope into latent space.
with pl.at(level=pl.Level.CORE_GROUP, name_hint="s4_q_pe_load"):
q_pe = pl.cast(q_proj[b, q_col + QK_NOPE_HEAD_DIM : q_col + QK_HEAD_DIM], target_type=pl.FP32)
q_pe = pl.cast(q_proj[b : b + 1, q_col + QK_NOPE_HEAD_DIM : q_col + QK_HEAD_DIM], target_type=pl.FP32)
q_pe_batch = pl.col_expand(
pl.full([MATMUL_ROW_PAD, QK_ROPE_HEAD_DIM], dtype=pl.FP32, value=0.0),
q_pe,
Expand All @@ -481,7 +482,7 @@ def deepseek_v3_2_decode_front_scope1234(
pl.full([MATMUL_ROW_PAD, QK_NOPE_HEAD_DIM], dtype=pl.FP32, value=0.0),
target_type=pl.BF16,
)
q_nope_padded = pl.col_expand(q_nope_padded, q_proj[b, q_col : q_col + QK_NOPE_HEAD_DIM])
q_nope_padded = pl.col_expand(q_nope_padded, q_proj[b : b + 1, q_col : q_col + QK_NOPE_HEAD_DIM])
q_nope_latent_batch = pl.full([MATMUL_ROW_PAD, KV_LORA_RANK], dtype=pl.FP32, value=0.0)

# Stage 4.2: Project q_nope to latent space chunk-by-chunk.
Expand Down
30 changes: 23 additions & 7 deletions models/deepseek/v4/sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def sparse_attn(
o_proj_odd = pl.create_tensor([T * H, HALF_ROPE], dtype=pl.FP32)
rope_even = pl.create_tensor([T * H, HALF_ROPE], dtype=pl.BF16)
rope_odd = pl.create_tensor([T * H, HALF_ROPE], dtype=pl.BF16)
rope_even_interleave_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
rope_odd_interleave_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32)
o_rope_interleave = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.BF16)
o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16)

Expand Down Expand Up @@ -254,7 +256,7 @@ def sparse_attn(
rope_even = pl.assemble(rope_even, pl.cast(rope_even_acc, target_type=pl.BF16, mode="rint"), [rope_head_row, 0])
rope_odd = pl.assemble(rope_odd, pl.cast(rope_odd_acc, target_type=pl.BF16, mode="rint"), [rope_head_row, 0])

with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_rope_assemble"):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_rope_assemble_matmul"):
for r0 in pl.range(0, HALF_ROPE, ROPE_CHUNK):
rope_even_chunk = rope_even[rope_head_row : rope_head_row + H, r0 : r0 + ROPE_CHUNK]
rope_odd_chunk = rope_odd[rope_head_row : rope_head_row + H, r0 : r0 + ROPE_CHUNK]
Expand All @@ -270,16 +272,30 @@ def sparse_attn(
b_trans=True,
out_dtype=pl.FP32,
)
rope_chunk = pl.cast(
pl.add(rope_even_interleave, rope_odd_interleave),
target_type=pl.BF16,
rope_even_interleave_buf = pl.assemble(
rope_even_interleave_buf,
rope_even_interleave,
[rope_head_row, 2 * r0],
)
o_rope_interleave = pl.assemble(
o_rope_interleave,
rope_chunk,
rope_odd_interleave_buf = pl.assemble(
rope_odd_interleave_buf,
rope_odd_interleave,
[rope_head_row, 2 * r0],
)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_rope_assemble_combine"):
rope_even_tile = rope_even_interleave_buf[rope_head_row : rope_head_row + H, 0 : ROPE_DIM]
rope_odd_tile = rope_odd_interleave_buf[rope_head_row : rope_head_row + H, 0 : ROPE_DIM]
rope_full = pl.cast(
pl.add(rope_even_tile, rope_odd_tile),
target_type=pl.BF16,
)
o_rope_interleave = pl.assemble(
o_rope_interleave,
rope_full,
[rope_head_row, 0],
)

for b in pl.range(B):
for h in pl.parallel(0, H, 1):
pack_head_row = b * H + h
Expand Down
8 changes: 4 additions & 4 deletions models/qwen3/32b/qwen3_32b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def qwen3_decode(
ctx_len = pl.read(seq_lens, [b])
pos = ctx_len - 1
ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE
cos_lo = rope_cos[pos, 0 : HALF_DIM]
cos_hi = rope_cos[pos, HALF_DIM : HEAD_DIM]
sin_lo = rope_sin[pos, 0 : HALF_DIM]
sin_hi = rope_sin[pos, HALF_DIM : HEAD_DIM]
cos_lo = rope_cos[pos : pos + 1, 0 : HALF_DIM]
cos_hi = rope_cos[pos : pos + 1, HALF_DIM : HEAD_DIM]
sin_lo = rope_sin[pos : pos + 1, 0 : HALF_DIM]
sin_hi = rope_sin[pos : pos + 1, HALF_DIM : HEAD_DIM]

# Stage 1: K RoPE + cache update + V cache + Q RoPE + pad.
with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_kv_cache"):
Expand Down
Loading