Skip to content

Optimize Qwen3-14B decode: manual KV-cache/silu deps, pipeline, wider LM head#337

Closed
lyfne123 wants to merge 1 commit into
hw-native-sys:mainfrom
lyfne123:qwen3-decode-manual-dep-pipeline
Closed

Optimize Qwen3-14B decode: manual KV-cache/silu deps, pipeline, wider LM head#337
lyfne123 wants to merge 1 commit into
hw-native-sys:mainfrom
lyfne123:qwen3-decode-manual-dep-pipeline

Conversation

@lyfne123
Copy link
Copy Markdown
Contributor

Summary

  • rope_kv_cache: opt k_cache / v_cache / all_q_padded out of OverlapMap via pl.at(no_dep_args=...) and re-establish the producer fence on qk_matmul / sv_matmul via deps=[rope_tid]. The paged-attention slot_mapping guarantees disjoint per-batch writes that the compiler cannot prove statically; without this hint the 16-batch fan-out serialises into a 16-long chain.
  • silu: mlp_tile uses manual_dep=True so its disjoint-slice writes fan out instead of serialising WAW onto one core; the silu producer TaskIds are collected (silu_tids) and down_proj fences on them via deps=. silu spreads 1 core -> ~14 cores.
  • Pipeline matmul K-reduction loops (stage=2): down_proj and lm_head.
  • VOCAB_CHUNK 64 -> 512: cuts the LM-head projection from 2376 to 297 output blocks (~8x fewer tasks in the scheduler-bound regime) and widens the matmul N dim.

All changes validated on decode_fwd.py --num-layers 2/40 (pass_rate >= 0.98). Builds on the upstream final-head split (#331).

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes task scheduling in the decode_layer and rms_lm_head modules by explicitly managing task dependencies using no_dep_args, deps, and pl.pipeline to prevent unnecessary serialization. It also increases the VOCAB_CHUNK size in config.py to improve LM-head performance and introduces a module-level MLP_OUT_BLOCKS constant. I have one suggestion to further clean up the code by using this new constant consistently.

Comment thread models/qwen3/14b/decode_layer.py Outdated
w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base, d0])
down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32)
for ob in pl.range(1, decode_mlp_out_blocks):
for ob in pl.pipeline(1, decode_mlp_out_blocks, stage=2):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and to reduce redundancy, consider using the module-level constant MLP_OUT_BLOCKS here instead of the local variable decode_mlp_out_blocks.

Using the global constant would allow you to remove the redundant local variable decode_mlp_out_blocks from the decode_layer function, making the code cleaner.

Suggested change
for ob in pl.pipeline(1, decode_mlp_out_blocks, stage=2):
for ob in pl.pipeline(1, MLP_OUT_BLOCKS, stage=2):

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR optimizes Qwen3 14B decode kernel task scheduling by expanding the LM-head vocab chunk to 512 and introducing explicit task-level dependencies in attention and MLP stages. The RoPE/KV cache now produces a task handle that gates downstream matmuls, and MLP SiLU computations are individually tracked so that down-projection waits only on their completion.

Changes

Decode kernel task dependency and pipelining

Layer / File(s) Summary
Configuration constants and MLP block sizing
models/qwen3/14b/config.py, models/qwen3/14b/decode_layer.py
Updates VOCAB_CHUNK from 64 to 512 for LM-head vocab tiling, and adds MLP_OUT_BLOCKS compile-time constant to provide a statically acceptable size for pl.array.create in the JIT kernel.
RoPE and attention matmul explicit dependencies
models/qwen3/14b/decode_layer.py
Captures RoPE + KV cache update task completion as rope_tid via no_dep_args on cached buffers, then makes subsequent qk_matmul and sv_matmul operations explicitly depend on that handle to enforce strict ordering.
MLP block staging with SiLU task tracking and gated down-projection
models/qwen3/14b/decode_layer.py
Enables manual dependency control for mlp_tile, captures per-block SiLU task IDs in silu_tids array with no_dep_args regions, and updates down_proj to wait on all SiLU completions before consuming mlp_tile.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#310: Both PRs modify the Qwen3 14B MLP decode "down_proj" stage—main PR changes the task/dependency wiring around MLP/SiLU and down_proj, while the retrieved PR optimizes the down-projection + residual compute paths.
  • hw-native-sys/pypto-lib#231: Both PRs modify the Qwen3 14B decode path's MLP/out-projection and down-projection computation/tiling—main PR by changing task dependencies and adding MLP_OUT_BLOCKS, and retrieved PR by updating decode tiling constants and the PyTorch golden reference.
  • hw-native-sys/pypto-lib#51: Both PRs refactor Qwen3 decode attention dataflow around the RoPE/KV-cache update and downstream matmul/softmax execution order, though the main PR additionally introduces explicit JIT dependency handles.

Suggested labels

enhancement

Poem

🐰 A faster decode path hops along,
With task IDs strong and dependencies long,
RoPE and SiLU now march in place—
No waiting without reason's grace,
Pipelined stages at perfect pace! 🎯

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: manual KV-cache/silu dependency optimization, pipeline improvements, and wider LM head vocab chunking (64→512).
Description check ✅ Passed The description provides detailed technical context for all major changes in the pull request, explaining the rationale and impact of each optimization with specific metrics.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
models/qwen3/14b/decode_layer.py (1)

