Refactor: fuse and simplify DeepSeek V4 qkv_proj_rope scopes#332
Conversation
- Fuse attn_norm rms/apply/bf16-cast into one scope; drop the intermediate token_x_fp32 GM tensor - Fuse qr rms/apply/bf16-cast and kv rms/nope-apply/rope-apply into single scopes each, dropping the dead fp32 write-backs - Rewrite qr_proj_matmul in the qwen3 q_proj style (create_tensor acc + pl.pipeline K-loop with if-first matmul/matmul_acc) - Drop all chunked_loop_optimizer hints - Fold the q/kv RoPE reassemble+write loops inside their pl.at scopes via an intermediate fp32 tensor; use pl.pipeline(stage=2) - Convert all pl.slice/pl.assemble to subscript [:] sugar - Add in-core T-dim range(2) tiling to the fused kv_rms and q_head_rms_rope scopes so they stay within the Vec buffer at B=64
📝 WalkthroughWalkthroughThis PR refactors the DeepSeek-V4 QKV projection and rotary embedding pipeline to eliminate assemble-based storage and consolidate BF16 computation using tiled RMSNorm and direct indexing across all projection paths. ChangesQKV Projection and RoPE Optimization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/qkv_proj_rope.py`:
- Around line 52-53: The tiling assumes T divides evenly by RMS_T_TILES, but
T_TILE = T // RMS_T_TILES can be zero or skip tail rows; add a guard: either
assert T >= RMS_T_TILES and T % RMS_T_TILES == 0 immediately after
RMS_T_TILES/T_TILE definition, or (preferable) compute tiles so tails are
handled by replacing T_TILE = T // RMS_T_TILES with a safe scheme (e.g., compute
tile_count = RMS_T_TILES; compute start = i * T_TILE and end = min(T, (i+1) *
T_TILE) for each tile, or use ceil division for T_TILE and clamp end indices)
and update the tiled loops that use RMS_T_TILES and T_TILE (the loops referenced
in the comment) to use start/end bounds so the final partial tile is processed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 07e4c9c1-3d64-4f45-8600-c1709225f6ef
📒 Files selected for processing (1)
models/deepseek/v4/qkv_proj_rope.py
| RMS_T_TILES = 2 | ||
| T_TILE = T // RMS_T_TILES |
There was a problem hiding this comment.
Guard the 2-way T tiling contract.
T_TILE = T // RMS_T_TILES makes the tiled loops at Lines 178-179 and 251-252 cover exactly RMS_T_TILES * T_TILE rows. If T is odd, or ever smaller than RMS_T_TILES, the tail rows are skipped or the tile size becomes zero, so the last tokens never get RMS/RoPE applied. Please either assert that contract here or handle the final partial tile explicitly.
Suggested minimal guard
RMS_T_TILES = 2
+assert T >= RMS_T_TILES and T % RMS_T_TILES == 0, \
+ "RMS_T_TILES must evenly divide T"
T_TILE = T // RMS_T_TILES📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| RMS_T_TILES = 2 | |
| T_TILE = T // RMS_T_TILES | |
| RMS_T_TILES = 2 | |
| assert T >= RMS_T_TILES and T % RMS_T_TILES == 0, \ | |
| "RMS_T_TILES must evenly divide T" | |
| T_TILE = T // RMS_T_TILES |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/qkv_proj_rope.py` around lines 52 - 53, The tiling assumes
T divides evenly by RMS_T_TILES, but T_TILE = T // RMS_T_TILES can be zero or
skip tail rows; add a guard: either assert T >= RMS_T_TILES and T % RMS_T_TILES
== 0 immediately after RMS_T_TILES/T_TILE definition, or (preferable) compute
tiles so tails are handled by replacing T_TILE = T // RMS_T_TILES with a safe
scheme (e.g., compute tile_count = RMS_T_TILES; compute start = i * T_TILE and
end = min(T, (i+1) * T_TILE) for each tile, or use ceil division for T_TILE and
clamp end indices) and update the tiled loops that use RMS_T_TILES and T_TILE
(the loops referenced in the comment) to use start/end bounds so the final
partial tile is processed.
There was a problem hiding this comment.
Code Review
This pull request refactors the qkv_proj_rope function in models/deepseek/v4/qkv_proj_rope.py to optimize performance and improve code structure, notably by updating parallel loop patterns, replacing slice operations with direct indexing, and optimizing tensor assembly. The review identified a critical race condition where tensors allocated in global memory lack a head dimension, causing data corruption during parallel execution. Additionally, suggestions were provided to fuse specific scopes to reduce global memory bandwidth usage and improve overall efficiency.
| q_rot_even_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) | ||
| q_rot_odd_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) |
There was a problem hiding this comment.
The tensors q_rot_even_bf16 and q_rot_odd_bf16 are created inside the pl.parallel(H) loop but lack a head dimension (e.g., [H, T, ROPE_HALF]). Since these tensors cross pl.at scopes, they are allocated in Global Memory (GM). Without indexing by the head index h, all parallel iterations will overwrite the same memory locations, leading to a critical race condition.
To fix this, either:
- Move the declarations outside the
pl.parallelloop and add anHdimension, then index them with[h, ...]throughout the loop. - Fuse the
q_head_rms_ropeandq_rope_reassemblescopes so these can be local variables (though this is difficult due to the tiling inversion betweenTandROPE_DIM).
| q_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): | ||
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | ||
| pair_col = rope_col // 2 | ||
| q_rot_even_chunk = pl.slice(q_rot_even_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | ||
| q_rot_odd_chunk = pl.slice(q_rot_odd_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | ||
| q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | ||
| q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | ||
| q_rot_chunk = pl.matmul( | ||
| q_rot_even_chunk, | ||
| pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | ||
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | ||
| out_dtype=pl.FP32, | ||
| ) | ||
| q_rot_chunk = pl.matmul_acc( | ||
| q_rot_chunk, | ||
| q_rot_odd_chunk, | ||
| pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | ||
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | ||
| ) | ||
| q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = q_rot_chunk | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): | ||
| h0 = h * HEAD_DIM | ||
| q_flat = pl.assemble(q_flat, pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint"), [0, h0 + NOPE_DIM + rope_col]) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): | ||
| h0 = h * HEAD_DIM | ||
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | ||
| q_rope_chunk = q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] | ||
| q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rope_chunk, target_type=pl.BF16, mode="rint") |
There was a problem hiding this comment.
The q_rope_reassemble and q_rope_write scopes can be fused into a single pl.at scope. This would eliminate the need for the intermediate GM tensor q_rope_fp32 (which also suffers from the race condition mentioned above), saving GM bandwidth and improving performance. You can write the results of the matmuls directly to q_flat after casting.
| q_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| pair_col = rope_col // 2 | |
| q_rot_even_chunk = pl.slice(q_rot_even_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | |
| q_rot_odd_chunk = pl.slice(q_rot_odd_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | |
| q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| q_rot_chunk = pl.matmul( | |
| q_rot_even_chunk, | |
| pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | |
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| out_dtype=pl.FP32, | |
| ) | |
| q_rot_chunk = pl.matmul_acc( | |
| q_rot_chunk, | |
| q_rot_odd_chunk, | |
| pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | |
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| ) | |
| q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = q_rot_chunk | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): | |
| h0 = h * HEAD_DIM | |
| q_flat = pl.assemble(q_flat, pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint"), [0, h0 + NOPE_DIM + rope_col]) | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): | |
| h0 = h * HEAD_DIM | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| q_rope_chunk = q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] | |
| q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rope_chunk, target_type=pl.BF16, mode="rint") | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble_write"): | |
| h0 = h * HEAD_DIM | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| pair_col = rope_col // 2 | |
| # Note: q_rot_even_bf16 should be indexed by h if fixed as suggested above | |
| q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| q_rot_chunk = pl.matmul( | |
| q_rot_even_chunk, | |
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| out_dtype=pl.FP32, | |
| ) | |
| q_rot_chunk = pl.matmul_acc( | |
| q_rot_chunk, | |
| q_rot_odd_chunk, | |
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| ) | |
| q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint") |
| kv_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): | ||
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | ||
| pair_col = rope_col // 2 | ||
| kv_rot_even_chunk = pl.slice(kv_rot_even_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | ||
| kv_rot_odd_chunk = pl.slice(kv_rot_odd_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | ||
| kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | ||
| kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | ||
| kv_rot_chunk = pl.matmul( | ||
| kv_rot_even_chunk, | ||
| pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | ||
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | ||
| out_dtype=pl.FP32, | ||
| ) | ||
| kv_rot_chunk = pl.matmul_acc( | ||
| kv_rot_chunk, | ||
| kv_rot_odd_chunk, | ||
| pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | ||
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | ||
| ) | ||
| kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): | ||
| kv = pl.assemble(kv, pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint"), [0, NOPE_DIM + rope_col]) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): | ||
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | ||
| kv_rope_chunk = kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] | ||
| kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rope_chunk, target_type=pl.BF16, mode="rint") |
There was a problem hiding this comment.
Similar to the Q path, the kv_rope_reassemble and kv_rope_write scopes should be fused to eliminate the intermediate GM tensor kv_rope_fp32 and save GM round-trips.
| kv_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| pair_col = rope_col // 2 | |
| kv_rot_even_chunk = pl.slice(kv_rot_even_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | |
| kv_rot_odd_chunk = pl.slice(kv_rot_odd_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) | |
| kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| kv_rot_chunk = pl.matmul( | |
| kv_rot_even_chunk, | |
| pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | |
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| out_dtype=pl.FP32, | |
| ) | |
| kv_rot_chunk = pl.matmul_acc( | |
| kv_rot_chunk, | |
| kv_rot_odd_chunk, | |
| pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), | |
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| ) | |
| kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): | |
| kv = pl.assemble(kv, pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint"), [0, NOPE_DIM + rope_col]) | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| kv_rope_chunk = kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] | |
| kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rope_chunk, target_type=pl.BF16, mode="rint") | |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble_write"): | |
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | |
| pair_col = rope_col // 2 | |
| kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] | |
| kv_rot_chunk = pl.matmul( | |
| kv_rot_even_chunk, | |
| even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| out_dtype=pl.FP32, | |
| ) | |
| kv_rot_chunk = pl.matmul_acc( | |
| kv_rot_chunk, | |
| kv_rot_odd_chunk, | |
| odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], | |
| ) | |
| kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint") |
Summary
token_x_fp32GM tensorqr_proj_matmulin the qwen3q_projstyle (create_tensoraccumulator +pl.pipelineK-loop withif-firstmatmul/matmul_acc)chunked_loop_optimizerhintspl.atscopes via an intermediate fp32 tensor; usepl.pipeline(stage=2)pl.slice/pl.assembleto subscript[:]sugarrange(2)tiling to the fusedkv_rmsandq_head_rms_ropescopes so they stay within the Vec buffer at B=64Validated on a2a3 (B=64, T=128): q / kv / qr / qr_scale all PASS.