Optimize Qwen3-14B decode: manual KV-cache/silu deps, pipeline, wider LM head#337
Optimize Qwen3-14B decode: manual KV-cache/silu deps, pipeline, wider LM head#337lyfne123 wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes task scheduling in the decode_layer and rms_lm_head modules by explicitly managing task dependencies using no_dep_args, deps, and pl.pipeline to prevent unnecessary serialization. It also increases the VOCAB_CHUNK size in config.py to improve LM-head performance and introduces a module-level MLP_OUT_BLOCKS constant. I have one suggestion to further clean up the code by using this new constant consistently.
| w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base, d0]) | ||
| down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) | ||
| for ob in pl.range(1, decode_mlp_out_blocks): | ||
| for ob in pl.pipeline(1, decode_mlp_out_blocks, stage=2): |
There was a problem hiding this comment.
For consistency and to reduce redundancy, consider using the module-level constant MLP_OUT_BLOCKS here instead of the local variable decode_mlp_out_blocks.
Using the global constant would allow you to remove the redundant local variable decode_mlp_out_blocks from the decode_layer function, making the code cleaner.
| for ob in pl.pipeline(1, decode_mlp_out_blocks, stage=2): | |
| for ob in pl.pipeline(1, MLP_OUT_BLOCKS, stage=2): |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR optimizes Qwen3 14B decode kernel task scheduling by expanding the LM-head vocab chunk to 512 and introducing explicit task-level dependencies in attention and MLP stages. The RoPE/KV cache now produces a task handle that gates downstream matmuls, and MLP SiLU computations are individually tracked so that down-projection waits only on their completion. ChangesDecode kernel task dependency and pipelining
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
models/qwen3/14b/decode_layer.py (1)
84-87: ⚡ Quick winFail fast if the MLP tiling stops dividing evenly.
MLP_OUT_BLOCKS,decode_mlp_out_blocks, andsilu_tidsall assume exactINTERMEDIATE / MLP_OUT_CHUNKtiling. If that invariant changes later, the tail channels get skipped rather than rejected.Suggested guard
MLP_OUT_BLOCKS = INTERMEDIATE // MLP_OUT_CHUNK +if INTERMEDIATE % MLP_OUT_CHUNK != 0: + raise ValueError("INTERMEDIATE must be divisible by MLP_OUT_CHUNK")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/decode_layer.py` around lines 84 - 87, The tiling math assumes INTERMEDIATE is divisible by MLP_OUT_CHUNK, but currently silently drops tail channels; add a fail-fast check: assert INTERMEDIATE % MLP_OUT_CHUNK == 0 (or raise ValueError) before computing the module-level MLP_OUT_BLOCKS and also add the same sanity check inside the pl.jit function where decode_mlp_out_blocks and silu_tids are computed so the JITed code will raise if the invariant breaks; reference the constants MLP_OUT_BLOCKS, decode_mlp_out_blocks, silu_tids, INTERMEDIATE, and MLP_OUT_CHUNK when adding these guards.models/qwen3/14b/config.py (1)
64-68: ⚡ Quick winGuard the new LM-head tiling contract.
rms_lm_head()iteratesVOCAB // VOCAB_CHUNKblocks and never handles a remainder, so a future vocab change here would silently drop tail logits. A module-level check would fail fast instead.Suggested guard
VOCAB_CHUNK = 512 +if VOCAB % VOCAB_CHUNK != 0: + raise ValueError("VOCAB must be divisible by VOCAB_CHUNK")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/config.py` around lines 64 - 68, Add a module-level guard to ensure the VOCAB and VOCAB_CHUNK tiling contract holds so rms_lm_head() won't silently drop tail logits: check that VOCAB % VOCAB_CHUNK == 0 at import time (and raise/abort with a clear message if not). Place this check near the VOCAB_CHUNK definition in config.py and reference the symbols VOCAB and VOCAB_CHUNK so future changes fail fast if the remainder would be non-zero.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@models/qwen3/14b/config.py`:
- Around line 64-68: Add a module-level guard to ensure the VOCAB and
VOCAB_CHUNK tiling contract holds so rms_lm_head() won't silently drop tail
logits: check that VOCAB % VOCAB_CHUNK == 0 at import time (and raise/abort with
a clear message if not). Place this check near the VOCAB_CHUNK definition in
config.py and reference the symbols VOCAB and VOCAB_CHUNK so future changes fail
fast if the remainder would be non-zero.
In `@models/qwen3/14b/decode_layer.py`:
- Around line 84-87: The tiling math assumes INTERMEDIATE is divisible by
MLP_OUT_CHUNK, but currently silently drops tail channels; add a fail-fast
check: assert INTERMEDIATE % MLP_OUT_CHUNK == 0 (or raise ValueError) before
computing the module-level MLP_OUT_BLOCKS and also add the same sanity check
inside the pl.jit function where decode_mlp_out_blocks and silu_tids are
computed so the JITed code will raise if the invariant breaks; reference the
constants MLP_OUT_BLOCKS, decode_mlp_out_blocks, silu_tids, INTERMEDIATE, and
MLP_OUT_CHUNK when adding these guards.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: e4dec897-0d80-4d75-8c99-bc2502ab90b0
📒 Files selected for processing (3)
models/qwen3/14b/config.pymodels/qwen3/14b/decode_layer.pymodels/qwen3/14b/rms_lm_head.py
771e465 to
68a2460
Compare
Per-layer scheduling fixes in decode_layer.py (orchestration-only; no kernel numerics change): 1. rope_kv_cache: opt k_cache / v_cache / all_q_padded out of OverlapMap via pl.at(no_dep_args=...) and re-establish the producer fence on qk_matmul / sv_matmul via deps=[rope_tid]. The paged-attention slot_mapping guarantees disjoint per-batch writes the compiler cannot prove statically; without this the 16-batch fan-out serialises into a 16-long chain. 2. silu: mlp_tile uses manual_dep=True so its disjoint-slice writes fan out instead of serialising WAW onto one core; the silu producer TaskIds are collected (silu_tids) and down_proj fences on them via deps=. silu spread 1 core -> ~14 cores. Validated on decode_fwd.py --num-layers 2/40 (pass_rate >= 0.98) on a2a3 / a2a3sim. Co-Authored-By: Claude Opus 4.7 <[email protected]>
68a2460 to
b5b02ee
Compare
Summary
k_cache/v_cache/all_q_paddedout of OverlapMap viapl.at(no_dep_args=...)and re-establish the producer fence onqk_matmul/sv_matmulviadeps=[rope_tid]. The paged-attentionslot_mappingguarantees disjoint per-batch writes that the compiler cannot prove statically; without this hint the 16-batch fan-out serialises into a 16-long chain.mlp_tileusesmanual_dep=Trueso its disjoint-slice writes fan out instead of serialising WAW onto one core; the silu producer TaskIds are collected (silu_tids) anddown_projfences on them viadeps=. silu spreads 1 core -> ~14 cores.stage=2):down_projandlm_head.All changes validated on
decode_fwd.py --num-layers 2/40(pass_rate >= 0.98). Builds on the upstream final-head split (#331).