From 7735545cf5c096b5819fb038f84bd323ddae2c40 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Tue, 12 May 2026 15:24:45 +0800 Subject: [PATCH 1/2] Update: adapt DeepSeek attention fixtures - Pass sparse attention local RoPE selector tensors from SWA and HCA examples - Initialize KV cache fixtures with deterministic non-zero data - Align SWA and HCA decode precision tolerances after NPU validation --- models/deepseek/v4/attention_hca.py | 41 ++++++++++++++++++++++++--- models/deepseek/v4/attention_swa.py | 43 +++++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 9 deletions(-) diff --git a/models/deepseek/v4/attention_hca.py b/models/deepseek/v4/attention_hca.py index d3ee957..e113fc1 100644 --- a/models/deepseek/v4/attention_hca.py +++ b/models/deepseek/v4/attention_hca.py @@ -32,6 +32,8 @@ H = 64 # flash:64 pro:128 HEAD_DIM = 512 ROPE_HEAD_DIM = 64 +SPARSE_ROPE_CHUNK = 16 +SPARSE_ROPE_INTERLEAVE_CHUNK = 2 * SPARSE_ROPE_CHUNK NOPE_HEAD_DIM = HEAD_DIM - ROPE_HEAD_DIM Q_LORA = 1024 # flash:1024 pro:1536 Q_PROJ_OUT_CHUNK = 128 @@ -105,6 +107,8 @@ def attention_hca( freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], + even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], + odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], # main compressor (rotate=False, head_dim=HEAD_DIM, ratio=128, overlap=False) cmp_wkv: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16], cmp_wgate: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16], @@ -289,6 +293,8 @@ def attention_hca( seqused_kv, rope_cos_t, rope_sin_t, + even_select_local, + odd_select_local, wo_a, wo_b, wo_b_scale, @@ -324,6 +330,8 @@ def attention_hca_test( freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], + even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], + odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], cmp_wkv: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16], cmp_wgate: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16], cmp_ape: pl.Tensor[[COMPRESS_RATIO, MAIN_OUT_DIM], pl.FP32], @@ -347,6 +355,7 @@ def attention_hca_test( hc_attn_fn, hc_attn_scale, hc_attn_base, attn_norm_w, wq_a, wq_b, wq_b_scale, wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, even_select_t, odd_select_t, + even_select_local, odd_select_local, cmp_wkv, cmp_wgate, cmp_ape, cmp_norm_w, cmp_kv_state, cmp_score_state, kv_cache, ori_block_table, cmp_kv, cmp_block_table, @@ -485,6 +494,8 @@ def golden_attention_hca(tensors): "seqused_kv": tensors["seqused_kv"].view(B), "freqs_cos": rope_cos_T, "freqs_sin": rope_sin_T, + "even_select_local": tensors["even_select_local"], + "odd_select_local": tensors["odd_select_local"], "wo_a": tensors["wo_a"], "wo_b": tensors["wo_b"], "wo_b_scale": tensors["wo_b_scale"], @@ -563,6 +574,24 @@ def init_odd_select_t(): for i in range(ROPE_HEAD_DIM // 2): m[i, 2 * i + 1] = 1 return m + def init_even_select_local(): + m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK)) + for i in range(SPARSE_ROPE_CHUNK): + m[2 * i, i] = 1 + return m + def init_odd_select_local(): + m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK)) + for i in range(SPARSE_ROPE_CHUNK): + m[2 * i + 1, i] = 1 + return m + + def init_normalized_cache(shape, seed): + generator = torch.Generator() + generator.manual_seed(seed) + cache = torch.randn(*shape, generator=generator) + denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS) + return (cache / denom).to(torch.bfloat16) + def init_cmp_wkv(): return torch.randn(MAIN_OUT_DIM, D) / D ** 0.5 def init_cmp_wgate(): @@ -576,9 +605,9 @@ def init_cmp_kv_state(): def init_cmp_score_state(): return torch.full((B, MAIN_STATE_LEN, MAIN_OUT_DIM), float("-inf")) def init_kv_cache(): - return torch.zeros(ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM) + return init_normalized_cache((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260512) def init_cmp_kv(): - return torch.zeros(CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM) + return init_normalized_cache((CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260513) def init_ori_block_table(): tbl = torch.full((B, ORI_MAX_BLOCKS), -1, dtype=torch.int32) @@ -630,6 +659,8 @@ def init_wo_b(): TensorSpec("freqs_sin", [MAX_SEQ_LEN, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_freqs_sin), TensorSpec("even_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_even_select_t), TensorSpec("odd_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_odd_select_t), + TensorSpec("even_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local), + TensorSpec("odd_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_odd_select_local), TensorSpec("cmp_wkv", [MAIN_OUT_DIM, D], torch.bfloat16, init_value=init_cmp_wkv), TensorSpec("cmp_wgate", [MAIN_OUT_DIM, D], torch.bfloat16, init_value=init_cmp_wgate), TensorSpec("cmp_ape", [COMPRESS_RATIO, MAIN_OUT_DIM], torch.float32, init_value=init_cmp_ape), @@ -666,8 +697,10 @@ def init_wo_b(): specs=build_tensor_specs(), golden_fn=golden_attention_hca, config=RunConfig( - rtol=3e-3, - atol=3e-3, + # Random ori/cmp cache fixtures exercise non-zero history values + # instead of the previous all-zero cache. + rtol=1e-2, + atol=1e-2, compile=dict(dump_passes=True), runtime=dict( platform=args.platform, diff --git a/models/deepseek/v4/attention_swa.py b/models/deepseek/v4/attention_swa.py index 0e1cb14..ac2e2d2 100644 --- a/models/deepseek/v4/attention_swa.py +++ b/models/deepseek/v4/attention_swa.py @@ -32,6 +32,8 @@ H = 64 # flash:64 pro:128 HEAD_DIM = 512 ROPE_HEAD_DIM = 64 +SPARSE_ROPE_CHUNK = 16 +SPARSE_ROPE_INTERLEAVE_CHUNK = 2 * SPARSE_ROPE_CHUNK NOPE_HEAD_DIM = HEAD_DIM - ROPE_HEAD_DIM Q_LORA = 1024 # flash:1024 pro:1536 Q_PROJ_OUT_CHUNK = 128 @@ -64,7 +66,7 @@ SPARSE_CMP_MAX_BLOCKS = 64 SPARSE_CMP_BLOCK_NUM = B * SPARSE_CMP_MAX_BLOCKS -START_POS = 3 # default for ScalarSpec; >0 (decode); SWA path has no compression-related constraint +START_POS = 127 # default for ScalarSpec; full-window decode fixture; SWA has no compression constraint @pl.jit.inline @@ -86,6 +88,8 @@ def attention_swa( freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], + even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], + odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], # KV cache (sliding-window only: [0, WIN) ori; no cmp portion) kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], @@ -193,6 +197,8 @@ def attention_swa( seqused_kv, rope_cos_t, rope_sin_t, + even_select_local, + odd_select_local, wo_a, wo_b, wo_b_scale, @@ -230,6 +236,8 @@ def attention_swa_test( freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16], even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16], + even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], + odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16], # KV cache (sliding-window only: [0, WIN) ori; no cmp portion) kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32], @@ -249,6 +257,7 @@ def attention_swa_test( attn_norm_w, wq_a, wq_b, wq_b_scale, wkv, gamma_cq, gamma_ckv, freqs_cos, freqs_sin, even_select_t, odd_select_t, + even_select_local, odd_select_local, kv_cache, block_table, attn_sink, seqused_kv, wo_a, wo_b, wo_b_scale, @@ -352,6 +361,8 @@ def golden_attention_swa(tensors): "seqused_kv": seqused_kv.view(B), "freqs_cos": rope_cos_T, "freqs_sin": rope_sin_T, + "even_select_local": tensors["even_select_local"], + "odd_select_local": tensors["odd_select_local"], "wo_a": tensors["wo_a"], "wo_b": tensors["wo_b"], "wo_b_scale": tensors["wo_b_scale"], @@ -430,8 +441,26 @@ def init_odd_select_t(): for i in range(ROPE_HEAD_DIM // 2): m[i, 2 * i + 1] = 1 return m + def init_even_select_local(): + m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK)) + for i in range(SPARSE_ROPE_CHUNK): + m[2 * i, i] = 1 + return m + def init_odd_select_local(): + m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK)) + for i in range(SPARSE_ROPE_CHUNK): + m[2 * i + 1, i] = 1 + return m + + def init_normalized_cache(shape, seed): + generator = torch.Generator() + generator.manual_seed(seed) + cache = torch.randn(*shape, generator=generator) + denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS) + return (cache / denom).to(torch.bfloat16) + def init_kv_cache(): - return torch.zeros(BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM) + return init_normalized_cache((BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260512) def init_block_table(): tbl = torch.full((B, MAX_BLOCKS), -1, dtype=torch.int32) @@ -471,6 +500,8 @@ def init_wo_b(): TensorSpec("freqs_sin", [MAX_SEQ_LEN, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_freqs_sin), TensorSpec("even_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_even_select_t), TensorSpec("odd_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_odd_select_t), + TensorSpec("even_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local), + TensorSpec("odd_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_odd_select_local), TensorSpec("kv_cache", [BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache), TensorSpec("block_table", [B, MAX_BLOCKS], torch.int32, init_value=init_block_table), TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink), @@ -499,9 +530,11 @@ def init_wo_b(): specs=build_tensor_specs(), golden_fn=golden_attention_swa, config=RunConfig( - # qkv_proj_rope and sparse_attn both use W8A8/BF16 stages; SWA carries that drift through attention/o_proj. - rtol=1.3e-2, - atol=1.3e-2, + # qkv_proj_rope and sparse_attn both use W8A8/BF16 stages; the + # random KV-cache fixture exercises a less diluted attention output + # than the previous all-zero cache. + rtol=1e-2, + atol=1e-2, compile=dict(dump_passes=True), runtime=dict( platform=args.platform, From cf59218793a227ea253f54500fc27a80680ea0e0 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Tue, 12 May 2026 17:51:50 +0800 Subject: [PATCH 2/2] Update: randomize DeepSeek KV cache fixtures - Remove fixed seeds from SWA and HCA KV cache initialization - Keep normalized non-zero cache values for decode attention precision coverage --- models/deepseek/v4/attention_hca.py | 10 ++++------ models/deepseek/v4/attention_swa.py | 8 +++----- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/models/deepseek/v4/attention_hca.py b/models/deepseek/v4/attention_hca.py index e113fc1..44b27b6 100644 --- a/models/deepseek/v4/attention_hca.py +++ b/models/deepseek/v4/attention_hca.py @@ -585,10 +585,8 @@ def init_odd_select_local(): m[2 * i + 1, i] = 1 return m - def init_normalized_cache(shape, seed): - generator = torch.Generator() - generator.manual_seed(seed) - cache = torch.randn(*shape, generator=generator) + def init_normalized_cache(shape): + cache = torch.randn(*shape) denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS) return (cache / denom).to(torch.bfloat16) @@ -605,9 +603,9 @@ def init_cmp_kv_state(): def init_cmp_score_state(): return torch.full((B, MAIN_STATE_LEN, MAIN_OUT_DIM), float("-inf")) def init_kv_cache(): - return init_normalized_cache((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260512) + return init_normalized_cache((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)) def init_cmp_kv(): - return init_normalized_cache((CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260513) + return init_normalized_cache((CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)) def init_ori_block_table(): tbl = torch.full((B, ORI_MAX_BLOCKS), -1, dtype=torch.int32) diff --git a/models/deepseek/v4/attention_swa.py b/models/deepseek/v4/attention_swa.py index ac2e2d2..47f6fbf 100644 --- a/models/deepseek/v4/attention_swa.py +++ b/models/deepseek/v4/attention_swa.py @@ -452,15 +452,13 @@ def init_odd_select_local(): m[2 * i + 1, i] = 1 return m - def init_normalized_cache(shape, seed): - generator = torch.Generator() - generator.manual_seed(seed) - cache = torch.randn(*shape, generator=generator) + def init_normalized_cache(shape): + cache = torch.randn(*shape) denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS) return (cache / denom).to(torch.bfloat16) def init_kv_cache(): - return init_normalized_cache((BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM), seed=20260512) + return init_normalized_cache((BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)) def init_block_table(): tbl = torch.full((B, MAX_BLOCKS), -1, dtype=torch.int32)