Split Qwen3 final RMS LM head#331
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR extracts the final RMSNorm+LM-head projection into a new ChangesFinal-Head Extraction and Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the Qwen3-14B model by extracting the final RMSNorm and LM head projection logic from decode_layer.py into a new dedicated module, rms_lm_head.py. Consequently, the decode_layer function signature has been updated to remove the final head parameters and now accepts a next_hidden tensor. Feedback indicates that the allocation of next_hidden inside the layer loop in decode_fwd.py should be moved outside the loop to avoid potential memory pressure caused by repeated allocations across multiple layers.
| current_hidden = pl.assemble(current_hidden, hidden_chunk, [b0, copy_k0]) | ||
|
|
||
| for layer_idx in pl.range(num_layers_actual): | ||
| next_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) |
There was a problem hiding this comment.
Allocating next_hidden inside the pl.range loop can lead to multiple tensor allocations if the compiler does not perform buffer reuse optimization. For a model with many layers, this could significantly increase memory pressure. Consider pre-allocating the necessary buffers outside the loop to ensure constant memory usage regardless of the number of layers.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/rms_lm_head.py`:
- Around line 75-103: The code computes lm_valid_rows = pl.min(BATCH_TILE,
user_batch - b0) which can be negative for padded tiles and is then passed into
valid_shape causing runtime errors; clamp lm_valid_rows to be non‑negative or
skip the store when it is zero. Locate the LM head tile loop (symbols:
pl.parallel over b0, lm_valid_rows, lm_acc_gm, lm_acc_chunk, lm_acc_trimmed,
out) and either replace lm_valid_rows with a non‑negative value (e.g.,
lm_valid_rows_clamped = max(0, lm_valid_rows)) before using it in the
valid_shape slice, or guard the store/assemble of out with a conditional that
skips assembling lm_acc_trimmed into out when lm_valid_rows <= 0. Ensure the
valid_shape argument never receives a negative row count.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 24a07697-24d8-4125-b7e9-7cb08b093e8b
📒 Files selected for processing (3)
models/qwen3/14b/decode_fwd.pymodels/qwen3/14b/decode_layer.pymodels/qwen3/14b/rms_lm_head.py
| for b0 in pl.parallel(0, BATCH, BATCH_TILE): | ||
| lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) | ||
| for ob in pl.parallel(VOCAB // VOCAB_CHUNK): | ||
| lm_o0 = ob * VOCAB_CHUNK | ||
| lm_acc_gm = pl.create_tensor([BATCH_TILE, VOCAB_CHUNK], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head"): | ||
| lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, 0]) | ||
| lm_weight_chunk = pl.slice(lm_head_weight, [VOCAB_CHUNK, LM_HEAD_K_CHUNK], [lm_o0, 0]) | ||
| lm_acc = pl.matmul(lm_hidden_chunk, lm_weight_chunk, out_dtype=pl.FP32, b_trans=True) | ||
| for kb in pl.range(1, HIDDEN // LM_HEAD_K_CHUNK): | ||
| lm_k0 = kb * LM_HEAD_K_CHUNK | ||
| lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, lm_k0]) | ||
| lm_weight_chunk = pl.slice( | ||
| lm_head_weight, | ||
| [VOCAB_CHUNK, LM_HEAD_K_CHUNK], | ||
| [lm_o0, lm_k0], | ||
| ) | ||
| lm_acc = pl.matmul_acc(lm_acc, lm_hidden_chunk, lm_weight_chunk, b_trans=True) | ||
| lm_acc_gm = pl.assemble(lm_acc_gm, lm_acc, [0, 0]) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head_store"): | ||
| lm_acc_chunk = pl.slice(lm_acc_gm, [BATCH_TILE, VOCAB_CHUNK], [0, 0]) | ||
| lm_acc_trimmed = pl.slice( | ||
| lm_acc_chunk, | ||
| [BATCH_TILE, VOCAB_CHUNK], | ||
| [0, 0], | ||
| valid_shape=[lm_valid_rows, VOCAB_CHUNK], | ||
| ) | ||
| out = pl.assemble(out, lm_acc_trimmed, [b0, lm_o0]) |
There was a problem hiding this comment.
Clamp or skip empty batch tiles before trimming out.
lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) becomes negative once b0 >= user_batch. That negative value is then passed into valid_shape, so smaller runtime batches can break on the padded tiles even though this kernel is typed for USER_BATCH_DYN. Clamp the row count at zero or skip the store when the tile is empty.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/rms_lm_head.py` around lines 75 - 103, The code computes
lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) which can be negative for
padded tiles and is then passed into valid_shape causing runtime errors; clamp
lm_valid_rows to be non‑negative or skip the store when it is zero. Locate the
LM head tile loop (symbols: pl.parallel over b0, lm_valid_rows, lm_acc_gm,
lm_acc_chunk, lm_acc_trimmed, out) and either replace lm_valid_rows with a
non‑negative value (e.g., lm_valid_rows_clamped = max(0, lm_valid_rows)) before
using it in the valid_shape slice, or guard the store/assemble of out with a
conditional that skips assembling lm_acc_trimmed into out when lm_valid_rows <=
0. Ensure the valid_shape argument never receives a negative row count.
3dcc285 to
e365e4a
Compare
Extract the final RMSNorm and LM head projection from decode_layer into rms_lm_head.py. Keep decode_layer as the layer-only inline decode kernel and call rms_lm_head from the single-layer test and once after all layers in decode_fwd.
Summary
decode_layer.pyinto inlinerms_lm_head.py.decode_layer.pyas the layer-only inline decode kernel using a typed output hidden-state buffer.rms_lm_headfrom the single-layer test and once after all layers complete indecode_fwd.py.test_decode_layer,golden_decode_layer,decode_fwd, andgolden_decode_fwd.Validation
python -m py_compile models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.py models/qwen3/14b/rms_lm_head.pyruff check --config ruff.toml models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.py models/qwen3/14b/rms_lm_head.pygit diff --check -- models/qwen3/14b/decode_fwd.py models/qwen3/14b/decode_layer.pypython tests/lint/check_headers.pypython tests/lint/check_english_only.pyRelated Issues
N/A