From 4e615d9ca80a2c8efe3c490ff73667b7daa69504 Mon Sep 17 00:00:00 2001 From: sjduan Date: Fri, 15 May 2026 13:11:58 +0800 Subject: [PATCH] Add FLASH config support to DSv4 CSA attention - Change attention_csa.py to use FLASH config instead of DEMO - Rename compressor to indexer_compressor in indexer_compressor.py - Update indexer.py to use indexer_compressor and set IDX_TOPK from config --- models/deepseek/v4/attention_csa.py | 2 +- models/deepseek/v4/indexer.py | 6 +++--- models/deepseek/v4/indexer_compressor.py | 4 ++-- models/deepseek/v4/sparse_attn.py | 15 ++++++++------- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/models/deepseek/v4/attention_csa.py b/models/deepseek/v4/attention_csa.py index df30062..9776908 100644 --- a/models/deepseek/v4/attention_csa.py +++ b/models/deepseek/v4/attention_csa.py @@ -29,7 +29,7 @@ import pypto.language as pl -from config import DEMO as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS +from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, BLOCK_SIZE, INT8_SCALE_MAX, INT8_AMAX_EPS from compressor_ratio4 import compressor from hc_post import hc_post from hc_pre import hc_pre diff --git a/models/deepseek/v4/indexer.py b/models/deepseek/v4/indexer.py index 87a17b6..8f9e01b 100644 --- a/models/deepseek/v4/indexer.py +++ b/models/deepseek/v4/indexer.py @@ -14,7 +14,7 @@ import pypto.language as pl from config import FLASH as M, DECODE_BATCH, DECODE_SEQ, FP32_NEG_INF, INT8_SCALE_MAX, INT8_AMAX_EPS -from indexer_compressor import compressor +from indexer_compressor import indexer_compressor # model config B = DECODE_BATCH @@ -32,7 +32,7 @@ # kernel-local COMPRESS_RATIO = 4 # the indexer only runs on ratio-4 layers -IDX_TOPK = 16 # standalone-test scale; model value is M.index_topk (512 flash / 1024 pro) +IDX_TOPK = M.index_topk INNER_ROTATE = True @@ -238,7 +238,7 @@ def indexer( weights_scale = pl.mul(weights_acc, WEIGHTS_SCALE) weights = pl.assemble(weights, weights_scale, [0, h0]) - inner_kv, inner_kv_state, inner_score_state, idx_kv_cache = compressor( + inner_kv, inner_kv_state, inner_score_state, idx_kv_cache = indexer_compressor( x, inner_kv, inner_kv_state, diff --git a/models/deepseek/v4/indexer_compressor.py b/models/deepseek/v4/indexer_compressor.py index ba45001..c082a30 100644 --- a/models/deepseek/v4/indexer_compressor.py +++ b/models/deepseek/v4/indexer_compressor.py @@ -54,7 +54,7 @@ @pl.jit.inline -def compressor( +def indexer_compressor( x: pl.Tensor[[B, S, D], pl.BF16], kv: pl.Tensor[[B, S, HEAD_DIM], pl.FP32], kv_state: pl.Tensor[[B, STATE_LEN, OUT_DIM], pl.FP32], @@ -290,7 +290,7 @@ def compressor_test( start_pos: pl.Scalar[pl.INT32], rotate: pl.Scalar[pl.BOOL], ): - kv, kv_state, score_state, kv_cache = compressor( + kv, kv_state, score_state, kv_cache = indexer_compressor( x, kv, kv_state, score_state, wkv, wgate, ape, norm_w, cos, sin, even_select, odd_select, hadamard, kv_cache, start_pos, rotate ) return kv, kv_state, score_state, kv_cache diff --git a/models/deepseek/v4/sparse_attn.py b/models/deepseek/v4/sparse_attn.py index e8fbc20..54bcd68 100644 --- a/models/deepseek/v4/sparse_attn.py +++ b/models/deepseek/v4/sparse_attn.py @@ -176,12 +176,12 @@ def sparse_attn( [window_valid + kk, 0], ) - for h in pl.parallel(0, H, 1): - attn_head_row = b * H + h + for h0 in pl.parallel(0, H, MATMUL_ROW_PAD): + attn_head_row = b * H + h0 with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_init"): - q_batch = pl.col_expand( - pl.full([MATMUL_ROW_PAD, HEAD_DIM], dtype=pl.FP32, value=0.0), - pl.cast(q_flat[attn_head_row : attn_head_row + 1, 0 : HEAD_DIM], target_type=pl.FP32), + q_batch = pl.cast( + q_flat[attn_head_row : attn_head_row + MATMUL_ROW_PAD, 0 : HEAD_DIM], + target_type=pl.FP32, ) kv_batch = pl.col_expand( @@ -212,11 +212,12 @@ def sparse_attn( mi = mi_new with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_norm"): - sink_tile = pl.add(pl.sub(mi, mi), pl.read(attn_sink, [h])) + sink_bias = pl.reshape(attn_sink[h0 : h0 + MATMUL_ROW_PAD], [MATMUL_ROW_PAD, 1]) + sink_tile = pl.add(pl.sub(mi, mi), sink_bias) denom = pl.add(li, pl.exp(pl.sub(sink_tile, mi))) oi_out = pl.row_expand_div(oi, denom) attn_stage_row = pl.cast( - oi_out[0 : 1, 0 : HEAD_DIM], + oi_out[0 : MATMUL_ROW_PAD, 0 : HEAD_DIM], target_type=pl.BF16, ) attn_stage = pl.assemble(