Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions models/deepseek/v4/attention_hca.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
H = 64 # flash:64 pro:128
HEAD_DIM = 512
ROPE_HEAD_DIM = 64
SPARSE_ROPE_CHUNK = 16
SPARSE_ROPE_INTERLEAVE_CHUNK = 2 * SPARSE_ROPE_CHUNK
NOPE_HEAD_DIM = HEAD_DIM - ROPE_HEAD_DIM
Q_LORA = 1024 # flash:1024 pro:1536
Q_PROJ_OUT_CHUNK = 128
Expand Down Expand Up @@ -105,6 +107,8 @@ def attention_hca(
freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16],
even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
# main compressor (rotate=False, head_dim=HEAD_DIM, ratio=128, overlap=False)
cmp_wkv: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16],
cmp_wgate: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16],
Expand Down Expand Up @@ -289,6 +293,8 @@ def attention_hca(
seqused_kv,
rope_cos_t,
rope_sin_t,
even_select_local,
odd_select_local,
wo_a,
wo_b,
wo_b_scale,
Expand Down Expand Up @@ -324,6 +330,8 @@ def attention_hca_test(
freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16],
even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
cmp_wkv: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16],
cmp_wgate: pl.Tensor[[MAIN_OUT_DIM, D], pl.BF16],
cmp_ape: pl.Tensor[[COMPRESS_RATIO, MAIN_OUT_DIM], pl.FP32],
Expand All @@ -347,6 +355,7 @@ def attention_hca_test(
hc_attn_fn, hc_attn_scale, hc_attn_base,
attn_norm_w, wq_a, wq_b, wq_b_scale, wkv, gamma_cq, gamma_ckv,
freqs_cos, freqs_sin, even_select_t, odd_select_t,
even_select_local, odd_select_local,
cmp_wkv, cmp_wgate, cmp_ape, cmp_norm_w,
cmp_kv_state, cmp_score_state,
kv_cache, ori_block_table, cmp_kv, cmp_block_table,
Expand Down Expand Up @@ -485,6 +494,8 @@ def golden_attention_hca(tensors):
"seqused_kv": tensors["seqused_kv"].view(B),
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"even_select_local": tensors["even_select_local"],
"odd_select_local": tensors["odd_select_local"],
"wo_a": tensors["wo_a"],
"wo_b": tensors["wo_b"],
"wo_b_scale": tensors["wo_b_scale"],
Expand Down Expand Up @@ -563,6 +574,22 @@ def init_odd_select_t():
for i in range(ROPE_HEAD_DIM // 2):
m[i, 2 * i + 1] = 1
return m
def init_even_select_local():
m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK))
for i in range(SPARSE_ROPE_CHUNK):
m[2 * i, i] = 1
return m
def init_odd_select_local():
m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK))
for i in range(SPARSE_ROPE_CHUNK):
m[2 * i + 1, i] = 1
return m

def init_normalized_cache(shape):
cache = torch.randn(*shape)
denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS)
return (cache / denom).to(torch.bfloat16)

def init_cmp_wkv():
return torch.randn(MAIN_OUT_DIM, D) / D ** 0.5
def init_cmp_wgate():
Expand All @@ -576,9 +603,9 @@ def init_cmp_kv_state():
def init_cmp_score_state():
return torch.full((B, MAIN_STATE_LEN, MAIN_OUT_DIM), float("-inf"))
def init_kv_cache():
return torch.zeros(ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)
return init_normalized_cache((ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM))
def init_cmp_kv():
return torch.zeros(CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)
return init_normalized_cache((CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM))

