Update DeepSeek V4 indexer and ratio-4 compressor#260
Conversation
## Summary - Promote the DeepSeek V4 indexer draft into models/deepseek/v4/indexer.py and remove the stale draft entry point. - Implement the decode indexer path: QR projection, selector-based RoPE, Hadamard rotation, weighted score reduction, and score/top-k outputs with torch golden coverage. - Preserve SORT_LEN-padded top-k sorting while using concat padding to avoid the A2/A3 non-mat tmov shape mismatch from assembling a SCORE_LEN row into a SORT_LEN row. - Update compressor_ratio4 for the rotating compressed-KV path, including projected kv/score state updates, RMSNorm/RoPE/Hadamard flow, KV-cache writes, and matching tensor specs. ## Related Issues N/A
📝 WalkthroughWalkthroughThe PR refactors the DeepSeek-V4 decode path by rewriting the KV compressor with overlapping-state pooling, RoPE-based rotation, and KV-cache output handling, and by replacing a draft indexer placeholder with a full implementation including score computation, topk selection, and golden reference. ChangesDeepSeek-V4 Decode Compressor and Indexer Redesign
Sequence DiagramsequenceDiagram
participant Input as Input (x, kv_state, score_state)
participant Accum as Accumulation (K_CHUNK tiles)
participant Pool as Overlapped Pool (conditional softmax)
participant Norm as RMSNorm (FP32 path)
participant RoPE as RoPE (even/odd select)
participant Rotate as Hadamard (optional)
participant Cache as KV Cache (write)
participant Output as Output (kv, state, cache)
Input->>Accum: Project & accumulate kv, score via wkv, wgate
Accum->>Accum: Scatter into flattened state at overlap slot
Accum->>Pool: Check compress_rem == 0?
Pool->>Pool: Conditional softmax + pool over score_state_flat
Pool->>Norm: Shift overlap layout, output pooled_kv
Norm->>Norm: Normalize pooled BF16 via FP32 accumulation
Norm->>RoPE: Apply RMSNorm result
RoPE->>RoPE: Slice/rotate even/odd dims using selectors & cos/sin
RoPE->>Rotate: Assemble rotated output
Rotate->>Rotate: Conditionally apply hadamard transform
Rotate->>Cache: Pack rotated kv to BF16
Cache->>Cache: Write to kv_cache at start_pos // COMPRESS_RATIO
Cache->>Output: Return updated kv, state, cache
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly Related Issues
Possibly Related PRs
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/indexer.py`:
- Around line 9-11: Update the stale comments in models/deepseek/v4/indexer.py:
change the module docstring so it no longer claims "The inner Compressor is
invoked via golden_compressor (placeholder)" and instead accurately describes
that golden_indexer consumes idx_kv_cache directly; remove or correct the
misleading START_POS comment next to the START_POS constant (it currently
asserts an inner-compressor condition that is not used and is incorrect for
START_POS = 256); and delete the redundant "# TODO: kernel implementation"
comment near the kernel code since the kernel is implemented directly below
(ensure any remaining comments accurately describe the implemented behavior in
golden_indexer and the kernel block).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 49b9ffef-68ad-4e95-8dd9-3f7d8164537d
📒 Files selected for processing (3)
models/deepseek/v4/compressor_ratio4.pymodels/deepseek/v4/indexer.pymodels/deepseek/v4/indexer_draft.py
💤 Files with no reviewable changes (1)
- models/deepseek/v4/indexer_draft.py
| """DeepSeek-V4 Indexer (decode). Mirrors model.py Indexer (line 380-433); | ||
| golden is a port of forward's decode branch (prefill `start_pos == 0` path is omitted). | ||
| The inner Compressor is invoked via golden_compressor (placeholder).""" |
There was a problem hiding this comment.
Clean up stale draft comments/TODO.
Several comments are leftovers from the indexer_draft.py lineage and no longer match this file:
- Lines 9-11: the docstring still claims "The inner Compressor is invoked via golden_compressor (placeholder)", but
golden_indexerhere does not invoke any compressor — it consumesidx_kv_cachedirectly. - Line 41: the
START_POScomment asserts(START_POS+1)%COMPRESS_RATIO==0 to cover the full inner-compressor path. WithSTART_POS = 256,(256+1) % 4 == 1, and the indexer has no inner-compressor path in any case, so the guidance is both unmet by the value and unreachable by the code. - Line 68:
# TODO: kernel implementationis stale — the kernel is implemented immediately below.
📝 Proposed cleanup
-"""DeepSeek-V4 Indexer (decode). Mirrors model.py Indexer (line 380-433);
-golden is a port of forward's decode branch (prefill `start_pos == 0` path is omitted).
-The inner Compressor is invoked via golden_compressor (placeholder)."""
+"""DeepSeek-V4 Indexer (decode). Mirrors model.py Indexer (line 380-433);
+golden is a port of forward's decode branch (prefill `start_pos == 0` path is omitted).
+Consumes the compressed KV cache produced by compressor_ratio4."""-START_POS = 256 # default for ScalarSpec; >0 (decode) and (START_POS+1)%COMPRESS_RATIO==0 to cover the full inner-compressor path
+START_POS = 256 # default for ScalarSpec; >0 selects the decode path in golden_indexer
OFFSET = 128 # default for ScalarSpec; = win in attention orch; added to topk_idxs (model.py:432)- # TODO: kernel implementation
cache_len = (start_pos + S) // COMPRESS_RATIOAlso applies to: 41-41, 68-68
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/indexer.py` around lines 9 - 11, Update the stale comments
in models/deepseek/v4/indexer.py: change the module docstring so it no longer
claims "The inner Compressor is invoked via golden_compressor (placeholder)" and
instead accurately describes that golden_indexer consumes idx_kv_cache directly;
remove or correct the misleading START_POS comment next to the START_POS
constant (it currently asserts an inner-compressor condition that is not used
and is incorrect for START_POS = 256); and delete the redundant "# TODO: kernel
implementation" comment near the kernel code since the kernel is implemented
directly below (ensure any remaining comments accurately describe the
implemented behavior in golden_indexer and the kernel block).
There was a problem hiding this comment.
Code Review
This pull request refactors the compressor in compressor_ratio4.py to incorporate RoPE, Hadamard rotation, and KV cache management, while also tightening test tolerances. Additionally, it introduces a new indexer.py implementation for the DeepSeek-V4 decode phase, replacing a previous draft. Feedback highlights opportunities to improve performance by parallelizing operations like RMSNorm, RoPE, and weight projections over the batch dimension in both the compressor and indexer modules to better utilize hardware resources.
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): | ||
| partial_sq = pl.full([1, B * S], dtype=pl.FP32, value=0.0) | ||
| for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): | ||
| # Golden applies rmsnorm to kv.to(torch.bfloat16), then casts to FP32 inside rmsnorm. | ||
| kv_rms_chunk = pl.cast( | ||
| pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), | ||
| target_type=pl.FP32, | ||
| ) | ||
| partial_sq = pl.add( | ||
| partial_sq, | ||
| pl.reshape(pl.row_sum(pl.mul(kv_rms_chunk, kv_rms_chunk)), [1, B * S]), | ||
| ) | ||
|
|
||
| variance = pl.reshape(pl.add(pl.mul(partial_sq, HEAD_DIM_INV), EPS), [B * S, 1]) | ||
| inv_rms = pl.recip(pl.sqrt(variance)) | ||
| for k0 in pl.range(0, HEAD_DIM, HEAD_CHUNK): | ||
| kv_norm_chunk = pl.cast( | ||
| pl.cast(pooled_kv[:, k0 : k0 + HEAD_CHUNK], target_type=pl.BF16), | ||
| target_type=pl.FP32, | ||
| ) | ||
| gamma = norm_w_2d[:, k0 : k0 + HEAD_CHUNK] | ||
| normed_chunk = pl.col_expand_mul(pl.row_expand_mul(kv_norm_chunk, inv_rms), gamma) | ||
| normed_kv = pl.assemble(normed_kv, pl.cast(normed_chunk, target_type=pl.BF16), [0, k0]) | ||
|
|
||
| kv_rope = pl.create_tensor([B * S, ROPE_HEAD_DIM], dtype=pl.BF16) | ||
| kv_proj_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) | ||
| kv_proj_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.FP32) | ||
| rope_even = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) | ||
| rope_odd = pl.create_tensor([B * S, ROPE_HEAD_DIM // 2], dtype=pl.BF16) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_slice"): | ||
| kv_rope = pl.assemble(kv_rope, normed_kv[:, NOPE_HEAD_DIM : HEAD_DIM], [0, 0]) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_slice"): | ||
| kv_rope_tile = kv_rope[:, 0 : ROPE_CHUCK] | ||
| even_select_tile = even_select[0 : ROPE_CHUCK, :] | ||
| odd_select_tile = odd_select[0 : ROPE_CHUCK, :] | ||
| even_acc = pl.matmul(kv_rope_tile, even_select_tile, out_dtype=pl.FP32) | ||
| odd_acc = pl.matmul(kv_rope_tile, odd_select_tile, out_dtype=pl.FP32) | ||
|
|
||
| for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM, ROPE_CHUCK): | ||
| kv_rope_tile = kv_rope[:, r0 : r0 + ROPE_CHUCK] | ||
| even_select_tile = even_select[r0 : r0 + ROPE_CHUCK, :] | ||
| odd_select_tile = odd_select[r0 : r0 + ROPE_CHUCK, :] | ||
| even_acc = pl.matmul_acc(even_acc, kv_rope_tile, even_select_tile) | ||
| odd_acc = pl.matmul_acc(odd_acc, kv_rope_tile, odd_select_tile) | ||
| kv_proj_even = pl.assemble(kv_proj_even, even_acc, [0, 0]) | ||
| kv_proj_odd = pl.assemble(kv_proj_odd, odd_acc, [0, 0]) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_apply"): | ||
| even_tile = kv_proj_even[:, :] | ||
| odd_tile = kv_proj_odd[:, :] | ||
| rope_even_acc = pl.cast(pl.sub(pl.col_expand_mul(even_tile, cos), pl.col_expand_mul(odd_tile, sin)), target_type=pl.BF16) | ||
| rope_odd_acc = pl.cast(pl.add(pl.col_expand_mul(even_tile, sin), pl.col_expand_mul(odd_tile, cos)), target_type=pl.BF16) | ||
| rope_even = pl.assemble(rope_even, rope_even_acc, [0, 0]) | ||
| rope_odd = pl.assemble(rope_odd, rope_odd_acc, [0, 0]) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_assemble"): | ||
| rope_even_tile = rope_even[:, 0 : ROPE_CHUCK] | ||
| rope_odd_tile = rope_odd[:, 0 : ROPE_CHUCK] | ||
| even_select_tile_t = even_select[:, 0 : ROPE_CHUCK] | ||
| odd_select_tile_t = odd_select[:, 0 : ROPE_CHUCK] | ||
| rope_acc = pl.matmul(rope_even_tile, even_select_tile_t, out_dtype=pl.FP32, b_trans=True) | ||
| rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) | ||
|
|
||
| for r0 in pl.range(ROPE_CHUCK, ROPE_HEAD_DIM // 2, ROPE_CHUCK): | ||
| rope_even_tile = rope_even[:, r0 : r0 + ROPE_CHUCK] | ||
| rope_odd_tile = rope_odd[:, r0 : r0 + ROPE_CHUCK] | ||
| even_select_tile_t = even_select[:, r0 : r0 + ROPE_CHUCK] | ||
| odd_select_tile_t = odd_select[:, r0 : r0 + ROPE_CHUCK] | ||
| rope_acc = pl.matmul_acc(rope_acc, rope_even_tile, even_select_tile_t, b_trans=True) | ||
| rope_acc = pl.matmul_acc(rope_acc, rope_odd_tile, odd_select_tile_t, b_trans=True) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_write"): | ||
| normed_kv = pl.assemble(normed_kv, pl.cast(rope_acc, target_type=pl.BF16), [0, NOPE_HEAD_DIM]) | ||
|
|
||
| if rotate: | ||
| for o0 in pl.range(0, HEAD_DIM, OUT_CHUNK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_hadamard"): | ||
| kv_proj_tile = normed_kv[:, 0 : HEAD_DIM] | ||
| hadamard_tile = hadamard[0 : HEAD_DIM, o0 : o0 + OUT_CHUNK] | ||
| kv_hadamard_acc = pl.matmul(kv_proj_tile, hadamard_tile, out_dtype=pl.FP32) | ||
| kv_flat = pl.assemble(kv_flat, kv_hadamard_acc, [0, o0]) | ||
| else: | ||
| for o0 in pl.parallel(0, HEAD_DIM, OUT_CHUNK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_write"): | ||
| kv_out_tile = normed_kv[:, o0 : o0 + OUT_CHUNK] | ||
| kv_flat = pl.assemble(kv_flat, pl.cast(kv_out_tile, target_type=pl.FP32), [0, o0]) |
There was a problem hiding this comment.
The RMSNorm, RoPE, and Hadamard rotation sections are currently executed sequentially for the entire batch. This results in significant under-utilization of hardware resources, especially as the batch size increases. These operations should be parallelized over the batch dimension B * S to improve performance.
| for h0 in pl.parallel(0, IDX_N_HEADS, HEAD_CHUCK): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="weights_proj"): | ||
| x_tile = x_flat[:, 0 : D_CHUCK] | ||
| weights_proj_tile = weights_proj[0 : D_CHUCK, h0 : h0 + HEAD_CHUCK] | ||
| weights_acc = pl.matmul(x_tile, weights_proj_tile, out_dtype=pl.FP32) | ||
|
|
||
| for d0 in pl.range(D_CHUCK, D, D_CHUCK): | ||
| x_tile = x_flat[:, d0 : d0 + D_CHUCK] | ||
| weights_proj_tile = weights_proj[d0 : d0 + D_CHUCK, h0 : h0 + HEAD_CHUCK] | ||
| weights_acc = pl.matmul_acc(weights_acc, x_tile, weights_proj_tile) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="weights_write"): | ||
| weights_scale = pl.mul(weights_acc, WEIGHTS_SCALE) | ||
| weights = pl.assemble(weights, weights_scale, [0, h0]) |
There was a problem hiding this comment.
Summary
Related Issues
N/A