Skip to content

Split Qwen3 final RMS LM head#331

Merged
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
ndleslx:qwen3-split-rms-lm-head
May 20, 2026
Merged

Split Qwen3 final RMS LM head#331
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
ndleslx:qwen3-split-rms-lm-head

Conversation

@ndleslx
Copy link
Copy Markdown
Contributor

@ndleslx ndleslx commented May 20, 2026

Summary

  • Extract Qwen3 14B final RMSNorm and LM head projection from decode_layer.py into inline rms_lm_head.py.
  • Keep decode_layer.py as the layer-only inline decode kernel using a typed output hidden-state buffer.
  • Call rms_lm_head from the single-layer test and once after all layers complete in decode_fwd.py.
  • Rename entrypoints to test_decode_layer, golden_decode_layer, decode_fwd, and golden_decode_fwd.

Validation

  • python -m py_compile models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.py models/qwen3/14b/rms_lm_head.py
  • ruff check --config ruff.toml models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.py models/qwen3/14b/rms_lm_head.py
  • git diff --check -- models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.py
  • python tests/lint/check_headers.py
  • python tests/lint/check_english_only.py

Related Issues

N/A

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR extracts the final RMSNorm+LM-head projection into a new rms_lm_head kernel, refactors decode_layer to write layer outputs into a caller-provided next_hidden, and updates decode_fwd to call rms_lm_head once after all decoder layers. Tests and exported JIT/golden names are updated accordingly.

Changes

Final-Head Extraction and Integration

Layer / File(s) Summary
New rms_lm_head kernel implementation
models/qwen3/14b/rms_lm_head.py
New module implementing RMSNorm normalization in FP32 over BF16 hidden states, then chunked-matmul LM-head projection with FP32 accumulation, assembling results with sequence-length-based trimming to avoid writing invalid batch rows.
decode_layer refactoring to omit final-head logic
models/qwen3/14b/decode_layer.py
decode_layer signature updated to accept caller-provided next_hidden and return it; internal allocation and final RMSNorm+LM-head computation removed; imports, tests, and golden reference names adjusted to the new contract.
decode_fwd decode-loop and final-head orchestration
models/qwen3/14b/decode_fwd.py
Decode loop modified to allocate and pass next_hidden to each decode_layer call; rms_lm_head imported and invoked after the loop to compute final logits; JIT entrypoint and golden reference exports renamed.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#319: Both PRs refactor the Qwen3-14B decode final-head logic (final RMSNorm + LM-head projection) to apply once after decoder layers rather than per-layer.
  • hw-native-sys/pypto-lib#258: Also introduces/uses an rms_lm_head-style fused path for final-norm/LM-head computation in the decode/generation flow.
  • hw-native-sys/pypto-lib#309: Related work adding standalone final-RMS and LM-head kernel builders that overlap the same final-head computation area.

Suggested labels

enhancement

Poem

🐰 I hopped through layers, tidy and fleet,

Pulled the final head out — clean and neat.
Hidden states marched, then logits unfurled,
One RMSNorm, one head, to light up the world. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Split Qwen3 final RMS LM head' clearly and concisely summarizes the main change: extracting the final RMSNorm and LM head projection from decode_layer.py into a new rms_lm_head.py module.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description clearly explains the changes: extracting final RMSNorm and LM head projection into a new module, keeping decode_layer as layer-only, and calling rms_lm_head appropriately.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Qwen3-14B model by extracting the final RMSNorm and LM head projection logic from decode_layer.py into a new dedicated module, rms_lm_head.py. Consequently, the decode_layer function signature has been updated to remove the final head parameters and now accepts a next_hidden tensor. Feedback indicates that the allocation of next_hidden inside the layer loop in decode_fwd.py should be moved outside the loop to avoid potential memory pressure caused by repeated allocations across multiple layers.

current_hidden = pl.assemble(current_hidden, hidden_chunk, [b0, copy_k0])

