Fix sparse attention seqused semantics#333
Conversation
|
Warning Rate limit exceeded
You’ve run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
📝 WalkthroughWalkthroughThis PR refactors the sparse attention contract across DeepSeek v4 decode orchestration by changing ChangesPer-Batch seqused_kv Contract Refactoring
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the seqused_kv tensor representation across several attention models (CSA, HCA, SWA) and their corresponding test harnesses. The tensor shape is changed from [B, S] to [B], and the per-token causal length calculation is updated to derive values dynamically from the batch-level sequence length. Additionally, a regression fixture for causal testing is introduced in the sparse attention harness. I have provided feedback suggesting the extraction of the duplicated sequence length calculation logic into a reusable inline helper function to improve maintainability.
| gather_s = gather_t - gather_b * S | ||
| gather_seq_final = pl.read(seqused_kv, [gather_b]) | ||
| gather_seq_used = gather_seq_final - S + 1 + gather_s | ||
| gather_window_valid = pl.min(WIN, gather_seq_used) | ||
| gather_cmp_valid = gather_seq_used - gather_window_valid | ||
| gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid) |
There was a problem hiding this comment.
This block of logic to calculate sequence lengths is duplicated in four places within this function (in the gather, qk, pv, and merge stages). This repetition makes the code harder to read and maintain. Any future changes to this logic would need to be applied in all four locations.
To improve code reuse and maintainability, consider extracting this logic into a @pl.jit.inline helper function. You could define it at the module level, for example:
@pl.jit.inline
def _get_sparse_lengths(t, seqused_kv):
b = t // S
s = t - b * S
seq_final = pl.read(seqused_kv, [b])
seq_used = seq_final - S + 1 + s
window_valid = pl.min(WIN, seq_used)
cmp_valid = seq_used - window_valid
cmp_topk_valid = pl.min(IDX_TOPK, cmp_valid)
return window_valid, cmp_topk_validThen, you can replace the duplicated blocks with a single call to this function.
| gather_s = gather_t - gather_b * S | |
| gather_seq_final = pl.read(seqused_kv, [gather_b]) | |
| gather_seq_used = gather_seq_final - S + 1 + gather_s | |
| gather_window_valid = pl.min(WIN, gather_seq_used) | |
| gather_cmp_valid = gather_seq_used - gather_window_valid | |
| gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid) | |
| gather_window_valid, gather_cmp_topk_valid = _get_sparse_lengths(gather_t, seqused_kv) |
18428fb to
5d5768d
Compare
5d5768d to
168f251
Compare
Summary
Related Issues