Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions models/deepseek/v4/decode_sparse_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) PyPTO Contributors.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
Expand Down Expand Up @@ -659,11 +659,32 @@

kv_b = torch.stack(gathered, dim=0)
q_t = q[t]
scores = (q_t @ kv_b.T) * SOFTMAX_SCALE
score_max = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - score_max)
oi_num = exp_scores @ kv_b
li = exp_scores.sum(dim=-1, keepdim=True)

block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE]
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile.to(torch.bfloat16).float()
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)
Comment on lines +663 to +675
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully align the golden path with the optimized runtime implementation, the inputs to the QK matmul (q_t and kv_tile) should also be cast to bfloat16 to simulate the precision loss that occurs in the JIT kernel (where q and sparse_kv are pl.BF16). Currently, only the exponential weights and the PV matmul inputs are being cast.

Suggested change
block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE]
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile.to(torch.bfloat16).float()
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)
q_t = q_t.to(torch.bfloat16).float()
block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE].to(torch.bfloat16).float()
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)


score_max = block_mi[0]
li = block_li[0]
oi_num = block_oi[0]
for mi_cur, li_cur, oi_cur in zip(block_mi[1:], block_li[1:], block_oi[1:]):
score_max_new = torch.maximum(score_max, mi_cur)
alpha = torch.exp(score_max - score_max_new)
beta = torch.exp(mi_cur - score_max_new)
li = alpha * li + beta * li_cur
oi_num = alpha * oi_num + beta * oi_cur
score_max = score_max_new

denom = li + torch.exp(attn_sink.unsqueeze(-1) - score_max)
o[t] = oi_num / denom

Expand Down
2 changes: 1 addition & 1 deletion models/deepseek/v4/qkv_proj_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def qkv_proj_rope(
d0 = h0 + db * HEAD_CHUNK
q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_CHUNK]
q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T]))
q_head_inv_rms = pl.rsqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))
q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)))
q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1])
q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms

Expand Down
Loading