Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/deepseek/v4/attention_csa.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import pypto.language as pl

from config import DEMO as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS
from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS
from compressor_ratio4 import compressor
from hc_post import hc_post
from hc_pre import hc_pre
Expand Down
6 changes: 3 additions & 3 deletions models/deepseek/v4/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pypto.language as pl

from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, FP32_NEG_INF, INT8_SCALE_MAX, INT8_AMAX_EPS
from indexer_compressor import compressor
from indexer_compressor import indexer_compressor

# model config
B = DECODE_BATCH
Expand All @@ -32,7 +32,7 @@

# kernel-local
COMPRESS_RATIO = 4 # the indexer only runs on ratio-4 layers
IDX_TOPK = 16 # standalone-test scale; model value is M.index_topk (512 flash / 1024 pro)
IDX_TOPK = M.index_topk


INNER_ROTATE = True
Expand Down Expand Up @@ -238,7 +238,7 @@ def indexer(
weights_scale = pl.mul(weights_acc, WEIGHTS_SCALE)
weights = pl.assemble(weights, weights_scale, [0, h0])

inner_kv, inner_kv_state, inner_score_state, idx_kv_cache = compressor(
inner_kv, inner_kv_state, inner_score_state, idx_kv_cache = indexer_compressor(
x,
inner_kv,
inner_kv_state,
Expand Down
4 changes: 2 additions & 2 deletions models/deepseek/v4/indexer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@


@pl.jit.inline
def compressor(
def indexer_compressor(
x: pl.Tensor[[B, S, D], pl.BF16],
kv: pl.Tensor[[B, S, HEAD_DIM], pl.FP32],
kv_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32],
Expand Down Expand Up @@ -290,7 +290,7 @@ def compressor_test(
start_pos: pl.Scalar[pl.INT32],
rotate: pl.Scalar[pl.BOOL],
):
kv, kv_state, score_state, kv_cache = compressor(
kv, kv_state, score_state, kv_cache = indexer_compressor(
x, kv, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, even_select, odd_select, hadamard, kv_cache, start_pos, rotate
)
return kv, kv_state, score_state, kv_cache
Expand Down
15 changes: 8 additions & 7 deletions models/deepseek/v4/sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ def sparse_attn(
[window_valid + kk, 0],
)

for h in pl.parallel(0, H, 1):
attn_head_row = b * H + h
for h0 in pl.parallel(0, H, MATMUL_ROW_PAD):
attn_head_row = b * H + h0
with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_init"):
q_batch = pl.col_expand(
pl.full([MATMUL_ROW_PAD, HEAD_DIM], dtype=pl.FP32, value=0.0),
pl.cast(q_flat[attn_head_row : attn_head_row + 1, 0 : HEAD_DIM], target_type=pl.FP32),
q_batch = pl.cast(
q_flat[attn_head_row : attn_head_row + MATMUL_ROW_PAD, 0 : HEAD_DIM],
target_type=pl.FP32,
)

kv_batch = pl.col_expand(
Expand Down Expand Up @@ -212,11 +212,12 @@ def sparse_attn(
mi = mi_new

with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_norm"):
sink_tile = pl.add(pl.sub(mi, mi), pl.read(attn_sink, [h]))
sink_bias = pl.reshape(attn_sink[h0 : h0 + MATMUL_ROW_PAD], [MATMUL_ROW_PAD, 1])
sink_tile = pl.add(pl.sub(mi, mi), sink_bias)
denom = pl.add(li, pl.exp(pl.sub(sink_tile, mi)))
oi_out = pl.row_expand_div(oi, denom)
attn_stage_row = pl.cast(
oi_out[0 : 1, 0 : HEAD_DIM],
oi_out[0 : MATMUL_ROW_PAD, 0 : HEAD_DIM],
target_type=pl.BF16,
)
attn_stage = pl.assemble(
Expand Down
Loading