Skip to content
253 changes: 135 additions & 118 deletions models/qwen3/14b/qwen3_14b_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
Q_OUT_CHUNK = 256
KV_OUT_CHUNK = 256
BATCH_TILE = 16
RMSNORM_SPMD_CORES = 2
RMSNORM_SPMD_ROWS = BATCH_TILE // RMSNORM_SPMD_CORES

# Scope 2 tiling constants.
# Qwen3-14B uses 40 Q heads and 8 KV heads, so q_per_kv = 5.
Expand Down Expand Up @@ -116,6 +118,96 @@ def build_qwen3_decode_program(

@pl.program
class Qwen3Decode:
@pl.function(type=pl.FunctionType.InCore)
def rmsnorm_kernel(
self,
hidden_states: pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16],
b0: pl.Scalar[pl.INDEX],
cur_valid: pl.Scalar[pl.INDEX],
input_rms_weight: pl.Tensor[[1, hidden], pl.FP32],
normed_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]],
) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]:
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
Comment on lines +130 to +132
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Potential negative local_valid when cur_valid < row_start.

When cur_valid is less than row_start (e.g., cur_valid=4 with block_idx=1 where row_start=8), the expression cur_valid - row_start yields a negative value. pl.min(RMSNORM_SPMD_ROWS, -4) would produce -4, and passing a negative dimension to valid_shape could cause undefined behavior or incorrect zero-padding.

Consider clamping local_valid to be non-negative:

🛡️ Proposed fix
             row_start = block_idx * RMSNORM_SPMD_ROWS
-            local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+            local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 130 - 132, The computed
local_valid can become negative when cur_valid < row_start; update the
calculation around pl.tile.get_block_idx() so local_valid is clamped to a
non-negative value before it is used in valid_shape (e.g., compute local_valid =
max(0, min(RMSNORM_SPMD_ROWS, cur_valid - row_start)) or use an equivalent
pl.clip/pl.max/pl.min combination). Ensure this change touches the block where
block_idx, row_start, and local_valid are computed so downstream uses
(valid_shape, padding) never receive negative dimensions.


partial_sq = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0)
for kb in pl.pipeline(input_proj_k_blocks, stage=4):
k0 = kb * INPUT_PROJ_K_CHUNK
x_chunk = pl.cast(
pl.slice(
hidden_states,
[RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK],
[b0 + row_start, k0],
valid_shape=[local_valid, INPUT_PROJ_K_CHUNK],
),
target_type=pl.FP32,
)
partial_sq = pl.add(partial_sq, pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, RMSNORM_SPMD_ROWS]))
variance = pl.reshape(pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), [RMSNORM_SPMD_ROWS, 1])
inv_rms = pl.recip(pl.sqrt(variance))

for kb in pl.pipeline(input_proj_k_blocks, stage=4):
k0 = kb * INPUT_PROJ_K_CHUNK
x_chunk = pl.cast(
pl.slice(
hidden_states,
[RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK],
[b0 + row_start, k0],
valid_shape=[local_valid, INPUT_PROJ_K_CHUNK],
),
target_type=pl.FP32,
)
gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK]
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma)
normed_tile = pl.assemble(
normed_tile,
pl.cast(normed, target_type=pl.BF16),
[row_start, k0],
)
return normed_tile

@pl.function(type=pl.FunctionType.InCore)
def post_rmsnorm_kernel(
self,
resid1_tile: pl.Tensor[[BATCH_TILE, hidden], pl.FP32],
cur_valid: pl.Scalar[pl.INDEX],
post_rms_weight: pl.Tensor[[1, hidden], pl.FP32],
post_norm_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]],
) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]:
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
Comment on lines +178 to +180
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same negative local_valid issue as in rmsnorm_kernel.

Apply the same fix to clamp local_valid to be non-negative.

🛡️ Proposed fix
             row_start = block_idx * RMSNORM_SPMD_ROWS