def init_ori_block_table():
tbl = torch.full((B, ORI_MAX_BLOCKS), -1, dtype=torch.int32)
Expand Down Expand Up @@ -630,6 +657,8 @@ def init_wo_b():
TensorSpec("freqs_sin", [MAX_SEQ_LEN, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_freqs_sin),
TensorSpec("even_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_even_select_t),
TensorSpec("odd_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_odd_select_t),
TensorSpec("even_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local),
TensorSpec("odd_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_odd_select_local),
TensorSpec("cmp_wkv", [MAIN_OUT_DIM, D], torch.bfloat16, init_value=init_cmp_wkv),
TensorSpec("cmp_wgate", [MAIN_OUT_DIM, D], torch.bfloat16, init_value=init_cmp_wgate),
TensorSpec("cmp_ape", [COMPRESS_RATIO, MAIN_OUT_DIM], torch.float32, init_value=init_cmp_ape),
Expand Down Expand Up @@ -666,8 +695,10 @@ def init_wo_b():
specs=build_tensor_specs(),
golden_fn=golden_attention_hca,
config=RunConfig(
rtol=3e-3,
atol=3e-3,
# Random ori/cmp cache fixtures exercise non-zero history values
# instead of the previous all-zero cache.
rtol=1e-2,
atol=1e-2,
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
Expand Down
41 changes: 36 additions & 5 deletions models/deepseek/v4/attention_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
H = 64 # flash:64 pro:128
HEAD_DIM = 512
ROPE_HEAD_DIM = 64
SPARSE_ROPE_CHUNK = 16
SPARSE_ROPE_INTERLEAVE_CHUNK = 2 * SPARSE_ROPE_CHUNK
NOPE_HEAD_DIM = HEAD_DIM - ROPE_HEAD_DIM
Q_LORA = 1024 # flash:1024 pro:1536
Q_PROJ_OUT_CHUNK = 128
Expand Down Expand Up @@ -64,7 +66,7 @@
SPARSE_CMP_MAX_BLOCKS = 64
SPARSE_CMP_BLOCK_NUM = B * SPARSE_CMP_MAX_BLOCKS

START_POS = 3 # default for ScalarSpec; >0 (decode); SWA path has no compression-related constraint
START_POS = 127 # default for ScalarSpec; full-window decode fixture; SWA has no compression constraint


@pl.jit.inline
Expand All @@ -86,6 +88,8 @@ def attention_swa(
freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16],
even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
# KV cache (sliding-window only: [0, WIN) ori; no cmp portion)
kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16],
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
Expand Down Expand Up @@ -193,6 +197,8 @@ def attention_swa(
seqused_kv,
rope_cos_t,
rope_sin_t,
even_select_local,
odd_select_local,
wo_a,
wo_b,
wo_b_scale,
Expand Down Expand Up @@ -230,6 +236,8 @@ def attention_swa_test(
freqs_sin: pl.Tensor[[MAX_SEQ_LEN, ROPE_HEAD_DIM], pl.BF16],
even_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
odd_select_t: pl.Tensor[[ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], pl.BF16],
even_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
odd_select_local: pl.Tensor[[SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], pl.BF16],
# KV cache (sliding-window only: [0, WIN) ori; no cmp portion)
kv_cache: pl.Tensor[[BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16],
block_table: pl.Tensor[[B, MAX_BLOCKS], pl.INT32],
Expand All @@ -249,6 +257,7 @@ def attention_swa_test(
attn_norm_w, wq_a, wq_b, wq_b_scale, wkv,
gamma_cq, gamma_ckv,
freqs_cos, freqs_sin, even_select_t, odd_select_t,
even_select_local, odd_select_local,
kv_cache, block_table,
attn_sink, seqused_kv,
wo_a, wo_b, wo_b_scale,
Expand Down Expand Up @@ -352,6 +361,8 @@ def golden_attention_swa(tensors):
"seqused_kv": seqused_kv.view(B),
"freqs_cos": rope_cos_T,
"freqs_sin": rope_sin_T,
"even_select_local": tensors["even_select_local"],
"odd_select_local": tensors["odd_select_local"],
"wo_a": tensors["wo_a"],
"wo_b": tensors["wo_b"],
"wo_b_scale": tensors["wo_b_scale"],
Expand Down Expand Up @@ -430,8 +441,24 @@ def init_odd_select_t():
for i in range(ROPE_HEAD_DIM // 2):
m[i, 2 * i + 1] = 1
return m
def init_even_select_local():
m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK))
for i in range(SPARSE_ROPE_CHUNK):
m[2 * i, i] = 1
return m
def init_odd_select_local():
m = torch.zeros((SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK))
for i in range(SPARSE_ROPE_CHUNK):
m[2 * i + 1, i] = 1
return m

def init_normalized_cache(shape):
cache = torch.randn(*shape)
denom = cache.float().pow(2).mean(dim=-1, keepdim=True).sqrt().clamp_min(EPS)
return (cache / denom).to(torch.bfloat16)

def init_kv_cache():
return torch.zeros(BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM)
return init_normalized_cache((BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM))

def init_block_table():
tbl = torch.full((B, MAX_BLOCKS), -1, dtype=torch.int32)
Expand Down Expand Up @@ -471,6 +498,8 @@ def init_wo_b():
TensorSpec("freqs_sin", [MAX_SEQ_LEN, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_freqs_sin),
TensorSpec("even_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_even_select_t),
TensorSpec("odd_select_t", [ROPE_HEAD_DIM // 2, ROPE_HEAD_DIM], torch.bfloat16, init_value=init_odd_select_t),
TensorSpec("even_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local),
TensorSpec("odd_select_local", [SPARSE_ROPE_INTERLEAVE_CHUNK, SPARSE_ROPE_CHUNK], torch.bfloat16, init_value=init_odd_select_local),
TensorSpec("kv_cache", [BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], torch.bfloat16, init_value=init_kv_cache),
TensorSpec("block_table", [B, MAX_BLOCKS], torch.int32, init_value=init_block_table),
TensorSpec("attn_sink", [H], torch.float32, init_value=init_attn_sink),
Expand Down Expand Up @@ -499,9 +528,11 @@ def init_wo_b():
specs=build_tensor_specs(),
golden_fn=golden_attention_swa,
config=RunConfig(
# qkv_proj_rope and sparse_attn both use W8A8/BF16 stages; SWA carries that drift through attention/o_proj.
rtol=1.3e-2,
atol=1.3e-2,
# qkv_proj_rope and sparse_attn both use W8A8/BF16 stages; the
# random KV-cache fixture exercises a less diluted attention output
# than the previous all-zero cache.
rtol=1e-2,
atol=1e-2,
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
Expand Down
Loading