84-87: ⚡ Quick win

Fail fast if the MLP tiling stops dividing evenly.

MLP_OUT_BLOCKS, decode_mlp_out_blocks, and silu_tids all assume exact INTERMEDIATE / MLP_OUT_CHUNK tiling. If that invariant changes later, the tail channels get skipped rather than rejected.

Suggested guard
 MLP_OUT_BLOCKS = INTERMEDIATE // MLP_OUT_CHUNK
+if INTERMEDIATE % MLP_OUT_CHUNK != 0:
+    raise ValueError("INTERMEDIATE must be divisible by MLP_OUT_CHUNK")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/decode_layer.py` around lines 84 - 87, The tiling math
assumes INTERMEDIATE is divisible by MLP_OUT_CHUNK, but currently silently drops
tail channels; add a fail-fast check: assert INTERMEDIATE % MLP_OUT_CHUNK == 0
(or raise ValueError) before computing the module-level MLP_OUT_BLOCKS and also
add the same sanity check inside the pl.jit function where decode_mlp_out_blocks
and silu_tids are computed so the JITed code will raise if the invariant breaks;
reference the constants MLP_OUT_BLOCKS, decode_mlp_out_blocks, silu_tids,
INTERMEDIATE, and MLP_OUT_CHUNK when adding these guards.
models/qwen3/14b/config.py (1)

64-68: ⚡ Quick win

Guard the new LM-head tiling contract.

rms_lm_head() iterates VOCAB // VOCAB_CHUNK blocks and never handles a remainder, so a future vocab change here would silently drop tail logits. A module-level check would fail fast instead.

Suggested guard
 VOCAB_CHUNK = 512
+if VOCAB % VOCAB_CHUNK != 0:
+    raise ValueError("VOCAB must be divisible by VOCAB_CHUNK")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/config.py` around lines 64 - 68, Add a module-level guard to
ensure the VOCAB and VOCAB_CHUNK tiling contract holds so rms_lm_head() won't
silently drop tail logits: check that VOCAB % VOCAB_CHUNK == 0 at import time
(and raise/abort with a clear message if not). Place this check near the
VOCAB_CHUNK definition in config.py and reference the symbols VOCAB and
VOCAB_CHUNK so future changes fail fast if the remainder would be non-zero.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@models/qwen3/14b/config.py`:
- Around line 64-68: Add a module-level guard to ensure the VOCAB and
VOCAB_CHUNK tiling contract holds so rms_lm_head() won't silently drop tail
logits: check that VOCAB % VOCAB_CHUNK == 0 at import time (and raise/abort with
a clear message if not). Place this check near the VOCAB_CHUNK definition in
config.py and reference the symbols VOCAB and VOCAB_CHUNK so future changes fail
fast if the remainder would be non-zero.

In `@models/qwen3/14b/decode_layer.py`:
- Around line 84-87: The tiling math assumes INTERMEDIATE is divisible by
MLP_OUT_CHUNK, but currently silently drops tail channels; add a fail-fast
check: assert INTERMEDIATE % MLP_OUT_CHUNK == 0 (or raise ValueError) before
computing the module-level MLP_OUT_BLOCKS and also add the same sanity check
inside the pl.jit function where decode_mlp_out_blocks and silu_tids are
computed so the JITed code will raise if the invariant breaks; reference the
constants MLP_OUT_BLOCKS, decode_mlp_out_blocks, silu_tids, INTERMEDIATE, and
MLP_OUT_CHUNK when adding these guards.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: e4dec897-0d80-4d75-8c99-bc2502ab90b0

📥 Commits

Reviewing files that changed from the base of the PR and between 38cd262 and 771e465.

📒 Files selected for processing (3)
  • models/qwen3/14b/config.py
  • models/qwen3/14b/decode_layer.py
  • models/qwen3/14b/rms_lm_head.py

@lyfne123 lyfne123 force-pushed the qwen3-decode-manual-dep-pipeline branch from 771e465 to 68a2460 Compare May 21, 2026 06:36
Per-layer scheduling fixes in decode_layer.py (orchestration-only; no kernel
numerics change):

1. rope_kv_cache: opt k_cache / v_cache / all_q_padded out of OverlapMap via
   pl.at(no_dep_args=...) and re-establish the producer fence on
   qk_matmul / sv_matmul via deps=[rope_tid]. The paged-attention
   slot_mapping guarantees disjoint per-batch writes the compiler cannot
   prove statically; without this the 16-batch fan-out serialises into a
   16-long chain.

2. silu: mlp_tile uses manual_dep=True so its disjoint-slice writes fan out
   instead of serialising WAW onto one core; the silu producer TaskIds are
   collected (silu_tids) and down_proj fences on them via deps=. silu spread
   1 core -> ~14 cores.

Validated on decode_fwd.py --num-layers 2/40 (pass_rate >= 0.98) on a2a3 /
a2a3sim.

Co-Authored-By: Claude Opus 4.7 <[email protected]>
@lyfne123 lyfne123 force-pushed the qwen3-decode-manual-dep-pipeline branch from 68a2460 to b5b02ee Compare May 21, 2026 06:51
@lyfne123 lyfne123 closed this May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant