Update spmd2 version of decode_layer.py#335
Conversation
📝 WalkthroughWalkthroughThis PR refactors the Qwen3-14B single-layer decode implementation from a dynamic JIT-compiled approach to a fixed-shape SPMD program. The new ChangesQwen3-14B Decode Program Refactoring
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
models/qwen3/14b/decode_layer.py (1)
577-580: ⚡ Quick winInconsistent
cur_validhandling between RMSNorm calls.
rmsnorm_kernelis called with actualcur_valid(line 256), butpost_rmsnorm_kernelis called with hardcodedBATCH_TILE(line 579). This meanspost_rmsnorm_kernelalways processes the full tile, including padding rows with potentially uninitialized data.Consider passing the actual valid row count for consistency:
with pl.spmd(RMSNORM_SPMD_CORES): post_norm_tile = self.post_rmsnorm_kernel( - resid1_tile, BATCH_TILE, post_rms_weight, post_norm_tile, + resid1_tile, cur_valid, post_rms_weight, post_norm_tile, )Note: This requires adding
cur_valid = pl.min(BATCH_TILE, user_batch - b0)inside the scope 3 loop or using the existing calculation.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/decode_layer.py` around lines 577 - 580, The post_rmsnorm call uses a fixed BATCH_TILE instead of the actual number of valid rows, causing padded rows to be processed; update the call site that invokes post_rmsnorm_kernel (near the scope 3 loop) to pass the actual cur_valid (computed as pl.min(BATCH_TILE, user_batch - b0) or reuse the existing cur_valid used for rmsnorm_kernel) instead of BATCH_TILE so both rmsnorm_kernel and post_rmsnorm_kernel receive the same valid-row count.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 87-92: Replace the hard-coded globals with parameter-derived
values: compute head_dim_inv as 1.0 / head_dim and decode_attn_scale as 1.0 /
sqrt(head_dim) (use math.sqrt) instead of using HEAD_DIM_INV and ATTN_SCALE, and
set max_ctx_blocks = max_blocks_per_seq and ensure the kernel calls/indices use
max_ctx_blocks rather than MAX_BLOCKS_PER_SEQ; update any references in this
module (e.g., where head_dim_inv, decode_attn_scale, max_ctx_blocks are passed
into or referenced by the attention/kernel logic) so the computed values are
used consistently when non-default head_dim or max_blocks_per_seq are provided.
---
Nitpick comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 577-580: The post_rmsnorm call uses a fixed BATCH_TILE instead of
the actual number of valid rows, causing padded rows to be processed; update the
call site that invokes post_rmsnorm_kernel (near the scope 3 loop) to pass the
actual cur_valid (computed as pl.min(BATCH_TILE, user_batch - b0) or reuse the
existing cur_valid used for rmsnorm_kernel) instead of BATCH_TILE so both
rmsnorm_kernel and post_rmsnorm_kernel receive the same valid-row count.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 85413065-92fc-4373-a758-49d205d3c365
📒 Files selected for processing (1)
models/qwen3/14b/decode_layer.py
| head_dim_inv = HEAD_DIM_INV | ||
| decode_attn_scale = ATTN_SCALE | ||
| num_layers_actual = pl.tensor.dim(input_rms_weight, 0) | ||
| decode_layer_cache_rows = pl.tensor.dim(k_cache, 0) // num_layers_actual | ||
| user_batch = pl.tensor.dim(seq_lens, 0) | ||
| batch_padded = BATCH | ||
| layer_hidden_base = layer_idx * HIDDEN | ||
| layer_inter_base = layer_idx * INTERMEDIATE | ||
| layer_cache_base = layer_idx * decode_layer_cache_rows | ||
|
|
||
| # Intermediate FP32 tensors between scope 1 and scope 2. | ||
| q_proj = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) | ||
| k_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) | ||
| v_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) | ||
| q_proj_norm = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) | ||
| k_proj_norm = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) | ||
|
|
||
| # Scope 1: input RMSNorm + Q/K/V projection. | ||
| # The JIT inline path follows the fixed-BATCH single-layer kernel | ||
| # contract, so every matmul tile has a static M dim of BATCH_TILE. | ||
| for b0 in pl.parallel(0, batch_padded, BATCH_TILE): | ||
| normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): | ||
| partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) | ||
| for kb in pl.range(decode_scope1_hidden_blocks): | ||
| sq_k0 = kb * INPUT_PROJ_K_CHUNK | ||
| sq_chunk = pl.cast( | ||
| q_per_kv = num_heads // num_kv_heads | ||
| q_groups = q_per_kv // Q_HEAD_BATCH | ||
| total_q_groups = num_kv_heads * q_groups | ||
| max_ctx_blocks = max_blocks_per_seq |
There was a problem hiding this comment.
Parameter-derived values are computed but not used; globals used instead.
head_dim_inv and decode_attn_scale are assigned from global constants (HEAD_DIM_INV, ATTN_SCALE) rather than computed from the head_dim parameter. Similarly, max_ctx_blocks is computed from parameters but MAX_BLOCKS_PER_SEQ is used in the kernel (e.g., line 335, 429, etc.). If non-default parameters are passed, these mismatches would cause incorrect attention scaling and block table indexing.
Consider computing these from parameters:
head_dim_inv = 1.0 / head_dim
decode_attn_scale = 1.0 / math.sqrt(head_dim)Or document that only default parameter values are supported.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/decode_layer.py` around lines 87 - 92, Replace the
hard-coded globals with parameter-derived values: compute head_dim_inv as 1.0 /
head_dim and decode_attn_scale as 1.0 / sqrt(head_dim) (use math.sqrt) instead
of using HEAD_DIM_INV and ATTN_SCALE, and set max_ctx_blocks =
max_blocks_per_seq and ensure the kernel calls/indices use max_ctx_blocks rather
than MAX_BLOCKS_PER_SEQ; update any references in this module (e.g., where
head_dim_inv, decode_attn_scale, max_ctx_blocks are passed into or referenced by
the attention/kernel logic) so the computed values are used consistently when
non-default head_dim or max_blocks_per_seq are provided.
There was a problem hiding this comment.
Code Review
This pull request refactors the Qwen3 decode layer to utilize SPMD kernels for RMSNorm operations, replacing previous implementations. It also updates the test infrastructure to use a new 'run' function and refines the golden reference implementation. A high-severity issue was identified regarding the incorrect handling of partial batches in the 'post_rmsnorm_kernel' call, which must be updated to use the correctly calculated 'cur_valid' value instead of 'BATCH_TILE'.
| for b0 in pl.parallel(0, batch_padded, BATCH_TILE): | ||
| resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32) | ||
|
|
||
| for ob in pl.range(decode_q_out_blocks): | ||
| o0 = ob * Q_OUT_CHUNK | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"): | ||
| a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0]) | ||
| w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0]) | ||
| o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) | ||
| for kb in pl.range(1, hidden_blocks): | ||
| k0 = kb * K_CHUNK | ||
| a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]) | ||
| w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) | ||
| o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) | ||
|
|
||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"): | ||
| resid = pl.cast( | ||
| pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), | ||
| target_type=pl.FP32, | ||
| ) | ||
| resid_sum = pl.add(o_acc, resid) | ||
| resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) | ||
|
|
||
| post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) | ||
| # SPMD post_rmsnorm (change 2/2) | ||
| with pl.spmd(RMSNORM_SPMD_CORES): | ||
| post_norm_tile = self.post_rmsnorm_kernel( | ||
| resid1_tile, BATCH_TILE, post_rms_weight, post_norm_tile, | ||
| ) |
There was a problem hiding this comment.
The post_rmsnorm_kernel is called with BATCH_TILE as the cur_valid argument, which is incorrect when the number of valid rows in the batch is less than BATCH_TILE. This should be calculated as pl.min(BATCH_TILE, user_batch - b0) to correctly handle partial batches.
for b0 in pl.parallel(0, batch_padded, BATCH_TILE):
cur_valid = pl.min(BATCH_TILE, user_batch - b0)
resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32)
for ob in pl.range(decode_q_out_blocks):
o0 = ob * Q_OUT_CHUNK
with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"):
a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0])
w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0])
o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32)
for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0])
w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0])
o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"):
resid = pl.cast(
pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]),
target_type=pl.FP32,
)
resid_sum = pl.add(o_acc, resid)
resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0])
post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16)
# SPMD post_rmsnorm (change 2/2)
with pl.spmd(RMSNORM_SPMD_CORES):
post_norm_tile = self.post_rmsnorm_kernel(
resid1_tile, cur_valid, post_rms_weight, post_norm_tile,
)
No description provided.