Skip to content

Fix sparse attention seqused semantics#333

Merged
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
high-cloud:fix/official-seqused-kv
May 20, 2026
Merged

Fix sparse attention seqused semantics#333
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
high-cloud:fix/official-seqused-kv

Conversation

@high-cloud
Copy link
Copy Markdown
Contributor

Summary

  • Align sparse_attn seqused_kv with the official per-batch final sparse length contract.
  • Derive each query token's causal sparse length inside sparse_attn from seqused_kv[b] and s.
  • Update DeepSeek V4 attention/decode wrappers, fixtures, docs, and standalone causal regression coverage.

Related Issues

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

Warning

Rate limit exceeded

@high-cloud has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 14 minutes and 45 seconds before requesting another review.

You’ve run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 9a7b1a6e-0020-4e03-9d93-f6188feaa0d5

📥 Commits

Reviewing files that changed from the base of the PR and between 18428fb and 168f251.

📒 Files selected for processing (8)
  • models/deepseek/v4/attention_csa.py
  • models/deepseek/v4/attention_hca.py
  • models/deepseek/v4/attention_swa.py
  • models/deepseek/v4/decode_csa.py
  • models/deepseek/v4/decode_hca.py
  • models/deepseek/v4/decode_swa.py
  • models/deepseek/v4/deepseek_v4_decode_single_layer.md
  • models/deepseek/v4/sparse_attn.py
📝 Walkthrough

Walkthrough

This PR refactors the sparse attention contract across DeepSeek v4 decode orchestration by changing seqused_kv from a per-token-per-batch tensor [B, S] to a per-batch scalar [B]. Per-token sequence-used values are now derived in-kernel using the formula seqused_kv[b] - S + 1 + s. The change propagates through sparse_attn, three attention pathways (CSA, HCA, SWA), and their corresponding decode entrypoints.

Changes

Per-Batch seqused_kv Contract Refactoring

Layer / File(s) Summary
sparse_attn Kernel Refactoring
models/deepseek/v4/sparse_attn.py
Function signatures changed to accept seqused_kv [B] instead of [B, S]. Per-token seq_used is computed from the per-batch value using seqused_kv[b] - S + 1 + s in all kernel stages (gather, QK/softmax, PV, merge). Golden reference updated to match. Fixture generation produces per-batch scalars. New --causal-regression-fixture CLI flag added with supporting fixture logic.
CSA, HCA, SWA Attention Orchestration
models/deepseek/v4/attention_csa.py, models/deepseek/v4/attention_hca.py, models/deepseek/v4/attention_swa.py
Each attention pathway updates seqused_kv parameter from [B, S] to [B] in production and test function signatures. Golden references pass seqused_kv directly to sparse_attn without reshape/view. Fixture generators (init_seqused_kv) compute per-batch scalars using pathway-specific formulas (e.g., START_POS + S for CSA/HCA, min(WIN, START_POS + S) for SWA). TensorSpec updated to match [B] shape.
Decode Entrypoints
models/deepseek/v4/decode_csa.py, models/deepseek/v4/decode_hca.py, models/deepseek/v4/decode_swa.py
Decode functions and test variants (decode_csa, decode_hca, decode_swa) update seqused_kv parameters from [B, S] to [B], forwarding the updated shape to their respective attention orchestration calls.
Architecture Documentation
models/deepseek/v4/deepseek_v4_decode_single_layer.md
ATTENTION section clarified: seqused_kv is now documented as per-batch [B], with per-token sparse KV length derived as seqused_kv[b] - S + 1 + s, and causality enforced by the derived length combined with topk_idxs indices.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#327: Directly related refactoring of sparse_attn.py to use per-batch [B] seqused_kv shape and updated per-token derivation logic.
  • hw-native-sys/pypto-lib#234: Updates SWA/decode sparse-attention wiring to change seqused_kv from sequence-shaped to per-batch scalar.
  • hw-native-sys/pypto-lib#305: Modifies CSA decode entrypoints in related orchestration layers, though focused on indexer parameter changes alongside seqused_kv shape refactoring.

Poem

🐰 Per-batch scalars, once a token-wide dance,
Now whisper their length to each position's chance.
The sparse kernel blooms with a formula true:
seqused_kv[b] - S + 1 + s shines through!
From CSA to HCA to SWA's swift stride,
The refactored contract flows with causality's guide.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: fixing sparse attention seqused semantics from per-token to per-batch representation across the codebase.
Description check ✅ Passed The description is directly related to the changeset, explaining the main objectives of aligning seqused_kv contract, deriving causal sparse length, and updating wrappers/fixtures.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the seqused_kv tensor representation across several attention models (CSA, HCA, SWA) and their corresponding test harnesses. The tensor shape is changed from [B, S] to [B], and the per-token causal length calculation is updated to derive values dynamically from the batch-level sequence length. Additionally, a regression fixture for causal testing is introduced in the sparse attention harness. I have provided feedback suggesting the extraction of the duplicated sequence length calculation logic into a reusable inline helper function to improve maintainability.

Comment on lines +160 to 165
gather_s = gather_t - gather_b * S
gather_seq_final = pl.read(seqused_kv, [gather_b])
gather_seq_used = gather_seq_final - S + 1 + gather_s
gather_window_valid = pl.min(WIN, gather_seq_used)
gather_cmp_valid = gather_seq_used - gather_window_valid
gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of logic to calculate sequence lengths is duplicated in four places within this function (in the gather, qk, pv, and merge stages). This repetition makes the code harder to read and maintain. Any future changes to this logic would need to be applied in all four locations.

To improve code reuse and maintainability, consider extracting this logic into a @pl.jit.inline helper function. You could define it at the module level, for example:

@pl.jit.inline
def _get_sparse_lengths(t, seqused_kv):
    b = t // S
    s = t - b * S
    seq_final = pl.read(seqused_kv, [b])
    seq_used = seq_final - S + 1 + s
    window_valid = pl.min(WIN, seq_used)
    cmp_valid = seq_used - window_valid
    cmp_topk_valid = pl.min(IDX_TOPK, cmp_valid)
    return window_valid, cmp_topk_valid

Then, you can replace the duplicated blocks with a single call to this function.

Suggested change
gather_s = gather_t - gather_b * S
gather_seq_final = pl.read(seqused_kv, [gather_b])
gather_seq_used = gather_seq_final - S + 1 + gather_s
gather_window_valid = pl.min(WIN, gather_seq_used)
gather_cmp_valid = gather_seq_used - gather_window_valid
gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid)
gather_window_valid, gather_cmp_topk_valid = _get_sparse_lengths(gather_t, seqused_kv)

@high-cloud high-cloud force-pushed the fix/official-seqused-kv branch from 18428fb to 5d5768d Compare May 20, 2026 07:57
@high-cloud high-cloud force-pushed the fix/official-seqused-kv branch from 5d5768d to 168f251 Compare May 20, 2026 08:32
@zhangqi-chen zhangqi-chen merged commit 058894f into hw-native-sys:main May 20, 2026
7 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants