Fix DeepSeek decode attention precision#361
Conversation
|
Caution Review failedThe pull request is closed. ℹ️ Recent review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR refines the DeepSeek V4 model implementation in two independent ways: the sparse attention golden reference now computes softmax via tiled/blocked online merge across KV rows to align with kernel behavior, and the Q projection per-head RMSNorm switches from ChangesDeepSeek V4 Attention and Normalization Refinements
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related issues
Possibly related PRs
Suggested labels
Poem
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the golden_sparse_attn function to implement a tiled approach with online softmax reduction, better reflecting optimized runtime behavior. Additionally, it modifies the RMS calculation in qkv_proj_rope.py by replacing pl.rsqrt with pl.recip(pl.sqrt). A review comment suggests casting the QK matmul inputs to bfloat16 in the golden path to accurately simulate the precision loss occurring in the JIT kernel.
| block_mi = [] | ||
| block_li = [] | ||
| block_oi = [] | ||
| for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE): | ||
| kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE] | ||
| scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE | ||
| mi = scores.max(dim=-1, keepdim=True).values | ||
| exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float() | ||
| li = exp_scores.sum(dim=-1, keepdim=True) | ||
| oi = exp_scores @ kv_tile.to(torch.bfloat16).float() | ||
| block_mi.append(mi) | ||
| block_li.append(li) | ||
| block_oi.append(oi) |
There was a problem hiding this comment.
To fully align the golden path with the optimized runtime implementation, the inputs to the QK matmul (q_t and kv_tile) should also be cast to bfloat16 to simulate the precision loss that occurs in the JIT kernel (where q and sparse_kv are pl.BF16). Currently, only the exponential weights and the PV matmul inputs are being cast.
| block_mi = [] | |
| block_li = [] | |
| block_oi = [] | |
| for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE): | |
| kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE] | |
| scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE | |
| mi = scores.max(dim=-1, keepdim=True).values | |
| exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float() | |
| li = exp_scores.sum(dim=-1, keepdim=True) | |
| oi = exp_scores @ kv_tile.to(torch.bfloat16).float() | |
| block_mi.append(mi) | |
| block_li.append(li) | |
| block_oi.append(oi) | |
| q_t = q_t.to(torch.bfloat16).float() | |
| block_mi = [] | |
| block_li = [] | |
| block_oi = [] | |
| for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE): | |
| kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE].to(torch.bfloat16).float() | |
| scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE | |
| mi = scores.max(dim=-1, keepdim=True).values | |
| exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float() | |
| li = exp_scores.sum(dim=-1, keepdim=True) | |
| oi = exp_scores @ kv_tile | |
| block_mi.append(mi) | |
| block_li.append(li) | |
| block_oi.append(oi) |
Summary
recip(sqrt(...))instead ofrsqrt(...)to avoid a large BF16 exact mismatch increase.Root Cause
pl.rsqrt(...)in the q path is mathematically equivalent topl.recip(pl.sqrt(...)), but it is not numerically equivalent on this backend. In exact comparison, the q output mismatch ratio increased from0.0031%to20.6860%, and the full decode SWA path failed withx_outratio0.9199%.The sparse attention golden path also still modeled the old full-FP32 softmax/PV reference, while the optimized implementation computes tiled softmax blocks and materializes BF16 exp weights before PV.
Validation
decode_attention_swa.pyseed0 on NPU: PASSqkv_proj_rope.pystandalone on NPU: PASSCommand used through
task-submit --device autowithTASK_DEVICEforwarded to the model scripts.