Skip to content

Fix DeepSeek decode attention precision#361

Merged
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
high-cloud:codex/attention-precision-combined
May 23, 2026
Merged

Fix DeepSeek decode attention precision#361
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
high-cloud:codex/attention-precision-combined

Conversation

@high-cloud
Copy link
Copy Markdown
Contributor

Summary

  • Align the decode sparse attention golden path with the tiled BF16 exp/PV behavior used by the optimized runtime implementation.
  • Restore the DeepSeek V4 q-path per-head RMS inverse to recip(sqrt(...)) instead of rsqrt(...) to avoid a large BF16 exact mismatch increase.

Root Cause

pl.rsqrt(...) in the q path is mathematically equivalent to pl.recip(pl.sqrt(...)), but it is not numerically equivalent on this backend. In exact comparison, the q output mismatch ratio increased from 0.0031% to 20.6860%, and the full decode SWA path failed with x_out ratio 0.9199%.

The sparse attention golden path also still modeled the old full-FP32 softmax/PV reference, while the optimized implementation computes tiled softmax blocks and materializes BF16 exp weights before PV.

Validation

  • decode_attention_swa.py seed0 on NPU: PASS
  • qkv_proj_rope.py standalone on NPU: PASS

Command used through task-submit --device auto with TASK_DEVICE forwarded to the model scripts.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 23, 2026

Review Change Stack

Caution

Review failed

The pull request is closed.

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 8a10295e-e2ea-4963-8320-7068fd5f2e34

📥 Commits

Reviewing files that changed from the base of the PR and between be3c794 and 8af446f.

📒 Files selected for processing (2)
  • models/deepseek/v4/decode_sparse_attn.py
  • models/deepseek/v4/qkv_proj_rope.py

📝 Walkthrough

Walkthrough

This PR refines the DeepSeek V4 model implementation in two independent ways: the sparse attention golden reference now computes softmax via tiled/blocked online merge across KV rows to align with kernel behavior, and the Q projection per-head RMSNorm switches from rsqrt to an equivalent recip(sqrt(...)) form.

Changes

DeepSeek V4 Attention and Normalization Refinements

Layer / File(s) Summary
Sparse attention tiled softmax merge
models/deepseek/v4/decode_sparse_attn.py
The golden_sparse_attn reference now computes per-tile max, sum-exp, and numerator components and merges them via running-statistics updates instead of computing softmax over the full gathered KV matrix at once, aligning the reference with the kernel's staged merge behavior.
Q projection per-head RMSNorm form
models/deepseek/v4/qkv_proj_rope.py
Per-head inverse RMS computation for q_head_inv_rms changes from pl.rsqrt(...) to the equivalent pl.recip(pl.sqrt(...)) form.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related issues

Possibly related PRs

  • hw-native-sys/pypto-lib#327: Also updates the DeepSeek-V4 sparse attention golden reference to use tiled/blocked online-softmax merge structure, aligning reference behavior with the kernel's staged computation.
  • hw-native-sys/pypto-lib#251: Adjusts sparse attention online softmax staging/merge logic during DeepSeek decode, complementing the tiled merge approach in this PR.
  • hw-native-sys/pypto-lib#332: Modifies models/deepseek/v4/qkv_proj_rope.py at the same per-head RMSNorm computation stage in the Q projection pipeline.

Suggested labels

bug

Poem

🐰 With tiles and merges, softmax flows,
Per-head RMS in reciprocal glows,
Golden reference aligned at last,
DeepSeek V4's future's now fast!


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the golden_sparse_attn function to implement a tiled approach with online softmax reduction, better reflecting optimized runtime behavior. Additionally, it modifies the RMS calculation in qkv_proj_rope.py by replacing pl.rsqrt with pl.recip(pl.sqrt). A review comment suggests casting the QK matmul inputs to bfloat16 in the golden path to accurately simulate the precision loss occurring in the JIT kernel.

Comment on lines +663 to +675
block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE]
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile.to(torch.bfloat16).float()
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully align the golden path with the optimized runtime implementation, the inputs to the QK matmul (q_t and kv_tile) should also be cast to bfloat16 to simulate the precision loss that occurs in the JIT kernel (where q and sparse_kv are pl.BF16). Currently, only the exponential weights and the PV matmul inputs are being cast.

Suggested change
block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE]
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile.to(torch.bfloat16).float()
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)
q_t = q_t.to(torch.bfloat16).float()
block_mi = []
block_li = []
block_oi = []
for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE):
kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE].to(torch.bfloat16).float()
scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE
mi = scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float()
li = exp_scores.sum(dim=-1, keepdim=True)
oi = exp_scores @ kv_tile
block_mi.append(mi)
block_li.append(li)
block_oi.append(oi)

@high-cloud high-cloud changed the title [codex] Fix DeepSeek decode attention precision Fix DeepSeek decode attention precision May 23, 2026
@zhangqi-chen zhangqi-chen marked this pull request as ready for review May 23, 2026 04:05
@zhangqi-chen zhangqi-chen merged commit 01a2525 into hw-native-sys:main May 23, 2026
4 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants