Skip to content

Add spmd and mix versions of qwen3 14B decode#330

Closed
xzhxzhxzh123 wants to merge 10 commits into
hw-native-sys:mainfrom
xzhxzhxzh123:main
Closed

Add spmd and mix versions of qwen3 14B decode#330
xzhxzhxzh123 wants to merge 10 commits into
hw-native-sys:mainfrom
xzhxzhxzh123:main

Conversation

@xzhxzhxzh123
Copy link
Copy Markdown
Collaborator

No description provided.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Existing Qwen3-14B decode.py is refactored to extract RMSNorm logic into SPMD kernels and simplify MLP paths; three new complete single-layer decode implementations are added as variants using different kernel construction strategies (fused online-softmax, fused matmul-softmax, and SPMD/InCore primitives).

Changes

Qwen3-14B Decode SPMD Refactoring

Layer / File(s) Summary
SPMD RMSNorm kernel extraction
models/qwen3/14b/qwen3_14b_decode.py
New RMSNORM_SPMD_CORES and RMSNORM_SPMD_ROWS constants define SPMD tiling. Two InCore kernel functions (rmsnorm_kernel and post_rmsnorm_kernel) implement per-tile RMSNorm with two-pass variance/inv-rms computation and BF16 assembly.
RMSNorm kernel integration
models/qwen3/14b/qwen3_14b_decode.py
Scope 1 input RMSNorm and Scope 3 post-attention RMSNorm replace inline logic with SPMD kernel invocations, removing original per-core-group accumulation.
MLP down-projection simplification
models/qwen3/14b/qwen3_14b_decode.py
The fast-path/tail-path split for down-projection and residual is removed; computation always uses FP32, residual is added in FP32, then cast to BF16 with valid_shape trimming before assembly.
RunConfig harness update
models/qwen3/14b/qwen3_14b_decode.py
Test harness imports RunConfig and updates the run(...) call to use a single config=RunConfig(...) object instead of separate compile_cfg, runtime_cfg, rtol, and atol arguments.

Qwen3-14B Fused QK-Softmax-SV Online Softmax Decode

Layer / File(s) Summary
Configuration and constants
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
Module-level configuration defines model dimensions, tiling factors, dynamic dimension placeholders (USER_BATCH_DYN, KV_CACHE_ROWS_DYN, BLOCK_TABLE_FLAT_DYN), and numerical parameters used throughout the kernel.
Fused kernel implementation
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
build_qwen3_decode_program() builds a PyPTO kernel fusing all three scopes: input RMSNorm and Q/K/V projections with BATCH_TILE padding, per-head Q/K normalization and RoPE with paged KV-cache writes, grouped attention with online-softmax running max/exp-sum accumulation, output projection, post-attention RMSNorm, gated MLP with SiLU, and BF16 output trimming/assembly.
Test tensor specs builder
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
build_tensor_specs() computes cache sizes, initializes weights and KV-caches with randomized shape-consistent data, and constructs block-table and slot-mapping tensors for paged cache addressing.
PyTorch golden reference
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
golden_qwen3_decode() provides a PyTorch reference that mirrors the three-scope kernel path with input RMSNorm, projections, RoPE, grouped QK matmul with online-softmax accumulation, output projection, post-RMSNorm, and gated MLP.
CLI runner and execution
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
__main__ parses CLI flags (platform, device, batch, max-seq, PMU, kernel-insight export), runs the compiled kernel via run() with golden reference, checks pass/fail, and optionally exports kernel insight traces.

Qwen3-14B Fused QK Matmul-Softmax Decode

Layer / File(s) Summary
Configuration and constants
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
Module-level constants define model dimensions, numerical epsilon, tiling factors, and dynamic dimension placeholders for user batch and KV-cache sizing.
Program builder and kernel signature
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
build_qwen3_decode_program() sets up derived block/grid sizing, the Qwen3Decode class, and qwen3_decode function signature with all input tensor types and BF16 output.
Scope 1: Input RMSNorm and projections
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
Allocates batch-padded staging tensors and implements per-tile RMSNorm followed by Q/K/V projections using tiled matmuls with valid_shape handling on input loads for tail padding.
Scope 2: Attention with online softmax
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
Implements per-head Q/K RMS normalization, RoPE rotation, KV-cache and V-cache writes, then grouped decode attention using max-subtracted softmax with online softmax accumulation (running max/sum tracking and rescaling partial outputs).
Scope 3: Output projection, RMSNorm, and MLP
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
Performs output projection with residual addition, post-attention RMSNorm, gated MLP (gate/up + SiLU + down projection), and includes conditional tail-safe paths with valid_shape trimming for BF16 assembly.
Test tensor specs builder
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
build_tensor_specs() computes host tensor shapes, initializes all input and weight tensors, and constructs block-table and slot-mapping tensors.
PyTorch golden reference
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
golden_qwen3_decode() provides a PyTorch reference that mirrors the three scopes with RMSNorm, Q/K/V projections, RoPE with K/V cache updates, grouped attention softmax with causal masking and online accumulation, post-RMSNorm, and gated MLP.
CLI runner and execution
models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
__main__ parses CLI flags, calls golden.run() with compilation and runtime configuration, checks results, and optionally exports kernel insight traces.

Qwen3-14B SPMD/InCore Decode with RMSNorm Kernels

Layer / File(s) Summary
Configuration and dynamic dimensions
models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
Module-level constants define model dimensions, tiling parameters, and runtime-dynamic dimension placeholders for user batch, KV-cache rows, and block-table addressing.
SPMD kernel program and forward path
models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
build_qwen3_decode_program() defines nested RMSNorm and post-RMSNorm InCore kernel helpers and the main qwen3_decode forward. Forward implements all three scopes: (1) input RMSNorm and Q/K/V projections into FP32 staging, (2) per-head RMS scaling, RoPE rotation, KV-cache writes, grouped attention with online-softmax accumulation and SV matmul, (3) output projection, post-RMSNorm, MLP with SiLU gating, and final residuals, with internal padding/trim handling driven by dynamic batch dimensions.
Test tensor specs builder
models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
build_tensor_specs() generates host-side input/weight/cache tensor specifications with initialization routines for sequence lengths, slot/block mappings, RoPE cos/sin tables, KV caches, and projection/MLP weights.
PyTorch golden reference
models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
golden_qwen3_decode() provides a PyTorch reference that mirrors the three scopes, including RMSNorm, Q/K/V projections, RoPE rotation with cache writes, grouped attention with numerically-stable online-softmax accumulation, output projection, post-RMSNorm, and MLP with final residuals.
CLI runner and execution
models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
__main__ parses CLI arguments (platform, device, batch, max-seq, PMU, kernel-insight export), runs the compiled program via run() with configured tolerances and compilation options, and optionally exports kernel insight traces.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#319: Modifies Qwen3-14B qwen3_14b_decode.py scope-3 post-RMSNorm stage and output path, overlapping with main PR's SPMD kernel refactoring and down-projection simplification.
  • hw-native-sys/pypto-lib#231: Refactors Scope 1/Scope 3 RMSNorm computations in qwen3_14b_decode.py with focus on tiling and golden_qwen3_decode numerically-stable accumulation behavior.
  • hw-native-sys/pypto-lib#99: Restructures Qwen3-14B decode Scope 1/2/3 architecture boundaries, matching main PR's refactoring of RMSNorm logic into SPMD helper kernels.

Suggested labels

enhancement

Poem

🐇 Three kernels bloom where one did thrive,
SPMD helpers make norms alive,
Online-softmax sums as tokens dance,
Dynamic batch dims take their stance,
Decode gold shines in BF16's grace! ✨

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 4.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive No pull request description was provided by the author, making it impossible to assess whether a description exists that relates to the changeset. Add a pull request description explaining the purpose, scope, and testing approach for these new SPMD and mixed implementations of Qwen3 14B decode.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add spmd and mix versions of qwen3 14B decode' clearly and specifically describes the main changes: additions of SPMD and mixed implementations for Qwen3 14B decode functionality across multiple files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Qwen3-14B single-layer decode forward pass, covering RMSNorm, Q/K/V projections, RoPE, paged KV cache updates, and MLP components using a custom DSL. The review feedback identifies several performance and correctness improvements. Key recommendations include fusing the softmax and SV matmul blocks to reduce global memory round-trips, calculating the hidden size inverse locally within the builder to support dynamic configurations, and avoiding unnecessary global memory writes for raw scores by using L1-resident tensors directly.

Comment on lines +408 to +478
with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="qk_softmax"):
for sb in pl.range(ctx_blocks):
block_table_idx = block_table_base + sb
pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX)

cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE
k_tile0 = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :]
raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32)
all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0])

cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE
k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :]
raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32)
all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0])

s0 = sb * BLOCK_SIZE
valid_len = pl.min(BLOCK_SIZE, ctx_len - s0)

scores_valid0 = pl.slice(
all_raw_scores0,
[Q_HEAD_PAD, BLOCK_SIZE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_PAD, valid_len],
)
scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min)
scores0 = pl.mul(scores_padded0, attn_scale)
cur_mi0 = pl.row_max(scores0)
exp_scores0 = pl.exp(pl.row_expand_sub(scores0, cur_mi0))
exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16)
exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32)
cur_li0 = pl.row_sum(exp_scores_fp32_0)
all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0])
all_cur_mi0 = pl.assemble(all_cur_mi0, cur_mi0, [sb * Q_HEAD_PAD, 0])
all_cur_li0 = pl.assemble(all_cur_li0, cur_li0, [sb * Q_HEAD_PAD, 0])

scores_valid1 = pl.slice(
all_raw_scores1,
[Q_HEAD_PAD, BLOCK_SIZE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_PAD, valid_len],
)
scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min)
scores1 = pl.mul(scores_padded1, attn_scale)
cur_mi1 = pl.row_max(scores1)
exp_scores1 = pl.exp(pl.row_expand_sub(scores1, cur_mi1))
exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16)
exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32)
cur_li1 = pl.row_sum(exp_scores_fp32_1)
all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0])
all_cur_mi1 = pl.assemble(all_cur_mi1, cur_mi1, [sb * Q_HEAD_PAD, 0])
all_cur_li1 = pl.assemble(all_cur_li1, cur_li1, [sb * Q_HEAD_PAD, 0])

all_oi_tmp0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
all_oi_tmp1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="sv_matmul"):
for sb in pl.range(ctx_blocks):
block_table_idx = block_table_base + sb
pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX)

cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE
exp_tile0 = all_exp_padded0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :]
v_tile0 = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :]
oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32)
all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0])

cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE
exp_tile1 = all_exp_padded1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :]
v_tile1 = v_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :]
oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32)
all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attention implementation uses separate pl.at blocks for qk_softmax and sv_matmul, each iterating over sequence blocks (sb). This results in unnecessary GM round-trips for intermediate tensors like all_exp_padded0/1 and all_raw_scores0/1. Fusing these operations into a single sb loop within one pl.at block would allow keeping intermediate results in L1 or registers, significantly reducing GM traffic and improving performance. This fusion is a standard optimization for attention kernels (e.g., Flash Attention) and is especially beneficial for decode workloads.

Comment on lines +110 to +115
head_dim_inv = 1.0 / head_dim
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups
attn_scale = 1.0 / (head_dim ** 0.5)
max_ctx_blocks = max_blocks_per_seq
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The program uses a global constant HIDDEN_INV for RMSNorm calculations (lines 193 and 573). This makes the kernel incorrect if a different hidden_size is passed to build_qwen3_decode_program. It is recommended to calculate hidden_inv locally within the builder function to ensure the program remains correct for any hidden size configuration.

Suggested change
head_dim_inv = 1.0 / head_dim
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups
attn_scale = 1.0 / (head_dim ** 0.5)
max_ctx_blocks = max_blocks_per_seq
half_dim = head_dim // 2
head_dim_inv = 1.0 / head_dim
hidden_inv = 1.0 / hidden
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups

Comment on lines +416 to +432
all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0])

cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE
k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :]
raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32)
all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0])

s0 = sb * BLOCK_SIZE
valid_len = pl.min(BLOCK_SIZE, ctx_len - s0)

scores_valid0 = pl.slice(
all_raw_scores0,
[Q_HEAD_PAD, BLOCK_SIZE],
[sb * Q_HEAD_PAD, 0],
valid_shape=[Q_HEAD_PAD, valid_len],
)
scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The all_raw_scores0 tensor is used for a GM round-trip within the same pl.at block. The matmul result raw_scores0 is already in L1 and has the required shape [Q_HEAD_PAD, BLOCK_SIZE]. It should be used directly in pl.fillpad to avoid unnecessary GM writes and reads. The same optimization applies to all_raw_scores1.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`:
- Around line 100-113: The builder assumes the 14B geometry but silently drops
channels for other sizes; add upfront validation (e.g., in build_tensor_specs or
the caller initializing these constants) that enforces the exact invariants
required by the kernel: check hidden == num_heads * head_dim, inter == num_heads
* mlp_factor (or the expected MLP layout), and that hidden, inter, and head_dim
are divisible by
K_CHUNK/INPUT_PROJ_K_CHUNK/KV_PROJ_K_CHUNK/OUT_PROJ_K_CHUNK/DOWN_OUT_CHUNK/OUT_PROJ_N_CHUNK/MLP_OUT_CHUNK/DOWN_MLP_CHUNK
respectively; also ensure q_per_kv % Q_HEAD_BATCH == 0 and that q_groups and
total_q_groups match the kernel's expected grouping (q_groups = q_per_kv //
Q_HEAD_BATCH, total_q_groups = num_kv_heads * q_groups) so the RoPE/qk_norm and
the attention loop won’t drop or zero groups; if validations fail, raise an
explicit error describing the mismatch.
- Around line 92-96: Currently temporary tensors (q_proj, k_proj, v_proj,
q_proj_norm, k_proj_norm, attn_out, all_q_padded) are allocated using a
runtime-derived batch_padded computed from user_batch
(pl.tensor.dim(hidden_states, 0)); instead compute a compile-time padded
capacity from the builder `batch` parameter: batch_padded_static = ((batch +
BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE and use that for pl.create_tensor
dimensions. Then iterate/allocate loops up to batch_padded_static and apply
valid_shape when slicing inputs/outputs so runtime user_batch <= batch is
enforced only at data-movement boundaries; update any references to batch_padded
and remove reliance on pl.tensor.dim(hidden_states, 0) for allocation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 28429264-1061-43ae-a999-8bc04600f9b4

📥 Commits

Reviewing files that changed from the base of the PR and between 58fee33 and 822e9ca.

📒 Files selected for processing (1)
  • models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Comment on lines +92 to +96
# The `batch` parameter is only used by build_tensor_specs to size
# host buffers; it is no longer baked into the program. Every
# batch-dependent kernel signature dim is a pl.dynamic() variable
# (USER_BATCH_DYN / BLOCK_TABLE_FLAT_DYN / KV_CACHE_ROWS_DYN), so a
# single compiled program serves any user_batch <= host capacity.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists
find . -name "qwen3_14b_decode_qk_softmax_fused.py" -type f

Repository: hw-native-sys/pypto-lib

Length of output: 150


🏁 Script executed:

# Get the file size to see if we can read it entirely
wc -l models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py 2>/dev/null || echo "File not found at that path"

Repository: hw-native-sys/pypto-lib

Length of output: 153


🏁 Script executed:

# Search for the file more broadly
find . -name "*qwen3*14b*fused*" -type f | head -20

Repository: hw-native-sys/pypto-lib

Length of output: 150


🏁 Script executed:

# Read lines 92-96 to verify the comment about batch
sed -n '92,96p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Repository: hw-native-sys/pypto-lib

Length of output: 416


🏁 Script executed:

# Read lines around 150 to see batch_padded derivation and tensor allocations
sed -n '145,165p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Repository: hw-native-sys/pypto-lib

Length of output: 1441


🏁 Script executed:

# Read lines around 293-296 for attn_out and all_q_padded allocations
sed -n '288,300p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Repository: hw-native-sys/pypto-lib

Length of output: 881


🏁 Script executed:

# Search for where batch_padded is defined
rg "batch_padded\s*=" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -B 2 -A 2

Repository: hw-native-sys/pypto-lib

Length of output: 334


🏁 Script executed:

# Look for pl.create_tensor usage to understand the pattern
rg "pl\.create_tensor" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -A 3

Repository: hw-native-sys/pypto-lib

Length of output: 6645


🏁 Script executed:

# Search for PyPTO documentation or API constraints
find . -name "*.md" -o -name "*.py" | xargs grep -l "pl.create_tensor" | head -10

Repository: hw-native-sys/pypto-lib

Length of output: 498


🏁 Script executed:

# Look at builder pattern to understand compile-time vs runtime expectations
sed -n '1,100p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Repository: hw-native-sys/pypto-lib

Length of output: 3409


🏁 Script executed:

# Search for build_tensor_specs to understand host buffer allocation
rg "build_tensor_specs|class.*Builder" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -A 10 | head -50

Repository: hw-native-sys/pypto-lib

Length of output: 1550


🏁 Script executed:

# Check if there are other files in codebase that show pl.create_tensor patterns
find . -name "*.py" -path "*/pypto*" | xargs rg "batch_padded|pl.create_tensor" -l | head -10

Repository: hw-native-sys/pypto-lib

Length of output: 150


🏁 Script executed:

# Check the PyPTO coding style guide for constraints on pl.create_tensor()
cat docs/pypto-coding-style.md

Repository: hw-native-sys/pypto-lib

Length of output: 14775


🏁 Script executed:

# Look at another example to see how pl.create_tensor is typically used
sed -n '1,300p' examples/intermediate/rms_norm.py | grep -A 5 "pl.create_tensor"

Repository: hw-native-sys/pypto-lib

Length of output: 443


🏁 Script executed:

# Check deepseek models for similar patterns
rg "batch_padded|user_batch.*pl.tensor.dim" models/deepseek/ -A 5 -B 2 | head -60

Repository: hw-native-sys/pypto-lib

Length of output: 49


🏁 Script executed:

# Verify the builder's batch parameter is available in the function scope where batch_padded is used
sed -n '85,155p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py | grep -E "def|batch|USER_BATCH"

Repository: hw-native-sys/pypto-lib

Length of output: 1300


🏁 Script executed:

# Check the overall scope to understand the function parameters
sed -n '75,110p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py

Repository: hw-native-sys/pypto-lib

Length of output: 1406


Use batch parameter for compile-time tensor allocation instead of runtime-derived batch_padded.

Tensors allocated via pl.create_tensor() require compile-time static dimensions. Currently, batch_padded is derived from the runtime dimension user_batch = pl.tensor.dim(hidden_states, 0) and then used to size q_proj, k_proj, v_proj, q_proj_norm, k_proj_norm (lines 156–160) and attn_out, all_q_padded (lines 293–296).

Instead, allocate these temporaries from a compile-time padded capacity derived from the builder's batch parameter:

batch_padded_static = ((batch + BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE

Then loop over batch_padded_static and use valid_shape on input/output slices to enforce the runtime user_batch <= batch contract, keeping padding and trimming localized to data movement rather than allocation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`
around lines 92 - 96, Currently temporary tensors (q_proj, k_proj, v_proj,
q_proj_norm, k_proj_norm, attn_out, all_q_padded) are allocated using a
runtime-derived batch_padded computed from user_batch
(pl.tensor.dim(hidden_states, 0)); instead compute a compile-time padded
capacity from the builder `batch` parameter: batch_padded_static = ((batch +
BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE and use that for pl.create_tensor
dimensions. Then iterate/allocate loops up to batch_padded_static and apply
valid_shape when slicing inputs/outputs so runtime user_batch <= batch is
enforced only at data-movement boundaries; update any references to batch_padded
and remove reliance on pl.tensor.dim(hidden_states, 0) for allocation.

Comment on lines +100 to +113
input_proj_k_blocks = hidden // INPUT_PROJ_K_CHUNK
kv_proj_k_blocks = hidden // KV_PROJ_K_CHUNK
out_proj_k_blocks = hidden // OUT_PROJ_K_CHUNK
hidden_blocks = hidden // K_CHUNK
down_out_blocks = hidden // DOWN_OUT_CHUNK
out_proj_n_blocks = hidden // OUT_PROJ_N_CHUNK
mlp_out_blocks = inter // MLP_OUT_CHUNK
down_mlp_blocks = inter // DOWN_MLP_CHUNK
max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE
half_dim = head_dim // 2
head_dim_inv = 1.0 / head_dim
q_per_kv = num_heads // num_kv_heads
q_groups = q_per_kv // Q_HEAD_BATCH
total_q_groups = num_kv_heads * q_groups
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fail fast on non-Qwen3-14B geometry.

These builders look generic, but the kernel is still hard-wired to the 14B layout: block counts use floor division, qk_norm and RoPE only materialize one Q_HEAD_BATCH query group per KV head, and the attention loop consumes groups two at a time. Any non-default hidden_size/intermediate_size/num_heads/num_kv_heads can silently drop channels or leave whole query groups zeroed. build_tensor_specs() already exposes this by sizing out from num_heads * head_dim instead of hidden_size. Please either validate the exact supported invariants up front or remove the exposed hyperparameters from the public API.

Suggested guardrail
 def build_qwen3_decode_program(
     batch: int = BATCH,
     max_seq: int = MAX_SEQ,
     hidden_size: int = HIDDEN,
     intermediate_size: int = INTERMEDIATE,
     num_heads: int = NUM_HEADS,
     num_kv_heads: int = NUM_KV_HEADS,
     head_dim: int = HEAD_DIM,
 ):
+    if hidden_size != num_heads * head_dim:
+        raise ValueError("hidden_size must equal num_heads * head_dim")
+    if num_heads % num_kv_heads != 0:
+        raise ValueError("num_heads must be divisible by num_kv_heads")
+    if num_heads // num_kv_heads != Q_HEAD_BATCH:
+        raise ValueError(
+            f"This kernel currently assumes exactly {Q_HEAD_BATCH} query heads per KV head",
+        )
+    if num_kv_heads % 2 != 0:
+        raise ValueError("num_kv_heads must be even")
+    if hidden_size % INPUT_PROJ_K_CHUNK != 0 or hidden_size % K_CHUNK != 0:
+        raise ValueError("hidden_size must align with kernel chunk sizes")
+    if intermediate_size % MLP_OUT_CHUNK != 0 or intermediate_size % DOWN_MLP_CHUNK != 0:
+        raise ValueError("intermediate_size must align with MLP chunk sizes")
+
 def build_tensor_specs(
     batch: int = BATCH,
     max_seq: int = MAX_SEQ,
     hidden_size: int = HIDDEN,
     intermediate_size: int = INTERMEDIATE,
@@
-    hidden = num_heads * head_dim
+    if hidden_size != num_heads * head_dim:
+        raise ValueError("hidden_size must equal num_heads * head_dim")
+    hidden = hidden_size

Also applies to: 263-275, 349-380, 384-398, 690-695, 805-805, 842-844, 928-973

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`
around lines 100 - 113, The builder assumes the 14B geometry but silently drops
channels for other sizes; add upfront validation (e.g., in build_tensor_specs or
the caller initializing these constants) that enforces the exact invariants
required by the kernel: check hidden == num_heads * head_dim, inter == num_heads
* mlp_factor (or the expected MLP layout), and that hidden, inter, and head_dim
are divisible by
K_CHUNK/INPUT_PROJ_K_CHUNK/KV_PROJ_K_CHUNK/OUT_PROJ_K_CHUNK/DOWN_OUT_CHUNK/OUT_PROJ_N_CHUNK/MLP_OUT_CHUNK/DOWN_MLP_CHUNK
respectively; also ensure q_per_kv % Q_HEAD_BATCH == 0 and that q_groups and
total_q_groups match the kernel's expected grouping (q_groups = q_per_kv //
Q_HEAD_BATCH, total_q_groups = num_kv_heads * q_groups) so the RoPE/qk_norm and
the attention loop won’t drop or zero groups; if validations fail, raise an
explicit error describing the mismatch.

@xzhxzhxzh123 xzhxzhxzh123 changed the title Create qwen3_14b_decode_qk_softmax_fused.py Add spmd and mix versions of qwen3 14B decode May 20, 2026
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (1)

25-25: 💤 Low value

Nitpick: Replace ambiguous multiplication sign character.

The docstring uses × (Unicode multiplication sign) which is flagged by static analysis (RUF002). Consider replacing with x for consistency.

Proposed fix
-  1. Output projection: attn_out × wo
+  1. Output projection: attn_out x wo
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py` at
line 25, Docstring uses the Unicode multiplication sign '×' in the bullet
"Output projection: attn_out × wo" which triggers static-lint RUF002; update
that occurrence in the module
qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (the docstring text
containing "attn_out × wo") to use a plain ASCII 'x' (i.e., "attn_out x wo") so
the docstring remains readable and lint-clean.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/qwen3/14b/qwen3_14b_decode.py`:
- Around line 178-180: The computed local_valid can be negative here (same bug
as in rmsnorm_kernel); change the assignment in qwen3_14b_decode.py so that
local_valid is clamped to be non-negative after computing local_valid =
pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) — e.g., replace with a
non-negative clamp using pl.max(..., 0) or pl.clamp(..., 0, None) so local_valid
never goes below zero; keep the same block_idx, RMSNORM_SPMD_ROWS and cur_valid
variables intact.
- Around line 130-132: The computed local_valid can become negative when
cur_valid < row_start; update the calculation around pl.tile.get_block_idx() so
local_valid is clamped to a non-negative value before it is used in valid_shape
(e.g., compute local_valid = max(0, min(RMSNORM_SPMD_ROWS, cur_valid -
row_start)) or use an equivalent pl.clip/pl.max/pl.min combination). Ensure this
change touches the block where block_idx, row_start, and local_valid are
computed so downstream uses (valid_shape, padding) never receive negative
dimensions.

---

Nitpick comments:
In `@models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py`:
- Line 25: Docstring uses the Unicode multiplication sign '×' in the bullet
"Output projection: attn_out × wo" which triggers static-lint RUF002; update
that occurrence in the module
qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (the docstring text
containing "attn_out × wo") to use a plain ASCII 'x' (i.e., "attn_out x wo") so
the docstring remains readable and lint-clean.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 8e4e14da-730b-4494-9c17-03f71d127fff

📥 Commits

Reviewing files that changed from the base of the PR and between 822e9ca and 386a309.

📒 Files selected for processing (4)
  • models/qwen3/14b/qwen3_14b_decode.py
  • models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py
  • models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
  • models/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
💤 Files with no reviewable changes (1)
  • models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py

Comment on lines +130 to +132
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Potential negative local_valid when cur_valid < row_start.

When cur_valid is less than row_start (e.g., cur_valid=4 with block_idx=1 where row_start=8), the expression cur_valid - row_start yields a negative value. pl.min(RMSNORM_SPMD_ROWS, -4) would produce -4, and passing a negative dimension to valid_shape could cause undefined behavior or incorrect zero-padding.

Consider clamping local_valid to be non-negative:

🛡️ Proposed fix
             row_start = block_idx * RMSNORM_SPMD_ROWS
-            local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+            local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 130 - 132, The computed
local_valid can become negative when cur_valid < row_start; update the
calculation around pl.tile.get_block_idx() so local_valid is clamped to a
non-negative value before it is used in valid_shape (e.g., compute local_valid =
max(0, min(RMSNORM_SPMD_ROWS, cur_valid - row_start)) or use an equivalent
pl.clip/pl.max/pl.min combination). Ensure this change touches the block where
block_idx, row_start, and local_valid are computed so downstream uses
(valid_shape, padding) never receive negative dimensions.

Comment on lines +178 to +180
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same negative local_valid issue as in rmsnorm_kernel.

Apply the same fix to clamp local_valid to be non-negative.

🛡️ Proposed fix
             row_start = block_idx * RMSNORM_SPMD_ROWS
-            local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+            local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 178 - 180, The computed
local_valid can be negative here (same bug as in rmsnorm_kernel); change the
assignment in qwen3_14b_decode.py so that local_valid is clamped to be
non-negative after computing local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid -
row_start) — e.g., replace with a non-negative clamp using pl.max(..., 0) or
pl.clamp(..., 0, None) so local_valid never goes below zero; keep the same
block_idx, RMSNORM_SPMD_ROWS and cur_valid variables intact.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant