Add FLASH config support to DSv4 CSA attention#296
Conversation
- Change attention_csa.py to use FLASH config instead of DEMO - Rename compressor to indexer_compressor in indexer_compressor.py - Update indexer.py to use indexer_compressor and set IDX_TOPK from config
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
📝 WalkthroughWalkthroughThe PR consolidates configuration, naming, and computation changes across the DeepSeek v4 attention stack. ChangesDeepSeek v4 Refactoring and Optimization
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
models/deepseek/v4/sparse_attn.py (1)
179-220: ⚡ Quick winGuard the head-block tail.
This loop now assumes
His a multiple ofMATMUL_ROW_PAD. Theq_flat/attn_sinkslices and the final assemble are all unconditionalh0 : h0 + MATMUL_ROW_PAD, so the last block will go out of bounds if that stops being true.Suggested guard
MATMUL_ROW_PAD = 16 +assert H % MATMUL_ROW_PAD == 0, ( + f"num_attention_heads={H} must be divisible by MATMUL_ROW_PAD={MATMUL_ROW_PAD}" +)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/sparse_attn.py` around lines 179 - 220, The loop over h0 assumes H is divisible by MATMUL_ROW_PAD causing out-of-bounds slices on q_flat, attn_sink and oi_out; modify the loop body in the parallel block that uses q_flat, kv_topk_batch, attn_sink and oi_out (symbols: H, MATMUL_ROW_PAD, q_flat, kv_topk_batch, attn_sink, oi_out, attn_stage_row) to compute a tail_len = min(MATMUL_ROW_PAD, H - h0) and use that length for all slices and reshapes (or explicitly pad temporary buffers to MATMUL_ROW_PAD and mask results) so the final cast/assemble only indexes 0:tail_len where needed; ensure all per-block computations (q_batch, kv_batch, oi, li, mi, sink_bias, oi_out and the final attn_stage_row write) respect tail_len to avoid OOB accesses.models/deepseek/v4/indexer.py (1)
35-35: ⚡ Quick winAssert
index_topkfits the score buffer.Now that
IDX_TOPKcomes from config, the later top-k path assumesIDX_TOPK <= SCORE_LENwithout checking it. A larger value will overrun the sorted-pair slice contract and the fixed[1, IDX_TOPK]scratch shape.Suggested guard
IDX_TOPK = M.index_topk +assert 0 <= IDX_TOPK <= SCORE_LEN, ( + f"index_topk={IDX_TOPK} must satisfy 0 <= index_topk <= SCORE_LEN={SCORE_LEN}" +)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/indexer.py` at line 35, The code sets IDX_TOPK = M.index_topk but never validates it against the fixed score buffer length, risking buffer overruns; add a guard where IDX_TOPK is derived (near IDX_TOPK / M.index_topk) that checks IDX_TOPK <= SCORE_LEN and either clamp it (IDX_TOPK = min(M.index_topk, SCORE_LEN)) or raise a clear ValueError/Assertion if M.index_topk > SCORE_LEN, and update any dependent assumptions about the sorted-pair slice/scratch shape ([1, IDX_TOPK]) accordingly so callers using IDX_TOPK cannot exceed the SCORE_LEN buffer.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/deepseek/v4/indexer.py`:
- Line 35: The code sets IDX_TOPK = M.index_topk but never validates it against
the fixed score buffer length, risking buffer overruns; add a guard where
IDX_TOPK is derived (near IDX_TOPK / M.index_topk) that checks IDX_TOPK <=
SCORE_LEN and either clamp it (IDX_TOPK = min(M.index_topk, SCORE_LEN)) or raise
a clear ValueError/Assertion if M.index_topk > SCORE_LEN, and update any
dependent assumptions about the sorted-pair slice/scratch shape ([1, IDX_TOPK])
accordingly so callers using IDX_TOPK cannot exceed the SCORE_LEN buffer.
In `@models/deepseek/v4/sparse_attn.py`:
- Around line 179-220: The loop over h0 assumes H is divisible by MATMUL_ROW_PAD
causing out-of-bounds slices on q_flat, attn_sink and oi_out; modify the loop
body in the parallel block that uses q_flat, kv_topk_batch, attn_sink and oi_out
(symbols: H, MATMUL_ROW_PAD, q_flat, kv_topk_batch, attn_sink, oi_out,
attn_stage_row) to compute a tail_len = min(MATMUL_ROW_PAD, H - h0) and use that
length for all slices and reshapes (or explicitly pad temporary buffers to
MATMUL_ROW_PAD and mask results) so the final cast/assemble only indexes
0:tail_len where needed; ensure all per-block computations (q_batch, kv_batch,
oi, li, mi, sink_bias, oi_out and the final attn_stage_row write) respect
tail_len to avoid OOB accesses.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 09000a80-d75d-4290-97f7-467ffa40b819
📒 Files selected for processing (4)
models/deepseek/v4/attention_csa.pymodels/deepseek/v4/indexer.pymodels/deepseek/v4/indexer_compressor.pymodels/deepseek/v4/sparse_attn.py
Summary
attention_csa.pyto useFLASHconfig instead ofDEMOcompressortoindexer_compressorinindexer_compressor.pyindexer.pyto useindexer_compressorand setIDX_TOPKfrom config (M.index_topk)sparse_attn.pyto align with FLASH config parametersDependencies: