Skip to content

Update spmd2 version of decode_layer.py#335

Open
xzhxzhxzh123 wants to merge 1 commit into
hw-native-sys:mainfrom
xzhxzhxzh123:main1
Open

Update spmd2 version of decode_layer.py#335
xzhxzhxzh123 wants to merge 1 commit into
hw-native-sys:mainfrom
xzhxzhxzh123:main1

Conversation

@xzhxzhxzh123
Copy link
Copy Markdown
Collaborator

No description provided.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors the Qwen3-14B single-layer decode implementation from a dynamic JIT-compiled approach to a fixed-shape SPMD program. The new build_qwen3_decode_program() builder defines a complete decode pipeline with embedded RMSNorm kernels, inlining the prior multi-call structure into a single opaque kernel. The PyTorch golden reference and test harness are updated to match the new program structure and parameters.

Changes

Qwen3-14B Decode Program Refactoring

Layer / File(s) Summary
Decode program builder and kernel definitions
models/qwen3/14b/decode_layer.py
New build_qwen3_decode_program() factory creates a pl.program class with two RMSNorm InCore kernels and a full qwen3_decode opaque kernel. The kernel consolidates input RMSNorm + QKV projection (scope1), grouped attention with RoPE and online-softmax accumulation (scope2), and output projection + post-RMSNorm + MLP with SiLU + final-RMSNorm + LM-head (scope3) into a single compute unit. Module docstring updated to describe the SPMD structure.
Golden reference implementation
models/qwen3/14b/decode_layer.py
PyTorch reference function renamed from golden_decode_layer() to golden_qwen3_decode() with rewritten computation flow matching the new program: input assembly and RMSNorm, QKV + per-head normalization, RoPE + grouped attention with online softmax and KV cache operations, output projection + residual, post-RMSNorm, MLP using F.silu directly, final residual + final RMSNorm + LM head projection.
Test harness integration
models/qwen3/14b/decode_layer.py
__main__ runner switched from run_jit(fn=test_decode_layer, golden_fn=golden_decode_layer, ...) to run(program=build_qwen3_decode_program(...), specs=build_tensor_specs(...), golden_fn=golden_qwen3_decode, ...), removing the old JIT wrapper and wiring constant program parameters directly.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#318: Implements Qwen3-14B single-layer decode forward path with build_qwen3_decode_program and golden_qwen3_decode, directly connected at the decode-program level.
  • hw-native-sys/pypto-lib#319: Refactors Qwen3-14B decode to fuse final RMSNorm and LM head into decode output path using final_norm_weight and lm_head_weight.
  • hw-native-sys/pypto-lib#331: Restructures decode pipeline by moving final RMSNorm+LM-head logic from decode_layer.py into separate rms_lm_head.py, contrasting with this PR's inline integration.

Suggested labels

enhancement

Poem

🐰 The decode pipeline hops through SPMD's door,
RMSNorm and attention dance once more,
SiLU gates glow as logits take flight,
From JIT's embrace to fixed-shape might!
One kernel now holds the whole forward light ✨

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive No description was provided by the author, making it impossible to assess relevance to the changeset. Add a description explaining the motivation, changes, and impact of the decode_layer.py update for better context.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: updating the SPMD2 version of decode_layer.py with new program builder and kernel implementations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
models/qwen3/14b/decode_layer.py (1)

577-580: ⚡ Quick win

Inconsistent cur_valid handling between RMSNorm calls.

rmsnorm_kernel is called with actual cur_valid (line 256), but post_rmsnorm_kernel is called with hardcoded BATCH_TILE (line 579). This means post_rmsnorm_kernel always processes the full tile, including padding rows with potentially uninitialized data.

Consider passing the actual valid row count for consistency:

                with pl.spmd(RMSNORM_SPMD_CORES):
                    post_norm_tile = self.post_rmsnorm_kernel(
-                       resid1_tile, BATCH_TILE, post_rms_weight, post_norm_tile,
+                       resid1_tile, cur_valid, post_rms_weight, post_norm_tile,
                    )