for layer_idx in pl.range(num_layers_actual):
next_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating next_hidden inside the pl.range loop can lead to multiple tensor allocations if the compiler does not perform buffer reuse optimization. For a model with many layers, this could significantly increase memory pressure. Consider pre-allocating the necessary buffers outside the loop to ensure constant memory usage regardless of the number of layers.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 75-103: The code computes lm_valid_rows = pl.min(BATCH_TILE,
user_batch - b0) which can be negative for padded tiles and is then passed into
valid_shape causing runtime errors; clamp lm_valid_rows to be non‑negative or
skip the store when it is zero. Locate the LM head tile loop (symbols:
pl.parallel over b0, lm_valid_rows, lm_acc_gm, lm_acc_chunk, lm_acc_trimmed,
out) and either replace lm_valid_rows with a non‑negative value (e.g.,
lm_valid_rows_clamped = max(0, lm_valid_rows)) before using it in the
valid_shape slice, or guard the store/assemble of out with a conditional that
skips assembling lm_acc_trimmed into out when lm_valid_rows <= 0. Ensure the
valid_shape argument never receives a negative row count.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 24a07697-24d8-4125-b7e9-7cb08b093e8b

📥 Commits

Reviewing files that changed from the base of the PR and between 5ed1797 and f03d9a4.

📒 Files selected for processing (3)
  • models/qwen3/14b/decode_fwd.py
  • models/qwen3/14b/decode_layer.py
  • models/qwen3/14b/rms_lm_head.py

Comment on lines +75 to +103
for b0 in pl.parallel(0, BATCH, BATCH_TILE):
lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0)
for ob in pl.parallel(VOCAB // VOCAB_CHUNK):
lm_o0 = ob * VOCAB_CHUNK
lm_acc_gm = pl.create_tensor([BATCH_TILE, VOCAB_CHUNK], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head"):
lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, 0])
lm_weight_chunk = pl.slice(lm_head_weight, [VOCAB_CHUNK, LM_HEAD_K_CHUNK], [lm_o0, 0])
lm_acc = pl.matmul(lm_hidden_chunk, lm_weight_chunk, out_dtype=pl.FP32, b_trans=True)
for kb in pl.range(1, HIDDEN // LM_HEAD_K_CHUNK):
lm_k0 = kb * LM_HEAD_K_CHUNK
lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, lm_k0])
lm_weight_chunk = pl.slice(
lm_head_weight,
[VOCAB_CHUNK, LM_HEAD_K_CHUNK],
[lm_o0, lm_k0],
)
lm_acc = pl.matmul_acc(lm_acc, lm_hidden_chunk, lm_weight_chunk, b_trans=True)
lm_acc_gm = pl.assemble(lm_acc_gm, lm_acc, [0, 0])

with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head_store"):
lm_acc_chunk = pl.slice(lm_acc_gm, [BATCH_TILE, VOCAB_CHUNK], [0, 0])
lm_acc_trimmed = pl.slice(
lm_acc_chunk,
[BATCH_TILE, VOCAB_CHUNK],
[0, 0],
valid_shape=[lm_valid_rows, VOCAB_CHUNK],
)
out = pl.assemble(out, lm_acc_trimmed, [b0, lm_o0])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Clamp or skip empty batch tiles before trimming out.

lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) becomes negative once b0 >= user_batch. That negative value is then passed into valid_shape, so smaller runtime batches can break on the padded tiles even though this kernel is typed for USER_BATCH_DYN. Clamp the row count at zero or skip the store when the tile is empty.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/rms_lm_head.py` around lines 75 - 103, The code computes
lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) which can be negative for
padded tiles and is then passed into valid_shape causing runtime errors; clamp
lm_valid_rows to be non‑negative or skip the store when it is zero. Locate the
LM head tile loop (symbols: pl.parallel over b0, lm_valid_rows, lm_acc_gm,
lm_acc_chunk, lm_acc_trimmed, out) and either replace lm_valid_rows with a
non‑negative value (e.g., lm_valid_rows_clamped = max(0, lm_valid_rows)) before
using it in the valid_shape slice, or guard the store/assemble of out with a
conditional that skips assembling lm_acc_trimmed into out when lm_valid_rows <=
0. Ensure the valid_shape argument never receives a negative row count.

@ndleslx ndleslx force-pushed the qwen3-split-rms-lm-head branch 2 times, most recently from 3dcc285 to e365e4a Compare May 20, 2026 06:40
Extract the final RMSNorm and LM head projection from decode_layer into rms_lm_head.py.

Keep decode_layer as the layer-only inline decode kernel and call rms_lm_head from the single-layer test and once after all layers in decode_fwd.
@zhangqi-chen zhangqi-chen merged commit a83c8ef into hw-native-sys:main May 20, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants