Add spmd and mix versions of qwen3 14B decode#330
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughExisting Qwen3-14B decode.py is refactored to extract RMSNorm logic into SPMD kernels and simplify MLP paths; three new complete single-layer decode implementations are added as variants using different kernel construction strategies (fused online-softmax, fused matmul-softmax, and SPMD/InCore primitives). ChangesQwen3-14B Decode SPMD Refactoring
Qwen3-14B Fused QK-Softmax-SV Online Softmax Decode
Qwen3-14B Fused QK Matmul-Softmax Decode
Qwen3-14B SPMD/InCore Decode with RMSNorm Kernels
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements the Qwen3-14B single-layer decode forward pass, covering RMSNorm, Q/K/V projections, RoPE, paged KV cache updates, and MLP components using a custom DSL. The review feedback identifies several performance and correctness improvements. Key recommendations include fusing the softmax and SV matmul blocks to reduce global memory round-trips, calculating the hidden size inverse locally within the builder to support dynamic configurations, and avoiding unnecessary global memory writes for raw scores by using L1-resident tensors directly.
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="qk_softmax"): | ||
| for sb in pl.range(ctx_blocks): | ||
| block_table_idx = block_table_base + sb | ||
| pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) | ||
|
|
||
| cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE | ||
| k_tile0 = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] | ||
| raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32) | ||
| all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE | ||
| k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] | ||
| raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) | ||
| all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| s0 = sb * BLOCK_SIZE | ||
| valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) | ||
|
|
||
| scores_valid0 = pl.slice( | ||
| all_raw_scores0, | ||
| [Q_HEAD_PAD, BLOCK_SIZE], | ||
| [sb * Q_HEAD_PAD, 0], | ||
| valid_shape=[Q_HEAD_PAD, valid_len], | ||
| ) | ||
| scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) | ||
| scores0 = pl.mul(scores_padded0, attn_scale) | ||
| cur_mi0 = pl.row_max(scores0) | ||
| exp_scores0 = pl.exp(pl.row_expand_sub(scores0, cur_mi0)) | ||
| exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16) | ||
| exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32) | ||
| cur_li0 = pl.row_sum(exp_scores_fp32_0) | ||
| all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0]) | ||
| all_cur_mi0 = pl.assemble(all_cur_mi0, cur_mi0, [sb * Q_HEAD_PAD, 0]) | ||
| all_cur_li0 = pl.assemble(all_cur_li0, cur_li0, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| scores_valid1 = pl.slice( | ||
| all_raw_scores1, | ||
| [Q_HEAD_PAD, BLOCK_SIZE], | ||
| [sb * Q_HEAD_PAD, 0], | ||
| valid_shape=[Q_HEAD_PAD, valid_len], | ||
| ) | ||
| scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min) | ||
| scores1 = pl.mul(scores_padded1, attn_scale) | ||
| cur_mi1 = pl.row_max(scores1) | ||
| exp_scores1 = pl.exp(pl.row_expand_sub(scores1, cur_mi1)) | ||
| exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16) | ||
| exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32) | ||
| cur_li1 = pl.row_sum(exp_scores_fp32_1) | ||
| all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0]) | ||
| all_cur_mi1 = pl.assemble(all_cur_mi1, cur_mi1, [sb * Q_HEAD_PAD, 0]) | ||
| all_cur_li1 = pl.assemble(all_cur_li1, cur_li1, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| all_oi_tmp0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) | ||
| all_oi_tmp1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="sv_matmul"): | ||
| for sb in pl.range(ctx_blocks): | ||
| block_table_idx = block_table_base + sb | ||
| pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) | ||
|
|
||
| cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE | ||
| exp_tile0 = all_exp_padded0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] | ||
| v_tile0 = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] | ||
| oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32) | ||
| all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE | ||
| exp_tile1 = all_exp_padded1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] | ||
| v_tile1 = v_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] | ||
| oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32) | ||
| all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0]) | ||
|
|
There was a problem hiding this comment.
The attention implementation uses separate pl.at blocks for qk_softmax and sv_matmul, each iterating over sequence blocks (sb). This results in unnecessary GM round-trips for intermediate tensors like all_exp_padded0/1 and all_raw_scores0/1. Fusing these operations into a single sb loop within one pl.at block would allow keeping intermediate results in L1 or registers, significantly reducing GM traffic and improving performance. This fusion is a standard optimization for attention kernels (e.g., Flash Attention) and is especially beneficial for decode workloads.
| head_dim_inv = 1.0 / head_dim | ||
| q_per_kv = num_heads // num_kv_heads | ||
| q_groups = q_per_kv // Q_HEAD_BATCH | ||
| total_q_groups = num_kv_heads * q_groups | ||
| attn_scale = 1.0 / (head_dim ** 0.5) | ||
| max_ctx_blocks = max_blocks_per_seq |
There was a problem hiding this comment.
The program uses a global constant HIDDEN_INV for RMSNorm calculations (lines 193 and 573). This makes the kernel incorrect if a different hidden_size is passed to build_qwen3_decode_program. It is recommended to calculate hidden_inv locally within the builder function to ensure the program remains correct for any hidden size configuration.
| head_dim_inv = 1.0 / head_dim | |
| q_per_kv = num_heads // num_kv_heads | |
| q_groups = q_per_kv // Q_HEAD_BATCH | |
| total_q_groups = num_kv_heads * q_groups | |
| attn_scale = 1.0 / (head_dim ** 0.5) | |
| max_ctx_blocks = max_blocks_per_seq | |
| half_dim = head_dim // 2 | |
| head_dim_inv = 1.0 / head_dim | |
| hidden_inv = 1.0 / hidden | |
| q_per_kv = num_heads // num_kv_heads | |
| q_groups = q_per_kv // Q_HEAD_BATCH | |
| total_q_groups = num_kv_heads * q_groups |
| all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE | ||
| k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] | ||
| raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) | ||
| all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) | ||
|
|
||
| s0 = sb * BLOCK_SIZE | ||
| valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) | ||
|
|
||
| scores_valid0 = pl.slice( | ||
| all_raw_scores0, | ||
| [Q_HEAD_PAD, BLOCK_SIZE], | ||
| [sb * Q_HEAD_PAD, 0], | ||
| valid_shape=[Q_HEAD_PAD, valid_len], | ||
| ) | ||
| scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) |
There was a problem hiding this comment.
The all_raw_scores0 tensor is used for a GM round-trip within the same pl.at block. The matmul result raw_scores0 is already in L1 and has the required shape [Q_HEAD_PAD, BLOCK_SIZE]. It should be used directly in pl.fillpad to avoid unnecessary GM writes and reads. The same optimization applies to all_raw_scores1.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`:
- Around line 100-113: The builder assumes the 14B geometry but silently drops
channels for other sizes; add upfront validation (e.g., in build_tensor_specs or
the caller initializing these constants) that enforces the exact invariants
required by the kernel: check hidden == num_heads * head_dim, inter == num_heads
* mlp_factor (or the expected MLP layout), and that hidden, inter, and head_dim
are divisible by
K_CHUNK/INPUT_PROJ_K_CHUNK/KV_PROJ_K_CHUNK/OUT_PROJ_K_CHUNK/DOWN_OUT_CHUNK/OUT_PROJ_N_CHUNK/MLP_OUT_CHUNK/DOWN_MLP_CHUNK
respectively; also ensure q_per_kv % Q_HEAD_BATCH == 0 and that q_groups and
total_q_groups match the kernel's expected grouping (q_groups = q_per_kv //
Q_HEAD_BATCH, total_q_groups = num_kv_heads * q_groups) so the RoPE/qk_norm and
the attention loop won’t drop or zero groups; if validations fail, raise an
explicit error describing the mismatch.
- Around line 92-96: Currently temporary tensors (q_proj, k_proj, v_proj,
q_proj_norm, k_proj_norm, attn_out, all_q_padded) are allocated using a
runtime-derived batch_padded computed from user_batch
(pl.tensor.dim(hidden_states, 0)); instead compute a compile-time padded
capacity from the builder `batch` parameter: batch_padded_static = ((batch +
BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE and use that for pl.create_tensor
dimensions. Then iterate/allocate loops up to batch_padded_static and apply
valid_shape when slicing inputs/outputs so runtime user_batch <= batch is
enforced only at data-movement boundaries; update any references to batch_padded
and remove reliance on pl.tensor.dim(hidden_states, 0) for allocation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 28429264-1061-43ae-a999-8bc04600f9b4
📒 Files selected for processing (1)
models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py
| # The `batch` parameter is only used by build_tensor_specs to size | ||
| # host buffers; it is no longer baked into the program. Every | ||
| # batch-dependent kernel signature dim is a pl.dynamic() variable | ||
| # (USER_BATCH_DYN / BLOCK_TABLE_FLAT_DYN / KV_CACHE_ROWS_DYN), so a | ||
| # single compiled program serves any user_batch <= host capacity. |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists
find . -name "qwen3_14b_decode_qk_softmax_fused.py" -type fRepository: hw-native-sys/pypto-lib
Length of output: 150
🏁 Script executed:
# Get the file size to see if we can read it entirely
wc -l models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py 2>/dev/null || echo "File not found at that path"Repository: hw-native-sys/pypto-lib
Length of output: 153
🏁 Script executed:
# Search for the file more broadly
find . -name "*qwen3*14b*fused*" -type f | head -20Repository: hw-native-sys/pypto-lib
Length of output: 150
🏁 Script executed:
# Read lines 92-96 to verify the comment about batch
sed -n '92,96p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.pyRepository: hw-native-sys/pypto-lib
Length of output: 416
🏁 Script executed:
# Read lines around 150 to see batch_padded derivation and tensor allocations
sed -n '145,165p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.pyRepository: hw-native-sys/pypto-lib
Length of output: 1441
🏁 Script executed:
# Read lines around 293-296 for attn_out and all_q_padded allocations
sed -n '288,300p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.pyRepository: hw-native-sys/pypto-lib
Length of output: 881
🏁 Script executed:
# Search for where batch_padded is defined
rg "batch_padded\s*=" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -B 2 -A 2Repository: hw-native-sys/pypto-lib
Length of output: 334
🏁 Script executed:
# Look for pl.create_tensor usage to understand the pattern
rg "pl\.create_tensor" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -A 3Repository: hw-native-sys/pypto-lib
Length of output: 6645
🏁 Script executed:
# Search for PyPTO documentation or API constraints
find . -name "*.md" -o -name "*.py" | xargs grep -l "pl.create_tensor" | head -10Repository: hw-native-sys/pypto-lib
Length of output: 498
🏁 Script executed:
# Look at builder pattern to understand compile-time vs runtime expectations
sed -n '1,100p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.pyRepository: hw-native-sys/pypto-lib
Length of output: 3409
🏁 Script executed:
# Search for build_tensor_specs to understand host buffer allocation
rg "build_tensor_specs|class.*Builder" models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py -n -A 10 | head -50Repository: hw-native-sys/pypto-lib
Length of output: 1550
🏁 Script executed:
# Check if there are other files in codebase that show pl.create_tensor patterns
find . -name "*.py" -path "*/pypto*" | xargs rg "batch_padded|pl.create_tensor" -l | head -10Repository: hw-native-sys/pypto-lib
Length of output: 150
🏁 Script executed:
# Check the PyPTO coding style guide for constraints on pl.create_tensor()
cat docs/pypto-coding-style.mdRepository: hw-native-sys/pypto-lib
Length of output: 14775
🏁 Script executed:
# Look at another example to see how pl.create_tensor is typically used
sed -n '1,300p' examples/intermediate/rms_norm.py | grep -A 5 "pl.create_tensor"Repository: hw-native-sys/pypto-lib
Length of output: 443
🏁 Script executed:
# Check deepseek models for similar patterns
rg "batch_padded|user_batch.*pl.tensor.dim" models/deepseek/ -A 5 -B 2 | head -60Repository: hw-native-sys/pypto-lib
Length of output: 49
🏁 Script executed:
# Verify the builder's batch parameter is available in the function scope where batch_padded is used
sed -n '85,155p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py | grep -E "def|batch|USER_BATCH"Repository: hw-native-sys/pypto-lib
Length of output: 1300
🏁 Script executed:
# Check the overall scope to understand the function parameters
sed -n '75,110p' models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.pyRepository: hw-native-sys/pypto-lib
Length of output: 1406
Use batch parameter for compile-time tensor allocation instead of runtime-derived batch_padded.
Tensors allocated via pl.create_tensor() require compile-time static dimensions. Currently, batch_padded is derived from the runtime dimension user_batch = pl.tensor.dim(hidden_states, 0) and then used to size q_proj, k_proj, v_proj, q_proj_norm, k_proj_norm (lines 156–160) and attn_out, all_q_padded (lines 293–296).
Instead, allocate these temporaries from a compile-time padded capacity derived from the builder's batch parameter:
batch_padded_static = ((batch + BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE
Then loop over batch_padded_static and use valid_shape on input/output slices to enforce the runtime user_batch <= batch contract, keeping padding and trimming localized to data movement rather than allocation.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`
around lines 92 - 96, Currently temporary tensors (q_proj, k_proj, v_proj,
q_proj_norm, k_proj_norm, attn_out, all_q_padded) are allocated using a
runtime-derived batch_padded computed from user_batch
(pl.tensor.dim(hidden_states, 0)); instead compute a compile-time padded
capacity from the builder `batch` parameter: batch_padded_static = ((batch +
BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE and use that for pl.create_tensor
dimensions. Then iterate/allocate loops up to batch_padded_static and apply
valid_shape when slicing inputs/outputs so runtime user_batch <= batch is
enforced only at data-movement boundaries; update any references to batch_padded
and remove reliance on pl.tensor.dim(hidden_states, 0) for allocation.
| input_proj_k_blocks = hidden // INPUT_PROJ_K_CHUNK | ||
| kv_proj_k_blocks = hidden // KV_PROJ_K_CHUNK | ||
| out_proj_k_blocks = hidden // OUT_PROJ_K_CHUNK | ||
| hidden_blocks = hidden // K_CHUNK | ||
| down_out_blocks = hidden // DOWN_OUT_CHUNK | ||
| out_proj_n_blocks = hidden // OUT_PROJ_N_CHUNK | ||
| mlp_out_blocks = inter // MLP_OUT_CHUNK | ||
| down_mlp_blocks = inter // DOWN_MLP_CHUNK | ||
| max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE | ||
| half_dim = head_dim // 2 | ||
| head_dim_inv = 1.0 / head_dim | ||
| q_per_kv = num_heads // num_kv_heads | ||
| q_groups = q_per_kv // Q_HEAD_BATCH | ||
| total_q_groups = num_kv_heads * q_groups |
There was a problem hiding this comment.
Fail fast on non-Qwen3-14B geometry.
These builders look generic, but the kernel is still hard-wired to the 14B layout: block counts use floor division, qk_norm and RoPE only materialize one Q_HEAD_BATCH query group per KV head, and the attention loop consumes groups two at a time. Any non-default hidden_size/intermediate_size/num_heads/num_kv_heads can silently drop channels or leave whole query groups zeroed. build_tensor_specs() already exposes this by sizing out from num_heads * head_dim instead of hidden_size. Please either validate the exact supported invariants up front or remove the exposed hyperparameters from the public API.
Suggested guardrail
def build_qwen3_decode_program(
batch: int = BATCH,
max_seq: int = MAX_SEQ,
hidden_size: int = HIDDEN,
intermediate_size: int = INTERMEDIATE,
num_heads: int = NUM_HEADS,
num_kv_heads: int = NUM_KV_HEADS,
head_dim: int = HEAD_DIM,
):
+ if hidden_size != num_heads * head_dim:
+ raise ValueError("hidden_size must equal num_heads * head_dim")
+ if num_heads % num_kv_heads != 0:
+ raise ValueError("num_heads must be divisible by num_kv_heads")
+ if num_heads // num_kv_heads != Q_HEAD_BATCH:
+ raise ValueError(
+ f"This kernel currently assumes exactly {Q_HEAD_BATCH} query heads per KV head",
+ )
+ if num_kv_heads % 2 != 0:
+ raise ValueError("num_kv_heads must be even")
+ if hidden_size % INPUT_PROJ_K_CHUNK != 0 or hidden_size % K_CHUNK != 0:
+ raise ValueError("hidden_size must align with kernel chunk sizes")
+ if intermediate_size % MLP_OUT_CHUNK != 0 or intermediate_size % DOWN_MLP_CHUNK != 0:
+ raise ValueError("intermediate_size must align with MLP chunk sizes")
+
def build_tensor_specs(
batch: int = BATCH,
max_seq: int = MAX_SEQ,
hidden_size: int = HIDDEN,
intermediate_size: int = INTERMEDIATE,
@@
- hidden = num_heads * head_dim
+ if hidden_size != num_heads * head_dim:
+ raise ValueError("hidden_size must equal num_heads * head_dim")
+ hidden = hidden_sizeAlso applies to: 263-275, 349-380, 384-398, 690-695, 805-805, 842-844, 928-973
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In
`@models/qwen3/14b/pypto-lib/models/qwen3/14b/qwen3_14b_decode_qk_softmax_fused.py`
around lines 100 - 113, The builder assumes the 14B geometry but silently drops
channels for other sizes; add upfront validation (e.g., in build_tensor_specs or
the caller initializing these constants) that enforces the exact invariants
required by the kernel: check hidden == num_heads * head_dim, inter == num_heads
* mlp_factor (or the expected MLP layout), and that hidden, inter, and head_dim
are divisible by
K_CHUNK/INPUT_PROJ_K_CHUNK/KV_PROJ_K_CHUNK/OUT_PROJ_K_CHUNK/DOWN_OUT_CHUNK/OUT_PROJ_N_CHUNK/MLP_OUT_CHUNK/DOWN_MLP_CHUNK
respectively; also ensure q_per_kv % Q_HEAD_BATCH == 0 and that q_groups and
total_q_groups match the kernel's expected grouping (q_groups = q_per_kv //
Q_HEAD_BATCH, total_q_groups = num_kv_heads * q_groups) so the RoPE/qk_norm and
the attention loop won’t drop or zero groups; if validations fail, raise an
explicit error describing the mismatch.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (1)
25-25: 💤 Low valueNitpick: Replace ambiguous multiplication sign character.
The docstring uses
×(Unicode multiplication sign) which is flagged by static analysis (RUF002). Consider replacing withxfor consistency.Proposed fix
- 1. Output projection: attn_out × wo + 1. Output projection: attn_out x wo🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py` at line 25, Docstring uses the Unicode multiplication sign '×' in the bullet "Output projection: attn_out × wo" which triggers static-lint RUF002; update that occurrence in the module qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (the docstring text containing "attn_out × wo") to use a plain ASCII 'x' (i.e., "attn_out x wo") so the docstring remains readable and lint-clean.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/qwen3/14b/qwen3_14b_decode.py`:
- Around line 178-180: The computed local_valid can be negative here (same bug
as in rmsnorm_kernel); change the assignment in qwen3_14b_decode.py so that
local_valid is clamped to be non-negative after computing local_valid =
pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) — e.g., replace with a
non-negative clamp using pl.max(..., 0) or pl.clamp(..., 0, None) so local_valid
never goes below zero; keep the same block_idx, RMSNORM_SPMD_ROWS and cur_valid
variables intact.
- Around line 130-132: The computed local_valid can become negative when
cur_valid < row_start; update the calculation around pl.tile.get_block_idx() so
local_valid is clamped to a non-negative value before it is used in valid_shape
(e.g., compute local_valid = max(0, min(RMSNORM_SPMD_ROWS, cur_valid -
row_start)) or use an equivalent pl.clip/pl.max/pl.min combination). Ensure this
change touches the block where block_idx, row_start, and local_valid are
computed so downstream uses (valid_shape, padding) never receive negative
dimensions.
---
Nitpick comments:
In `@models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py`:
- Line 25: Docstring uses the Unicode multiplication sign '×' in the bullet
"Output projection: attn_out × wo" which triggers static-lint RUF002; update
that occurrence in the module
qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py (the docstring text
containing "attn_out × wo") to use a plain ASCII 'x' (i.e., "attn_out x wo") so
the docstring remains readable and lint-clean.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 8e4e14da-730b-4494-9c17-03f71d127fff
📒 Files selected for processing (4)
models/qwen3/14b/qwen3_14b_decode.pymodels/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.pymodels/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.pymodels/qwen3/14b/qwen3_14b_rmsnorm_spmd2.py
💤 Files with no reviewable changes (1)
- models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py
| block_idx = pl.tile.get_block_idx() | ||
| row_start = block_idx * RMSNORM_SPMD_ROWS | ||
| local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) |
There was a problem hiding this comment.
Potential negative local_valid when cur_valid < row_start.
When cur_valid is less than row_start (e.g., cur_valid=4 with block_idx=1 where row_start=8), the expression cur_valid - row_start yields a negative value. pl.min(RMSNORM_SPMD_ROWS, -4) would produce -4, and passing a negative dimension to valid_shape could cause undefined behavior or incorrect zero-padding.
Consider clamping local_valid to be non-negative:
🛡️ Proposed fix
row_start = block_idx * RMSNORM_SPMD_ROWS
- local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+ local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| block_idx = pl.tile.get_block_idx() | |
| row_start = block_idx * RMSNORM_SPMD_ROWS | |
| local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) | |
| block_idx = pl.tile.get_block_idx() | |
| row_start = block_idx * RMSNORM_SPMD_ROWS | |
| local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 130 - 132, The computed
local_valid can become negative when cur_valid < row_start; update the
calculation around pl.tile.get_block_idx() so local_valid is clamped to a
non-negative value before it is used in valid_shape (e.g., compute local_valid =
max(0, min(RMSNORM_SPMD_ROWS, cur_valid - row_start)) or use an equivalent
pl.clip/pl.max/pl.min combination). Ensure this change touches the block where
block_idx, row_start, and local_valid are computed so downstream uses
(valid_shape, padding) never receive negative dimensions.
| block_idx = pl.tile.get_block_idx() | ||
| row_start = block_idx * RMSNORM_SPMD_ROWS | ||
| local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) |
There was a problem hiding this comment.
Same negative local_valid issue as in rmsnorm_kernel.
Apply the same fix to clamp local_valid to be non-negative.
🛡️ Proposed fix
row_start = block_idx * RMSNORM_SPMD_ROWS
- local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+ local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| block_idx = pl.tile.get_block_idx() | |
| row_start = block_idx * RMSNORM_SPMD_ROWS | |
| local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) | |
| block_idx = pl.tile.get_block_idx() | |
| row_start = block_idx * RMSNORM_SPMD_ROWS | |
| local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 178 - 180, The computed
local_valid can be negative here (same bug as in rmsnorm_kernel); change the
assignment in qwen3_14b_decode.py so that local_valid is clamped to be
non-negative after computing local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid -
row_start) — e.g., replace with a non-negative clamp using pl.max(..., 0) or
pl.clamp(..., 0, None) so local_valid never goes below zero; keep the same
block_idx, RMSNORM_SPMD_ROWS and cur_valid variables intact.
No description provided.