Optimize DeepSeek V4 qkv_proj_rope decode (S=2): partial-sum reduces, amax fold, K-tile/stage tuning#339
Conversation
Restructure scopes to amortise per-task launch overhead and lift Vec/Cube utilisation on decode: - Split fused attn_norm into serial RMS reduce + parallel ATTN_NORM_GROUP apply; drop the FP32 norm intermediate, write token_x_bf16 directly. - Split qr_rms apply from RMS reduce and chunk by QR_NORM_GROUP. - Decouple qr_quant into amax-reduce + parallel apply over Q_LORA chunks. - Chunk qproj_matmul by Q_PROJ_GROUP and decouple qproj_dequant via a global INT32 col_acc_all staging buffer, letting dequant run at a larger Q_PROJ_DEQUANT_GROUP without slowing matmul. - Split per-head fused RMS+NOPE+RoPE into q_head_rms_nope + q_head_rope so the RoPE scope stays within the 192KB Vec UB budget at T=128 (S=2). - Pull q_rope_reassemble/q_rope_write out of the per-head loop and chunk by HEAD_GROUP via a [H*T, ROPE_DIM] pair staging buffer. - Chunk kv_proj_matmul by KV_PROJ_GROUP. Tuning constants: ROPE_CHUNK 32->64, Q_PROJ_CHUNK 128->256, QUANT_APPLY_CHUNK 256. Adds divisibility asserts for each new group.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR refactors the DeepSeek-V4 fused single-token QKV attention pipeline by decoupling compute stages, updating tiling/grouping parameters, and reorganizing RMSNorm and RoPE into separate grouped passes with intermediate staging buffers. Quantization and Q-projection matmul/dequant are split into independent stages; per-head RoPE is reassembled via cross-head staging; and KV RoPE reassembly uses a full FP32 buffer. ChangesDeepSeek-V4 QKV Projection and RoPE Pipeline Refactoring
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/qkv_proj_rope.py`:
- Around line 49-52: Add direct guards that H is divisible by grouping factors:
insert assert H % HEAD_GROUP == 0 with a clear message near the existing
HEAD_BLOCKS check so the grouped q-RoPE loops (pl.parallel(0, H, HEAD_GROUP) /
pl.range(HEAD_GROUP)) cannot index past H; likewise add assert H % Q_PROJ_GROUP
== 0 next to the Q_PROJ_GROUP/Q_PROJ_OUT_CHUNK assertions (and repeat the same
explicit H % GROUP checks in the other similar block around lines 277-318) so
grouped loop stride invariants are enforced explicitly.
- Line 48: Add explicit divisibility guards for the new Q_LORA tiling
assumptions: assert Q_LORA % Q_PROJ_CHUNK == 0 (so Q_PROJ_BLOCKS = Q_LORA //
Q_PROJ_CHUNK won't drop a tail), assert Q_LORA % QUANT_CHUNK == 0 (so inner
quant tiles don't run past range), and assert (Q_LORA // QUANT_CHUNK) %
QUANT_APPLY_CHUNK == 0 (so each outer group contains a full QUANT_APPLY_CHUNK of
QUANT_CHUNK tiles). Place these grouped asserts near the top of the module
(around the existing QUANT_APPLY_CHUNK definition) and/or at the start of
qproj_matmul and qr_quant_apply to clearly document and enforce the contracts
referenced in lines ~48, 174-181 and 202-217.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 9d300a81-7e87-4c00-b76e-18e22fc6e355
📒 Files selected for processing (1)
models/deepseek/v4/qkv_proj_rope.py
| assert (H * HEAD_DIM) % (HEAD_CHUNK * HEAD_GROUP) == 0, \ | ||
| "HEAD_BLOCKS must be divisible by HEAD_GROUP" | ||
| assert ((H * HEAD_DIM) // Q_PROJ_OUT_CHUNK) % Q_PROJ_GROUP == 0, \ | ||
| "Q_PROJ_HEAD_BLOCKS must be divisible by Q_PROJ_GROUP" |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win
Guard HEAD_GROUP against H directly.
The grouped q-RoPE loops step in head units (pl.parallel(0, H, HEAD_GROUP) + pl.range(HEAD_GROUP)), so the invariant they need is H % HEAD_GROUP == 0. The current HEAD_BLOCKS assert is only an indirect proxy and can pass for configs where the last h_inner still indexes past H.
Suggested guard addition
+assert H % HEAD_GROUP == 0, \
+ "H must be divisible by HEAD_GROUP"
assert (H * HEAD_DIM) % (HEAD_CHUNK * HEAD_GROUP) == 0, \
"HEAD_BLOCKS must be divisible by HEAD_GROUP"Also applies to: 277-318
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/qkv_proj_rope.py` around lines 49 - 52, Add direct guards
that H is divisible by grouping factors: insert assert H % HEAD_GROUP == 0 with
a clear message near the existing HEAD_BLOCKS check so the grouped q-RoPE loops
(pl.parallel(0, H, HEAD_GROUP) / pl.range(HEAD_GROUP)) cannot index past H;
likewise add assert H % Q_PROJ_GROUP == 0 next to the
Q_PROJ_GROUP/Q_PROJ_OUT_CHUNK assertions (and repeat the same explicit H % GROUP
checks in the other similar block around lines 277-318) so grouped loop stride
invariants are enforced explicitly.
There was a problem hiding this comment.
Code Review
This pull request introduces significant performance optimizations to the qkv_proj_rope module by increasing chunk sizes and implementing task-grouping for normalization, projection, and quantization stages. Key changes include the decoupling of query projection dequantization into a separate parallel pass and the use of staging buffers to manage memory constraints and reduce launch overhead. Review feedback suggests further refining the placement of tensor initializations within loops to prevent potential compiler-induced state dependencies and simplifying redundant loops where chunk sizes now match full dimensions.
|
|
||
| # Pre-declared outside pl.range to satisfy pypto's loop-carried init_values | ||
| # threading. The dummy value is overwritten by pl.matmul on the first iter. | ||
| col_acc = pl.create_tensor([T, Q_PROJ_OUT_CHUNK], dtype=pl.INT32) |
There was a problem hiding this comment.
The col_acc tensor is created outside the h_inner loop but inside the pl.parallel scope. While the logic correctly resets col_acc on the first iteration of the qb loop (when qr_proj_col0 == 0), reusing the same tensor object across serial iterations of h_inner can sometimes lead to unexpected loop-carried state dependencies in the PyPTO compiler if not explicitly handled. Since this is intended as a per-head-block scratch buffer, consider moving the creation inside the h_inner loop or ensuring the compiler correctly hoists it as a task-local scratch.
| wkv_chunk = wkv[d0 : d0 + D_CHUNK, kv_col0 : kv_col0 + KV_CHUNK] | ||
| kv_acc = pl.matmul_acc(kv_acc, kv_x_chunk_bf16, wkv_chunk) | ||
| kv_fp32[:, kv_col0 : kv_col0 + KV_CHUNK] = kv_acc | ||
| kv_acc = pl.create_tensor([T, KV_CHUNK], dtype=pl.FP32) |
There was a problem hiding this comment.
Similar to the qproj_matmul scope, kv_acc is created outside the k_inner loop. Although it is correctly initialized via pl.matmul when d0 == 0, reusing the tensor across serial iterations of k_inner might trigger loop-carried state analysis. Moving the creation inside the k_inner loop would make the task-local nature of the buffer more explicit to the optimizer.
| kv_rope_full = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): | ||
| for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): | ||
| for rope_col in pl.range(0, ROPE_DIM, ROPE_CHUNK): |
There was a problem hiding this comment.
The kv_rope_reassemble scope still uses a pl.range loop over ROPE_CHUNK, even though ROPE_CHUNK is now equal to ROPE_DIM (64). This makes the loop run only once. For consistency with the optimized q_rope_reassemble (lines 299-313), which processes the full ROPE_DIM at once, you could remove this loop and the associated slicing to simplify the code.
bf43f45 to
6cfd357
Compare
…es, amax fold, K-tile/stage tuning S=2 (T=128) wall-clock 3-run median: 624us -> 545us (-12.6%). Cumulative since pre-tuning baseline: 1868us -> 545us (-70.8%). Five accepted optimizations on top of the existing grouped-chunking endpoint: - Opt S: attn_norm_rms serial -> 2-way parallel partial sum + final reduce (chunked_loop_optimizer required on the partial scope to fit Vec UB; PARTIALS=2 keeps FP32 add deterministic across devices) - Opt T: fold qr_quant_amax into qr_norm_apply via a per-task partial-amax buffer; the residual qr_quant_amax scope shrinks 30us -> 1.7us. Partial amax is computed on qr_normed_bf16 (bit-identical to the original GM round-trip) so the INT8 quant scale is unchanged - Opt U: qr_rms 2-way parallel partial sum (mirrors Opt S; qr_fp32 is already FP32 so no chunked_loop_optimizer cast-split cost) - Opt V: Q_PROJ_CHUNK 256 -> 512 (K-tile only; N-tile already known to trigger ACL_ERROR_RT_AICORE_TIMEOUT per prior Opt B). qproj_matmul per-task Exec 74us -> 56us (-25%) - Opt X: qr_proj_matmul / kv_proj_matmul K-loop pl.pipeline stage 2 -> 4 (D_BLOCKS=32 has enough iters for 4-deep ping-pong) Side effects from renaming a few Python local names (d0 -> rms_d0 / apply_d0 / qr_d0, qr_chunk -> qr_norm_chunk) to break the implicit pypto AST init_values chain that links same-named locals across sibling scopes. Validation PASS on q / kv / qr / qr_scale across 3 runs on 3 different devices.
Summary
Continues the grouped-chunking series (Opt J-P) with five new optimizations targeting residual serial AIV reduces and cube K-loop tuning. S=2 (T=128) wall-clock 3-run median 624us → 545us (−12.6%); cumulative vs pre-tuning baseline 1868us → 545us (−70.8%).
New optimizations (on top of the prior Opt A+B+E+G+J+K+L+M-revert+N+O+P endpoint)
attn_norm_rmsserial reduce → 2-way parallel partial-sum + final reduce.chunked_loop_optimizeris required on the partial scope to fit the 192KB Vec UB at S=2.PARTIALS=2(not 4+) is intentional: it keeps the FP32 add deterministic, preservingqvalidation across devices.qr_quant_amaxintoqr_norm_apply. Eachqr_norm_applytask additionally writes a per-task partial amax toqr_amax_partial[Q_BLOCKS, T]; the residualqr_quant_amaxscope shrinks from 30us → 1.7us. Partial amax is computed onqr_normed_bf16— bit-identical to the value the original scope would have re-read via GM — so the INT8 quant scale is unchanged (qrvalidation atol=1 holds).qr_rms2-way parallel partial-sum (mirrors Opt S).qr_fp32is already FP32 so the inner loop is cast-free; the chunked_loop_optimizer cost from Opt S doesn't apply.Q_PROJ_CHUNK256 → 512 (K-tile).qproj_matmulper-task Exec 74us → 56us (−25%). N-tile (Q_PROJ_OUT_CHUNK) is intentionally left at 128 — the prior Opt B run on this kernel confirmedQ_PROJ_OUT_CHUNK=256triggersACL_ERROR_RT_AICORE_TIMEOUT(CANN template limit).qr_proj_matmul/kv_proj_matmulK-looppl.pipeline(stage=2 → stage=4)perdocs/performance-tuning.mdPart 2 §1. Both haveD_BLOCKS=32, enough iter count for 4-deep ping-pong.qproj_matmulskipped (post-Opt V it hasQ_PROJ_BLOCKS=2, too few iters).Incidental cleanup
init_valueschain across sibling scopes:d0→rms_d0/apply_d0/qr_d0,qr_chunk→qr_norm_chunk. Without this, removing the original serialattn_norm_rmsscope broke downstream scopes withVariable 'd0_inlineNN' used outside its defining scopeSSA errors.Test plan
q/kv/qr/qr_scalevalidation PASS at documented tolerancespython models/deepseek/v4/qkv_proj_rope.py -p a2a3 --enable-l2-swimlane(S=2, T=128, FLASH config)