-            local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
+            local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start)
block_idx = pl.tile.get_block_idx()
row_start = block_idx * RMSNORM_SPMD_ROWS
local_valid = pl.max(0, pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_decode.py` around lines 178 - 180, The computed
local_valid can be negative here (same bug as in rmsnorm_kernel); change the
assignment in qwen3_14b_decode.py so that local_valid is clamped to be
non-negative after computing local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid -
row_start) — e.g., replace with a non-negative clamp using pl.max(..., 0) or
pl.clamp(..., 0, None) so local_valid never goes below zero; keep the same
block_idx, RMSNORM_SPMD_ROWS and cur_valid variables intact.


sq_sum = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0)
for kb in pl.pipeline(hidden_blocks, stage=2):
k0 = kb * K_CHUNK
resid_chunk = pl.slice(
resid1_tile,
[RMSNORM_SPMD_ROWS, K_CHUNK],
[row_start, k0],
valid_shape=[local_valid, K_CHUNK],
)
sq_sum = pl.add(sq_sum, pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, RMSNORM_SPMD_ROWS]))
inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)))

for kb in pl.pipeline(hidden_blocks, stage=2):
k0 = kb * K_CHUNK
resid_chunk = pl.slice(
resid1_tile,
[RMSNORM_SPMD_ROWS, K_CHUNK],
[row_start, k0],
valid_shape=[local_valid, K_CHUNK],
)
post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK]
post_normed = pl.col_expand_mul(
pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [RMSNORM_SPMD_ROWS, 1])),
post_gamma,
)
normed_bf16 = pl.cast(post_normed, target_type=pl.BF16)
post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [row_start, k0])
return post_norm_tile

@pl.function(type=pl.FunctionType.Opaque)
def qwen3_decode(
self,
Expand Down Expand Up @@ -172,47 +264,8 @@ def qwen3_decode(
cur_valid = pl.min(BATCH_TILE, user_batch - b0)
normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"):
partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.pipeline(input_proj_k_blocks, stage=4):
k0 = kb * INPUT_PROJ_K_CHUNK
x_chunk = pl.cast(
pl.slice(
hidden_states,
[BATCH_TILE, INPUT_PROJ_K_CHUNK],
[b0, k0],
valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK],
),
target_type=pl.FP32,
)
partial_sq = pl.add(
partial_sq,
pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]),
)
variance = pl.reshape(
pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS),
[BATCH_TILE, 1],
)
inv_rms = pl.recip(pl.sqrt(variance))

for kb in pl.pipeline(input_proj_k_blocks, stage=4):
k0 = kb * INPUT_PROJ_K_CHUNK
x_chunk = pl.cast(
pl.slice(
hidden_states,
[BATCH_TILE, INPUT_PROJ_K_CHUNK],
[b0, k0],
valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK],
),
target_type=pl.FP32,
)
gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK]
normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma)
normed_tile = pl.assemble(
normed_tile,
pl.cast(normed, target_type=pl.BF16),
[0, k0],
)
with pl.spmd(RMSNORM_SPMD_CORES):
normed_tile = self.rmsnorm_kernel(hidden_states, b0, cur_valid, input_rms_weight, normed_tile)

for q0 in pl.parallel(0, hidden, Q_OUT_CHUNK):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"):
Expand Down Expand Up @@ -563,27 +616,9 @@ def qwen3_decode(
resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0])

post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="post_rmsnorm"):
sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0)
for kb in pl.pipeline(hidden_blocks, stage=2):
k0 = kb * K_CHUNK
resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK]
sq_sum = pl.add(
sq_sum,
pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]),
)
inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)))

for kb in pl.pipeline(hidden_blocks, stage=2):
k0 = kb * K_CHUNK
resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK]
post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK]
post_normed = pl.col_expand_mul(
pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])),
post_gamma,
)
normed_bf16 = pl.cast(post_normed, target_type=pl.BF16)
post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0])

with pl.spmd(RMSNORM_SPMD_CORES):
post_norm_tile = self.post_rmsnorm_kernel(resid1_tile, cur_valid, post_rms_weight, post_norm_tile)

mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16)
for ob in pl.range(mlp_out_blocks):
Expand Down Expand Up @@ -616,54 +651,34 @@ def qwen3_decode(
mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16)
mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0])

if cur_valid == BATCH_TILE:
with pl.at(
level=pl.Level.CORE_GROUP,
optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.LEFT_RIGHT)],
name_hint="down_proj_residual",
):
for dob in pl.parallel(down_out_blocks, chunk=1):
d0 = dob * DOWN_OUT_CHUNK
down_mlp_chunk_bf16 = mlp_tile[:, 0:DOWN_MLP_CHUNK]
w_down_chunk = w_down[0:DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK]
down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32)
for ob in pl.pipeline(1, down_mlp_blocks, stage=2):
o0 = ob * DOWN_MLP_CHUNK
down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK]
w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK]
for dob in pl.range(down_out_blocks):
d0 = dob * DOWN_OUT_CHUNK
fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"):
down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32)
for ob in pl.pipeline(0, down_mlp_blocks, stage=2):
o0 = ob * DOWN_MLP_CHUNK
down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK]
w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK]
if o0 == 0:
down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32)
else:
down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk)
resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK]
out_chunk = pl.add(down_acc, resid_chunk_fp32)
out = pl.assemble(out, pl.cast(out_chunk, target_type=pl.BF16), [b0, d0])
else:
for dob in pl.parallel(0, down_out_blocks, 1):
d0 = dob * DOWN_OUT_CHUNK
fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32)

with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"):
down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32)
for ob in pl.pipeline(0, down_mlp_blocks, stage=2):
o0 = ob * DOWN_MLP_CHUNK
down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK]
w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK]
if o0 == 0:
down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32)
else:
down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk)
fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0])

with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual_tail"):
down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK]
resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK]
out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32)
out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16)
out_chunk_trimmed = pl.slice(
out_chunk_cast,
[BATCH_TILE, DOWN_OUT_CHUNK],
[0, 0],
valid_shape=[cur_valid, DOWN_OUT_CHUNK],
)
out = pl.assemble(out, out_chunk_trimmed, [b0, d0])
fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0])

with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual"):
down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK]
resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK]
out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32)
out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16)
out_chunk_trimmed = pl.slice(
out_chunk_cast,
[BATCH_TILE, DOWN_OUT_CHUNK],
[0, 0],
valid_shape=[cur_valid, DOWN_OUT_CHUNK],
)
out = pl.assemble(out, out_chunk_trimmed, [b0, d0])

return out

Expand Down Expand Up @@ -994,7 +1009,7 @@ def chunked_row_sq_sum(x, k_chunk):
if __name__ == "__main__":
import argparse
import sys
from golden import run
from golden import RunConfig, run

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--platform", type=str, default="a2a3",
Expand Down Expand Up @@ -1033,15 +1048,17 @@ def chunked_row_sq_sum(x, k_chunk):
program=build_qwen3_decode_program(batch=args.batch),
specs=build_tensor_specs(batch=args.batch, use_max_seq=args.max_seq),
golden_fn=golden_qwen3_decode,
compile_cfg=dict(dump_passes=True),
runtime_cfg=dict(
platform=args.platform,
device_id=args.device,
enable_l2_swimlane=args.enable_l2_swimlane,
enable_pmu=args.enable_pmu,
config=RunConfig(
rtol=3e-3,
atol=3e-3,
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
device_id=args.device,
enable_l2_swimlane=args.enable_l2_swimlane,
enable_pmu=args.enable_pmu,
),
),
rtol=3e-3,
atol=3e-3,
)
if not result.passed:
if result.error:
Expand Down
Loading