Note: This requires adding cur_valid = pl.min(BATCH_TILE, user_batch - b0) inside the scope 3 loop or using the existing calculation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/decode_layer.py` around lines 577 - 580, The post_rmsnorm
call uses a fixed BATCH_TILE instead of the actual number of valid rows, causing
padded rows to be processed; update the call site that invokes
post_rmsnorm_kernel (near the scope 3 loop) to pass the actual cur_valid
(computed as pl.min(BATCH_TILE, user_batch - b0) or reuse the existing cur_valid
used for rmsnorm_kernel) instead of BATCH_TILE so both rmsnorm_kernel and
post_rmsnorm_kernel receive the same valid-row count.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 87-92: Replace the hard-coded globals with parameter-derived
values: compute head_dim_inv as 1.0 / head_dim and decode_attn_scale as 1.0 /
sqrt(head_dim) (use math.sqrt) instead of using HEAD_DIM_INV and ATTN_SCALE, and
set max_ctx_blocks = max_blocks_per_seq and ensure the kernel calls/indices use
max_ctx_blocks rather than MAX_BLOCKS_PER_SEQ; update any references in this
module (e.g., where head_dim_inv, decode_attn_scale, max_ctx_blocks are passed
into or referenced by the attention/kernel logic) so the computed values are
used consistently when non-default head_dim or max_blocks_per_seq are provided.

---

Nitpick comments:
In `@models/qwen3/14b/decode_layer.py`:
- Around line 577-580: The post_rmsnorm call uses a fixed BATCH_TILE instead of
the actual number of valid rows, causing padded rows to be processed; update the
call site that invokes post_rmsnorm_kernel (near the scope 3 loop) to pass the
actual cur_valid (computed as pl.min(BATCH_TILE, user_batch - b0) or reuse the
existing cur_valid used for rmsnorm_kernel) instead of BATCH_TILE so both
rmsnorm_kernel and post_rmsnorm_kernel receive the same valid-row count.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 85413065-92fc-4373-a758-49d205d3c365

📥 Commits

Reviewing files that changed from the base of the PR and between f0d8ce8 and af15624.

📒 Files selected for processing (1)
  • models/qwen3/14b/decode_layer.py

Comment on lines 87 to +92
head_dim_inv = HEAD_DIM_INV
decode_attn_scale = ATTN_SCALE
num_layers_actual = pl.tensor.dim(input_rms_weight, 0)
decode_layer_cache_rows = pl.tensor.dim(k_cache, 0) // num_layers_actual
user_batch = pl.tensor.dim(seq_lens, 0)
batch_padded = BATCH
layer_hidden_base = layer_idx * HIDDEN
layer_inter_base = layer_idx * INTERMEDIATE
layer_cache_base = layer_idx * decode_layer_cache_rows

# Intermediate FP32 tensors between scope 1 and scope 2.
q_proj = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32)
k_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32)
v_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32)
q_proj_norm = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32)
k_proj_norm = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32)

# Scope 1: input RMSNorm + Q/K/V projection.
# The JIT inline path follows the fixed-BATCH single-layer kernel
# contract, so every matmul tile has a static M dim of BATCH_TILE.
for b0 in pl.parallel(0, batch_padded, BATCH_TILE):
normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"):
partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.range(decode_scope1_hidden_blocks):
sq_k0 = kb * INPUT_PROJ_K_CHUNK
sq_chunk = pl.cast(
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups
max_ctx_blocks = max_blocks_per_seq
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Parameter-derived values are computed but not used; globals used instead.

head_dim_inv and decode_attn_scale are assigned from global constants (HEAD_DIM_INV, ATTN_SCALE) rather than computed from the head_dim parameter. Similarly, max_ctx_blocks is computed from parameters but MAX_BLOCKS_PER_SEQ is used in the kernel (e.g., line 335, 429, etc.). If non-default parameters are passed, these mismatches would cause incorrect attention scaling and block table indexing.

Consider computing these from parameters:

head_dim_inv = 1.0 / head_dim
decode_attn_scale = 1.0 / math.sqrt(head_dim)

Or document that only default parameter values are supported.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/decode_layer.py` around lines 87 - 92, Replace the
hard-coded globals with parameter-derived values: compute head_dim_inv as 1.0 /
head_dim and decode_attn_scale as 1.0 / sqrt(head_dim) (use math.sqrt) instead
of using HEAD_DIM_INV and ATTN_SCALE, and set max_ctx_blocks =
max_blocks_per_seq and ensure the kernel calls/indices use max_ctx_blocks rather
than MAX_BLOCKS_PER_SEQ; update any references in this module (e.g., where
head_dim_inv, decode_attn_scale, max_ctx_blocks are passed into or referenced by
the attention/kernel logic) so the computed values are used consistently when
non-default head_dim or max_blocks_per_seq are provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Qwen3 decode layer to utilize SPMD kernels for RMSNorm operations, replacing previous implementations. It also updates the test infrastructure to use a new 'run' function and refines the golden reference implementation. A high-severity issue was identified regarding the incorrect handling of partial batches in the 'post_rmsnorm_kernel' call, which must be updated to use the correctly calculated 'cur_valid' value instead of 'BATCH_TILE'.

Comment on lines +551 to +580
for b0 in pl.parallel(0, batch_padded, BATCH_TILE):
resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32)

for ob in pl.range(decode_q_out_blocks):
o0 = ob * Q_OUT_CHUNK

with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"):
a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0])
w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0])
o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32)
for kb in pl.range(1, hidden_blocks):
k0 = kb * K_CHUNK
a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0])
w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0])
o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"):
resid = pl.cast(
pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]),
target_type=pl.FP32,
)
resid_sum = pl.add(o_acc, resid)
resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0])

post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16)
# SPMD post_rmsnorm (change 2/2)
with pl.spmd(RMSNORM_SPMD_CORES):
post_norm_tile = self.post_rmsnorm_kernel(
resid1_tile, BATCH_TILE, post_rms_weight, post_norm_tile,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The post_rmsnorm_kernel is called with BATCH_TILE as the cur_valid argument, which is incorrect when the number of valid rows in the batch is less than BATCH_TILE. This should be calculated as pl.min(BATCH_TILE, user_batch - b0) to correctly handle partial batches.

            for b0 in pl.parallel(0, batch_padded, BATCH_TILE):
                cur_valid = pl.min(BATCH_TILE, user_batch - b0)
                resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32)

                for ob in pl.range(decode_q_out_blocks):
                    o0 = ob * Q_OUT_CHUNK

                    with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"):
                        a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0])
                        w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0])
                        o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32)
                        for kb in pl.range(1, hidden_blocks):
                            k0 = kb * K_CHUNK
                            a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0])
                            w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0])
                            o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk)

                    with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"):
                        resid = pl.cast(
                            pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]),
                            target_type=pl.FP32,
                        )
                        resid_sum = pl.add(o_acc, resid)
                        resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0])

                post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16)
                # SPMD post_rmsnorm (change 2/2)
                with pl.spmd(RMSNORM_SPMD_CORES):
                    post_norm_tile = self.post_rmsnorm_kernel(
                        resid1_tile, cur_valid, post_rms_weight, post_norm_tile,
                    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant