diff --git a/models/qwen3/14b/qwen3_14b_decode.py b/models/qwen3/14b/qwen3_14b_decode.py index cf98a76..16f5a71 100644 --- a/models/qwen3/14b/qwen3_14b_decode.py +++ b/models/qwen3/14b/qwen3_14b_decode.py @@ -62,6 +62,8 @@ Q_OUT_CHUNK = 256 KV_OUT_CHUNK = 256 BATCH_TILE = 16 +RMSNORM_SPMD_CORES = 2 +RMSNORM_SPMD_ROWS = BATCH_TILE // RMSNORM_SPMD_CORES # Scope 2 tiling constants. # Qwen3-14B uses 40 Q heads and 8 KV heads, so q_per_kv = 5. @@ -116,6 +118,96 @@ def build_qwen3_decode_program( @pl.program class Qwen3Decode: + @pl.function(type=pl.FunctionType.InCore) + def rmsnorm_kernel( + self, + hidden_states: pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16], + b0: pl.Scalar[pl.INDEX], + cur_valid: pl.Scalar[pl.INDEX], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + normed_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + block_idx = pl.tile.get_block_idx() + row_start = block_idx * RMSNORM_SPMD_ROWS + local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) + + partial_sq = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK], + [b0 + row_start, k0], + valid_shape=[local_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + partial_sq = pl.add(partial_sq, pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, RMSNORM_SPMD_ROWS])) + variance = pl.reshape(pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), [RMSNORM_SPMD_ROWS, 1]) + inv_rms = pl.recip(pl.sqrt(variance)) + + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK], + [b0 + row_start, k0], + valid_shape=[local_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK] + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) + normed_tile = pl.assemble( + normed_tile, + pl.cast(normed, target_type=pl.BF16), + [row_start, k0], + ) + return normed_tile + + @pl.function(type=pl.FunctionType.InCore) + def post_rmsnorm_kernel( + self, + resid1_tile: pl.Tensor[[BATCH_TILE, hidden], pl.FP32], + cur_valid: pl.Scalar[pl.INDEX], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + post_norm_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + block_idx = pl.tile.get_block_idx() + row_start = block_idx * RMSNORM_SPMD_ROWS + local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) + + sq_sum = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = pl.slice( + resid1_tile, + [RMSNORM_SPMD_ROWS, K_CHUNK], + [row_start, k0], + valid_shape=[local_valid, K_CHUNK], + ) + sq_sum = pl.add(sq_sum, pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, RMSNORM_SPMD_ROWS])) + inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) + + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = pl.slice( + resid1_tile, + [RMSNORM_SPMD_ROWS, K_CHUNK], + [row_start, k0], + valid_shape=[local_valid, K_CHUNK], + ) + post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK] + post_normed = pl.col_expand_mul( + pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [RMSNORM_SPMD_ROWS, 1])), + post_gamma, + ) + normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) + post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [row_start, k0]) + return post_norm_tile + @pl.function(type=pl.FunctionType.Opaque) def qwen3_decode( self, @@ -172,47 +264,8 @@ def qwen3_decode( cur_valid = pl.min(BATCH_TILE, user_batch - b0) normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): - partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.pipeline(input_proj_k_blocks, stage=4): - k0 = kb * INPUT_PROJ_K_CHUNK - x_chunk = pl.cast( - pl.slice( - hidden_states, - [BATCH_TILE, INPUT_PROJ_K_CHUNK], - [b0, k0], - valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], - ), - target_type=pl.FP32, - ) - partial_sq = pl.add( - partial_sq, - pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]), - ) - variance = pl.reshape( - pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), - [BATCH_TILE, 1], - ) - inv_rms = pl.recip(pl.sqrt(variance)) - - for kb in pl.pipeline(input_proj_k_blocks, stage=4): - k0 = kb * INPUT_PROJ_K_CHUNK - x_chunk = pl.cast( - pl.slice( - hidden_states, - [BATCH_TILE, INPUT_PROJ_K_CHUNK], - [b0, k0], - valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], - ), - target_type=pl.FP32, - ) - gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK] - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) - normed_tile = pl.assemble( - normed_tile, - pl.cast(normed, target_type=pl.BF16), - [0, k0], - ) + with pl.spmd(RMSNORM_SPMD_CORES): + normed_tile = self.rmsnorm_kernel(hidden_states, b0, cur_valid, input_rms_weight, normed_tile) for q0 in pl.parallel(0, hidden, Q_OUT_CHUNK): with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"): @@ -563,27 +616,9 @@ def qwen3_decode( resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="post_rmsnorm"): - sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.pipeline(hidden_blocks, stage=2): - k0 = kb * K_CHUNK - resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] - sq_sum = pl.add( - sq_sum, - pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]), - ) - inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) - - for kb in pl.pipeline(hidden_blocks, stage=2): - k0 = kb * K_CHUNK - resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] - post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK] - post_normed = pl.col_expand_mul( - pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), - post_gamma, - ) - normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) - post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) + + with pl.spmd(RMSNORM_SPMD_CORES): + post_norm_tile = self.post_rmsnorm_kernel(resid1_tile, cur_valid, post_rms_weight, post_norm_tile) mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16) for ob in pl.range(mlp_out_blocks): @@ -616,54 +651,34 @@ def qwen3_decode( mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) - if cur_valid == BATCH_TILE: - with pl.at( - level=pl.Level.CORE_GROUP, - optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.LEFT_RIGHT)], - name_hint="down_proj_residual", - ): - for dob in pl.parallel(down_out_blocks, chunk=1): - d0 = dob * DOWN_OUT_CHUNK - down_mlp_chunk_bf16 = mlp_tile[:, 0:DOWN_MLP_CHUNK] - w_down_chunk = w_down[0:DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] - down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) - for ob in pl.pipeline(1, down_mlp_blocks, stage=2): - o0 = ob * DOWN_MLP_CHUNK - down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] - w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + for dob in pl.range(down_out_blocks): + d0 = dob * DOWN_OUT_CHUNK + fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): + down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + for ob in pl.pipeline(0, down_mlp_blocks, stage=2): + o0 = ob * DOWN_MLP_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] + w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + if o0 == 0: + down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) + else: down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) - resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] - out_chunk = pl.add(down_acc, resid_chunk_fp32) - out = pl.assemble(out, pl.cast(out_chunk, target_type=pl.BF16), [b0, d0]) - else: - for dob in pl.parallel(0, down_out_blocks, 1): - d0 = dob * DOWN_OUT_CHUNK - fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): - down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) - for ob in pl.pipeline(0, down_mlp_blocks, stage=2): - o0 = ob * DOWN_MLP_CHUNK - down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] - w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] - if o0 == 0: - down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) - else: - down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) - fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual_tail"): - down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK] - resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] - out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) - out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) - out_chunk_trimmed = pl.slice( - out_chunk_cast, - [BATCH_TILE, DOWN_OUT_CHUNK], - [0, 0], - valid_shape=[cur_valid, DOWN_OUT_CHUNK], - ) - out = pl.assemble(out, out_chunk_trimmed, [b0, d0]) + fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual"): + down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK] + resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] + out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) + out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) + out_chunk_trimmed = pl.slice( + out_chunk_cast, + [BATCH_TILE, DOWN_OUT_CHUNK], + [0, 0], + valid_shape=[cur_valid, DOWN_OUT_CHUNK], + ) + out = pl.assemble(out, out_chunk_trimmed, [b0, d0]) return out @@ -994,7 +1009,7 @@ def chunked_row_sq_sum(x, k_chunk): if __name__ == "__main__": import argparse import sys - from golden import run + from golden import RunConfig, run parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", @@ -1033,15 +1048,17 @@ def chunked_row_sq_sum(x, k_chunk): program=build_qwen3_decode_program(batch=args.batch), specs=build_tensor_specs(batch=args.batch, use_max_seq=args.max_seq), golden_fn=golden_qwen3_decode, - compile_cfg=dict(dump_passes=True), - runtime_cfg=dict( - platform=args.platform, - device_id=args.device, - enable_l2_swimlane=args.enable_l2_swimlane, - enable_pmu=args.enable_pmu, + config=RunConfig( + rtol=3e-3, + atol=3e-3, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + enable_l2_swimlane=args.enable_l2_swimlane, + enable_pmu=args.enable_pmu, + ), ), - rtol=3e-3, - atol=3e-3, ) if not result.passed: if result.error: diff --git a/models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py b/models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py new file mode 100644 index 0000000..1cf3c48 --- /dev/null +++ b/models/qwen3/14b/qwen3_14b_decode_qk_softmax_sv_online_softmax_fused.py @@ -0,0 +1,1065 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Qwen3-14B single-layer decode forward. + +Scope 1: + 1. RMSNorm of input hidden states + 2. Q/K/V projection via matmul + +Per-head q_norm / k_norm + +Scope 2: + 1. K RoPE + paged cache write, V paged cache write, Q RoPE + pad + 2. QK matmul + 3. Softmax + 4. SV matmul + 5. Online-softmax accumulation + final normalisation + +Scope 3: + 1. Output projection: attn_out × wo + 2. Residual addition with hidden_states + 3. Post-attention RMSNorm + 4. MLP: gate/up projections, SiLU activation, down projection + 5. Final residual addition +""" + +# pyright: reportUndefinedVariable=false + +import pypto.language as pl + +# Dynamic dims for arbitrary user_batch support. Host allocates every +# batch-dependent tensor at the user-visible batch (no host pad / no +# host trim); the kernel internally rounds up to BATCH_TILE, zero-pads +# trailing rows of every input via valid_shape on the load slice, and +# trims the BF16 ND output via vec-to-vec textract before tstore. A +# single compiled program serves any user_batch <= host KV-cache +# capacity. +USER_BATCH_DYN = pl.dynamic("USER_BATCH_DYN") +KV_CACHE_ROWS_DYN = pl.dynamic("KV_CACHE_ROWS_DYN") +BLOCK_TABLE_FLAT_DYN = pl.dynamic("BLOCK_TABLE_FLAT_DYN") + +BATCH = 16 +MAX_SEQ = 4096 +NUM_HEADS = 40 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +HIDDEN = NUM_HEADS * HEAD_DIM # 5120 +INTERMEDIATE = 17408 +KV_HIDDEN = NUM_KV_HEADS * HEAD_DIM + +EPS = 1e-6 +HIDDEN_INV = 1.0 / HIDDEN + +# Scope 1 tiling constants. +INPUT_PROJ_K_CHUNK = 256 +KV_PROJ_K_CHUNK = 256 +Q_OUT_CHUNK = 256 +KV_OUT_CHUNK = 256 +BATCH_TILE = 16 + +# Scope 2 tiling constants. +# Qwen3-14B uses 40 Q heads and 8 KV heads, so q_per_kv = 5. +Q_HEAD_BATCH = 5 +Q_HEAD_PAD = 16 +SEQ_TILE = 256 +SB_BATCH = 128 +BLOCK_SIZE = SEQ_TILE + +# Scope 3 tiling constants. +K_CHUNK = 256 +OUT_PROJ_K_CHUNK = 256 +OUT_PROJ_N_CHUNK = 256 +MLP_OUT_CHUNK = 256 +DOWN_MLP_CHUNK = 256 +DOWN_OUT_CHUNK = 256 + + +def build_qwen3_decode_program( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, +): + # The `batch` parameter is only used by build_tensor_specs to size + # host buffers; it is no longer baked into the program. Every + # batch-dependent kernel signature dim is a pl.dynamic() variable + # (USER_BATCH_DYN / BLOCK_TABLE_FLAT_DYN / KV_CACHE_ROWS_DYN), so a + # single compiled program serves any user_batch <= host capacity. + hidden = hidden_size + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + input_proj_k_blocks = hidden // INPUT_PROJ_K_CHUNK + kv_proj_k_blocks = hidden // KV_PROJ_K_CHUNK + out_proj_k_blocks = hidden // OUT_PROJ_K_CHUNK + hidden_blocks = hidden // K_CHUNK + down_out_blocks = hidden // DOWN_OUT_CHUNK + out_proj_n_blocks = hidden // OUT_PROJ_N_CHUNK + mlp_out_blocks = inter // MLP_OUT_CHUNK + down_mlp_blocks = inter // DOWN_MLP_CHUNK + max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + half_dim = head_dim // 2 + head_dim_inv = 1.0 / head_dim + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + attn_scale = 1.0 / (head_dim ** 0.5) + max_ctx_blocks = max_blocks_per_seq + + @pl.program + class Qwen3Decode: + @pl.function(type=pl.FunctionType.Opaque) + def qwen3_decode( + self, + hidden_states: pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + wq: pl.Tensor[[hidden, hidden], pl.BF16], + wk: pl.Tensor[[hidden, kv_hidden], pl.BF16], + wv: pl.Tensor[[hidden, kv_hidden], pl.BF16], + q_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + k_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + seq_lens: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + block_table: pl.Tensor[[BLOCK_TABLE_FLAT_DYN], pl.INT32], + slot_mapping: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + rope_cos: pl.Tensor[[max_seq, head_dim], pl.FP32], + rope_sin: pl.Tensor[[max_seq, head_dim], pl.FP32], + k_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + v_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + wo: pl.Tensor[[hidden, hidden], pl.BF16], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + w_gate: pl.Tensor[[hidden, inter], pl.BF16], + w_up: pl.Tensor[[hidden, inter], pl.BF16], + w_down: pl.Tensor[[inter, hidden], pl.BF16], + out: pl.Out[pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]], + ) -> pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]: + # Runtime user_batch (host-visible batch) and BATCH_TILE-aligned + # internal batch_padded. All scope-1/scope-3 batch loops iterate + # over batch_padded and zero-pad/trim using valid_shape on + # input/output slices. Scope-2 iterates over user_batch directly + # (its outer loop is sequential per request, no per-tile pad + # gymnastics needed). + user_batch = pl.tensor.dim(hidden_states, 0) + batch_padded = ((user_batch + BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE + + # Intermediate FP32 tensors between scope 1 and scope 2. + # Allocated at runtime batch_padded; pl.create_tensor zero-inits + # so trailing (batch_padded - user_batch) padded rows are 0, + # which is the invariant relied on by Q/K-norm and scope-3. + q_proj = pl.create_tensor([batch_padded, hidden], dtype=pl.FP32) + k_proj = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + v_proj = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + q_proj_norm = pl.create_tensor([batch_padded, hidden], dtype=pl.FP32) + k_proj_norm = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + + # Scope 1: input RMSNorm + Q/K/V projection. + # Loop iterates over batch_padded (BATCH_TILE-aligned) so every + # matmul tile has a static known M dim of BATCH_TILE (a2a3 + # requirement). Trailing rows in the tail iter are zero-padded + # at load time via valid_shape on the hidden_states slice. + # RMSNorm of zero rows yields 0 (x=0 -> normed = 0 * rsqrt(EPS) + # * gamma = 0), so normed_tile padded rows stay 0. Subsequent + # q/k/v matmul reads from this in-kernel staging only, so + # padded q_proj/k_proj/v_proj rows are 0 acc, harmless. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): + partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [BATCH_TILE, INPUT_PROJ_K_CHUNK], + [b0, k0], + valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + partial_sq = pl.add( + partial_sq, + pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]), + ) + variance = pl.reshape( + pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), + [BATCH_TILE, 1], + ) + inv_rms = pl.recip(pl.sqrt(variance)) + + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [BATCH_TILE, INPUT_PROJ_K_CHUNK], + [b0, k0], + valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK] + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) + normed_tile = pl.assemble( + normed_tile, + pl.cast(normed, target_type=pl.BF16), + [0, k0], + ) + + for q0 in pl.parallel(0, hidden, Q_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"): + q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, input_proj_k_blocks, stage=2): + k0 = kb * INPUT_PROJ_K_CHUNK + tile_a_i = normed_tile[:, k0 : k0 + INPUT_PROJ_K_CHUNK] + tile_b_i = wq[k0 : k0 + INPUT_PROJ_K_CHUNK, q0 : q0 + Q_OUT_CHUNK] + if k0 == 0: + q_acc = pl.matmul(tile_a_i, tile_b_i, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + + for kv0 in pl.parallel(0, kv_hidden, KV_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="k_proj"): + k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, kv_proj_k_blocks, stage=2): + k0 = kb * KV_PROJ_K_CHUNK + k_tile_a_i = normed_tile[:, k0 : k0 + KV_PROJ_K_CHUNK] + k_tile_w_i = wk[k0 : k0 + KV_PROJ_K_CHUNK, kv0 : kv0 + KV_OUT_CHUNK] + if k0 == 0: + k_acc = pl.matmul(k_tile_a_i, k_tile_w_i, out_dtype=pl.FP32) + else: + k_acc = pl.matmul_acc(k_acc, k_tile_a_i, k_tile_w_i) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="v_proj"): + v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, kv_proj_k_blocks, stage=2): + k0 = kb * KV_PROJ_K_CHUNK + v_tile_a_i = normed_tile[:, k0 : k0 + KV_PROJ_K_CHUNK] + v_tile_w_i = wv[k0 : k0 + KV_PROJ_K_CHUNK, kv0 : kv0 + KV_OUT_CHUNK] + if k0 == 0: + v_acc = pl.matmul(v_tile_a_i, v_tile_w_i, out_dtype=pl.FP32) + else: + v_acc = pl.matmul_acc(v_acc, v_tile_a_i, v_tile_w_i) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + + # HF-style per-head q_norm / k_norm before RoPE, batched to avoid + # generating unsupported 1x1 vec-tile scalar ops on A2/A3. + # Loops over batch_padded; q_proj/k_proj are kernel-internal + # staging with zero-init padded rows (RMSNorm of 0 stays 0), + # so no valid_shape is needed here. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_norm"): + for h in pl.range(num_kv_heads): + q0 = h * q_per_kv * head_dim + q_chunk = pl.reshape( + q_proj[b0 : b0 + BATCH_TILE, q0 : q0 + Q_HEAD_BATCH * head_dim], + [BATCH_TILE * Q_HEAD_BATCH, head_dim], + ) + q_sq_sum = pl.row_sum(pl.mul(q_chunk, q_chunk)) + q_inv_rms = pl.rsqrt(pl.add(pl.mul(q_sq_sum, head_dim_inv), EPS)) + q_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(q_chunk, q_inv_rms), + q_norm_weight, + ) + q_chunk_norm_flat = pl.reshape(q_chunk_norm, [BATCH_TILE, Q_HEAD_BATCH * head_dim]) + q_proj_norm = pl.assemble(q_proj_norm, q_chunk_norm_flat, [b0, q0]) + + k0 = h * head_dim + k_chunk = k_proj[b0 : b0 + BATCH_TILE, k0 : k0 + head_dim] + k_sq_sum = pl.row_sum(pl.mul(k_chunk, k_chunk)) + k_inv_rms = pl.rsqrt(pl.add(pl.mul(k_sq_sum, head_dim_inv), EPS)) + k_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(k_chunk, k_inv_rms), + k_norm_weight, + ) + k_proj_norm = pl.assemble(k_proj_norm, k_chunk_norm, [b0, k0]) + + # Scope 2: RoPE + KV cache update + grouped decode attention. + # attn_out is allocated at batch_padded so scope-3 (which loops + # over batch_padded) can read full BATCH_TILE rows in every + # iteration; padded rows are zero-init and stay 0 (scope-2 only + # writes valid rows). all_q_padded is sized similarly; each + # Q_HEAD_PAD block is padded inside rope_kv_cache. + attn_out = pl.create_tensor([batch_padded, hidden], dtype=pl.BF16) + all_q_padded = pl.create_tensor( + [batch_padded * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16, + ) + + # Outer loop iterates user_batch sequentially (one row per iter). + # seq_lens / slot_mapping are sized [USER_BATCH_DYN] so reading + # b in [0, user_batch) is in-bounds. Padded b rows do not need + # attention; their attn_out rows stay zero (zero-init). + for b in pl.parallel(user_batch): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE + block_table_base = b * max_blocks_per_seq + slot = pl.tensor.read(slot_mapping, [b]) + slot_block = slot // BLOCK_SIZE + slot_offset = slot - slot_block * BLOCK_SIZE + cos_row = rope_cos[pos : pos + 1, :] + sin_row = rope_sin[pos : pos + 1, :] + cos_lo = cos_row[:, 0:half_dim] + cos_hi = cos_row[:, half_dim:head_dim] + sin_lo = sin_row[:, 0:half_dim] + sin_hi = sin_row[:, half_dim:head_dim] + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_kv_cache"): + for ki in pl.range(num_kv_heads): + kv_col = ki * head_dim + cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset + k_lo = k_proj_norm[b : b + 1, kv_col : kv_col + half_dim] + k_hi = k_proj_norm[b : b + 1, kv_col + half_dim : kv_col + head_dim] + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_lo, target_type=pl.BF16), + [cache_row, 0], + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_hi, target_type=pl.BF16), + [cache_row, half_dim], + ) + v_cache = pl.assemble( + v_cache, + pl.cast( + v_proj[b : b + 1, kv_col : kv_col + head_dim], + target_type=pl.BF16, + ), + [cache_row, 0], + ) + q_base = ki * q_per_kv + q_block = pl.reshape( + q_proj_norm[b : b + 1, q_base * head_dim : (q_base + Q_HEAD_BATCH) * head_dim], + [Q_HEAD_BATCH, head_dim], + ) + q_lo = q_block[:, 0:half_dim] + q_hi = q_block[:, half_dim:head_dim] + rot_lo_bf16 = pl.cast( + pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_lo_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD, 0], + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_hi_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD, half_dim], + ) + all_q_padded = pl.assemble( + all_q_padded, + pl.cast( + pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), + target_type=pl.BF16, + ), + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + Q_HEAD_BATCH, 0], + ) + + attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) + for gi in pl.parallel(0, total_q_groups, 2): + gi0 = gi + gi1 = gi + 1 + + kvh0 = gi0 // q_groups + qg0 = gi0 - kvh0 * q_groups + q_base0 = kvh0 * q_per_kv + qg0 * Q_HEAD_BATCH + q_padded_row0 = b * total_q_groups * Q_HEAD_PAD + gi0 * Q_HEAD_PAD + q_padded0 = all_q_padded[q_padded_row0 : q_padded_row0 + Q_HEAD_PAD, :] + + kvh1 = gi1 // q_groups + qg1 = gi1 - kvh1 * q_groups + q_base1 = kvh1 * q_per_kv + qg1 * Q_HEAD_BATCH + q_padded_row1 = b * total_q_groups * Q_HEAD_PAD + gi1 * Q_HEAD_PAD + q_padded1 = all_q_padded[q_padded_row1 : q_padded_row1 + Q_HEAD_PAD, :] + + all_raw_scores0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + all_raw_scores1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + all_exp_padded0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_exp_padded1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_cur_mi0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_mi1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="qk_softmax"): + for sb in pl.range(ctx_blocks): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + + cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE + k_tile0 = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32) + all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) + + cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE + k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] + raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) + all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) + + s0 = sb * BLOCK_SIZE + valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) + + scores_valid0 = pl.slice( + all_raw_scores0, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) + scores0 = pl.mul(scores_padded0, attn_scale) + cur_mi0 = pl.row_max(scores0) + exp_scores0 = pl.exp(pl.row_expand_sub(scores0, cur_mi0)) + exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16) + exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32) + cur_li0 = pl.row_sum(exp_scores_fp32_0) + all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0]) + all_cur_mi0 = pl.assemble(all_cur_mi0, cur_mi0, [sb * Q_HEAD_PAD, 0]) + all_cur_li0 = pl.assemble(all_cur_li0, cur_li0, [sb * Q_HEAD_PAD, 0]) + + scores_valid1 = pl.slice( + all_raw_scores1, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min) + scores1 = pl.mul(scores_padded1, attn_scale) + cur_mi1 = pl.row_max(scores1) + exp_scores1 = pl.exp(pl.row_expand_sub(scores1, cur_mi1)) + exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16) + exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32) + cur_li1 = pl.row_sum(exp_scores_fp32_1) + all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0]) + all_cur_mi1 = pl.assemble(all_cur_mi1, cur_mi1, [sb * Q_HEAD_PAD, 0]) + all_cur_li1 = pl.assemble(all_cur_li1, cur_li1, [sb * Q_HEAD_PAD, 0]) + + all_oi_tmp0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) + all_oi_tmp1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="sv_online_softmax"): + for sb in pl.range(ctx_blocks): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + + cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE + exp_tile0 = all_exp_padded0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + v_tile0 = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32) + all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0]) + + cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE + exp_tile1 = all_exp_padded1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + v_tile1 = v_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] + oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32) + all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0]) + + oi0 = all_oi_tmp0[0:Q_HEAD_PAD, :] + mi0 = all_cur_mi0[0:Q_HEAD_PAD, :] + li0 = all_cur_li0[0:Q_HEAD_PAD, :] + oi1 = all_oi_tmp1[0:Q_HEAD_PAD, :] + mi1 = all_cur_mi1[0:Q_HEAD_PAD, :] + li1 = all_cur_li1[0:Q_HEAD_PAD, :] + for sb in pl.range(1, ctx_blocks): + oi_tmp_valid0 = all_oi_tmp0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_mi0 = all_cur_mi0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_li0 = all_cur_li0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + mi_new0 = pl.maximum(mi0, cur_mi0) + alpha0 = pl.exp(pl.sub(mi0, mi_new0)) + beta0 = pl.exp(pl.sub(cur_mi0, mi_new0)) + li0 = pl.add(pl.mul(alpha0, li0), pl.mul(beta0, cur_li0)) + oi0 = pl.add(pl.row_expand_mul(oi0, alpha0), pl.row_expand_mul(oi_tmp_valid0, beta0)) + mi0 = mi_new0 + + oi_tmp_valid1 = all_oi_tmp1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_mi1 = all_cur_mi1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_li1 = all_cur_li1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + mi_new1 = pl.maximum(mi1, cur_mi1) + alpha1 = pl.exp(pl.sub(mi1, mi_new1)) + beta1 = pl.exp(pl.sub(cur_mi1, mi_new1)) + li1 = pl.add(pl.mul(alpha1, li1), pl.mul(beta1, cur_li1)) + oi1 = pl.add(pl.row_expand_mul(oi1, alpha1), pl.row_expand_mul(oi_tmp_valid1, beta1)) + mi1 = mi_new1 + + ctx0 = pl.row_expand_div(oi0, li0) + ctx_valid0 = ctx0[0:Q_HEAD_BATCH, :] + ctx_flat_bf16_0 = pl.cast(pl.reshape(ctx_valid0, [1, Q_HEAD_BATCH * head_dim]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_0, [0, q_base0 * head_dim]) + + ctx1 = pl.row_expand_div(oi1, li1) + ctx_valid1 = ctx1[0:Q_HEAD_BATCH, :] + ctx_flat_bf16_1 = pl.cast(pl.reshape(ctx_valid1, [1, Q_HEAD_BATCH * head_dim]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_1, [0, q_base1 * head_dim]) + + attn_out = pl.assemble(attn_out, attn_row, [b, 0]) + + # Scope 3: output projection + residual + post RMSNorm + MLP + residual. + # Loops over batch_padded so every iteration processes a full + # [BATCH_TILE, *] tile (a2a3 matmul M-tile constraint). + # `cur_valid` clamps the user-visible row count for input load + # (hidden_states valid_shape) and final out store (vec-to-vec + # textract trim). When user_batch is BATCH_TILE-aligned, + # cur_valid == BATCH_TILE every iter and trim is a no-op. + # + # Final down-proj + residual + cast uses the two-incore pattern + # validated in dynamic_batch_pad_repro: + # cube incore : matmul_acc -> FP32 -> assemble to GM scratch + # vec incore : tload FP32 chunk -> add FP32 resid -> cast BF16 + # (preserves ND layout) -> slice(valid_shape) + # -> assemble to out + # ND -> ND vec-to-vec textract is supported on ptoas 0.31. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + resid1_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="out_proj_residual"): + for ob in pl.parallel(0, out_proj_n_blocks, chunk=1): + o0 = ob * OUT_PROJ_N_CHUNK + + hidden_chunk = pl.slice( + hidden_states, + [BATCH_TILE, OUT_PROJ_N_CHUNK], + [b0, o0], + valid_shape=[cur_valid, OUT_PROJ_N_CHUNK], + ) + + o_acc = pl.create_tensor([BATCH_TILE, OUT_PROJ_N_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, out_proj_k_blocks, stage=2): + k0 = kb * OUT_PROJ_K_CHUNK + a_chunk = attn_out[b0 : b0 + BATCH_TILE, k0 : k0 + OUT_PROJ_K_CHUNK] + w_chunk = wo[k0 : k0 + OUT_PROJ_K_CHUNK, o0 : o0 + OUT_PROJ_N_CHUNK] + if k0 == 0: + o_acc = pl.matmul(a_chunk, w_chunk, out_dtype=pl.FP32) + else: + o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) + + resid = pl.cast(hidden_chunk, target_type=pl.FP32) + resid_sum = pl.add(o_acc, resid) + resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) + + post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="post_rmsnorm"): + sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] + sq_sum = pl.add( + sq_sum, + pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]), + ) + inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) + + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] + post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK] + post_normed = pl.col_expand_mul( + pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), + post_gamma, + ) + normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) + post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) + + mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16) + for ob in pl.range(mlp_out_blocks): + o0 = ob * MLP_OUT_CHUNK + with pl.at(level=pl.Level.CORE_GROUP, name_hint="gate_proj"): + gate_acc = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, hidden_blocks, stage=2): + k0 = kb * K_CHUNK + post_chunk = post_norm_tile[:, k0 : k0 + K_CHUNK] + wg = w_gate[k0 : k0 + K_CHUNK, o0 : o0 + MLP_OUT_CHUNK] + if k0 == 0: + gate_acc = pl.matmul(post_chunk, wg, out_dtype=pl.FP32) + else: + gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="up_proj"): + up_acc = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, hidden_blocks, stage=2): + k0 = kb * K_CHUNK + post_chunk = post_norm_tile[:, k0 : k0 + K_CHUNK] + wu = w_up[k0 : k0 + K_CHUNK, o0 : o0 + MLP_OUT_CHUNK] + if k0 == 0: + up_acc = pl.matmul(post_chunk, wu, out_dtype=pl.FP32) + else: + up_acc = pl.matmul_acc(up_acc, post_chunk, wu) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="silu"): + sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) + mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) + mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) + mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) + + if cur_valid == BATCH_TILE: + with pl.at( + level=pl.Level.CORE_GROUP, + optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.LEFT_RIGHT)], + name_hint="down_proj_residual", + ): + for dob in pl.parallel(down_out_blocks, chunk=1): + d0 = dob * DOWN_OUT_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, 0:DOWN_MLP_CHUNK] + w_down_chunk = w_down[0:DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) + for ob in pl.pipeline(1, down_mlp_blocks, stage=2): + o0 = ob * DOWN_MLP_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] + w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) + resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] + out_chunk = pl.add(down_acc, resid_chunk_fp32) + out = pl.assemble(out, pl.cast(out_chunk, target_type=pl.BF16), [b0, d0]) + else: + for dob in pl.parallel(0, down_out_blocks, 1): + d0 = dob * DOWN_OUT_CHUNK + fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): + down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + for ob in pl.pipeline(0, down_mlp_blocks, stage=2): + o0 = ob * DOWN_MLP_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] + w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + if o0 == 0: + down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) + else: + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) + fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual_tail"): + down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK] + resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] + out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) + out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) + out_chunk_trimmed = pl.slice( + out_chunk_cast, + [BATCH_TILE, DOWN_OUT_CHUNK], + [0, 0], + valid_shape=[cur_valid, DOWN_OUT_CHUNK], + ) + out = pl.assemble(out, out_chunk_trimmed, [b0, d0]) + + return out + + return Qwen3Decode + + +def build_tensor_specs( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, + use_max_seq: bool = False, +): + import torch + from golden import TensorSpec + + # Host allocates every batch-dependent tensor at the user-visible + # batch (no host pad / no host trim). The kernel internally rounds + # up to BATCH_TILE, zero-pads via valid_shape on input loads, and + # trims via vec-to-vec textract on the BF16 output. A single + # compiled program serves any batch <= host capacity (USER_BATCH_DYN + # / KV_CACHE_ROWS_DYN / BLOCK_TABLE_FLAT_DYN are pl.dynamic dims). + hidden = num_heads * head_dim + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + num_blocks = batch * max_blocks_per_seq + cache_rows = num_blocks * num_kv_heads * BLOCK_SIZE + synthetic_proj_scale = 0.5 + + if use_max_seq: + seq_lens_seed = torch.full((batch,), max_seq, dtype=torch.int32) + else: + seq_lens_seed = torch.randint(1, max_seq + 1, (batch,), dtype=torch.int32) + + def init_hidden_states(): + return torch.rand(batch, hidden_size) - 0.5 + + def init_rms_weight(): + return torch.rand(1, hidden_size) - 0.5 + + def init_wq(): + return torch.rand(hidden_size, hidden_size) / hidden_size ** 0.5 + + def init_wk(): + return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_wv(): + return synthetic_proj_scale * torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_q_norm_weight(): + return torch.ones(1, head_dim) + + def init_k_norm_weight(): + return torch.ones(1, head_dim) + + def init_seq_lens(): + return seq_lens_seed.clone() + + def init_block_table(): + return torch.arange(num_blocks, dtype=torch.int32) + + def init_slot_mapping(): + slots = torch.empty(batch, dtype=torch.int32) + for b in range(batch): + pos = int(seq_lens_seed[b].item()) - 1 + logical_block = pos // BLOCK_SIZE + page_offset = pos % BLOCK_SIZE + phys_block = b * max_blocks_per_seq + logical_block + slots[b] = phys_block * BLOCK_SIZE + page_offset + return slots + + def init_rope_cos(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_rope_sin(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_k_cache(): + return torch.rand(cache_rows, head_dim) - 0.5 + + def init_v_cache(): + return synthetic_proj_scale * (torch.rand(cache_rows, head_dim) - 0.5) + + def init_wo(): + return synthetic_proj_scale * (torch.rand(hidden_size, hidden_size) - 0.5) / hidden_size ** 0.5 + + def init_post_rms_weight(): + return torch.ones(1, hidden_size) + + def init_w_gate(): + return synthetic_proj_scale * (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_up(): + return synthetic_proj_scale * (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_down(): + return synthetic_proj_scale * (torch.rand(inter, hidden_size) - 0.5) / inter ** 0.5 + + return [ + TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, + init_value=init_hidden_states), + TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, + init_value=init_rms_weight), + TensorSpec("wq", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wq), + TensorSpec("wk", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wk), + TensorSpec("wv", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wv), + TensorSpec("q_norm_weight", [1, head_dim], torch.float32, + init_value=init_q_norm_weight), + TensorSpec("k_norm_weight", [1, head_dim], torch.float32, + init_value=init_k_norm_weight), + TensorSpec("seq_lens", [batch], torch.int32, init_value=init_seq_lens), + TensorSpec("block_table", [batch * max_blocks_per_seq], torch.int32, + init_value=init_block_table), + TensorSpec("slot_mapping", [batch], torch.int32, + init_value=init_slot_mapping), + TensorSpec("rope_cos", [max_seq, head_dim], torch.float32, + init_value=init_rope_cos), + TensorSpec("rope_sin", [max_seq, head_dim], torch.float32, + init_value=init_rope_sin), + TensorSpec("k_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_k_cache), + TensorSpec("v_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_v_cache), + TensorSpec("wo", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wo), + TensorSpec("post_rms_weight", [1, hidden_size], torch.float32, + init_value=init_post_rms_weight), + TensorSpec("w_gate", [hidden_size, inter], torch.bfloat16, + init_value=init_w_gate), + TensorSpec("w_up", [hidden_size, inter], torch.bfloat16, + init_value=init_w_up), + TensorSpec("w_down", [inter, hidden_size], torch.bfloat16, + init_value=init_w_down), + TensorSpec("out", [batch, hidden], torch.bfloat16, is_output=True), + ] + + +def golden_qwen3_decode(tensors): + """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" + import math + + import torch + + hidden_states = tensors["hidden_states"] + input_rms_weight = tensors["input_rms_weight"] + wq = tensors["wq"] + wk = tensors["wk"] + wv = tensors["wv"] + q_norm_weight = tensors["q_norm_weight"] + k_norm_weight = tensors["k_norm_weight"] + seq_lens = tensors["seq_lens"] + block_table = tensors["block_table"] + slot_mapping = tensors["slot_mapping"] + rope_cos = tensors["rope_cos"] + rope_sin = tensors["rope_sin"] + k_cache = tensors["k_cache"].clone() + v_cache = tensors["v_cache"].clone() + wo = tensors["wo"] + post_rms_weight = tensors["post_rms_weight"] + w_gate = tensors["w_gate"] + w_up = tensors["w_up"] + w_down = tensors["w_down"] + + batch = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + kv_hidden = wk.shape[1] + head_dim = rope_cos.shape[1] + max_seq = rope_cos.shape[0] + num_kv_heads = kv_hidden // head_dim + num_heads = hidden_size // head_dim + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + half = head_dim // 2 + scale = 1.0 / math.sqrt(head_dim) + eps = 1e-6 + max_ctx_blocks = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + + def tiled_matmul(lhs, rhs, k_chunk, n_chunk): + out = torch.zeros(lhs.shape[0], rhs.shape[1], dtype=torch.float32) + for n0 in range(0, rhs.shape[1], n_chunk): + acc = torch.zeros(lhs.shape[0], n_chunk, dtype=torch.float32) + for k0 in range(0, lhs.shape[1], k_chunk): + acc = acc + lhs[:, k0 : k0 + k_chunk].float() @ rhs[ + k0 : k0 + k_chunk, + n0 : n0 + n_chunk, + ].float() + out[:, n0 : n0 + n_chunk] = acc + return out + + def chunked_row_sq_sum(x, k_chunk): + acc = torch.zeros(x.shape[0], 1, dtype=torch.float32) + for k0 in range(0, x.shape[1], k_chunk): + x_chunk = x[:, k0 : k0 + k_chunk] + acc = acc + (x_chunk * x_chunk).sum(dim=-1, keepdim=True) + return acc + + q_proj = torch.zeros(batch, hidden_size, dtype=torch.float32) + k_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + v_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + + for b0 in range(0, batch, BATCH_TILE): + b_end = min(b0 + BATCH_TILE, batch) + x_tile = hidden_states[b0:b_end, :].float() + + sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32) + for k0 in range(0, hidden_size, INPUT_PROJ_K_CHUNK): + x_chunk = x_tile[:, k0:k0 + INPUT_PROJ_K_CHUNK] + sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) + variance = sq_sum / hidden_size + EPS + rms = torch.sqrt(variance) + normed = (x_tile / rms * input_rms_weight.float()).bfloat16() + + q_proj[b0:b_end, :] = tiled_matmul(normed, wq, INPUT_PROJ_K_CHUNK, Q_OUT_CHUNK) + k_proj[b0:b_end, :] = tiled_matmul(normed, wk, KV_PROJ_K_CHUNK, KV_OUT_CHUNK) + v_proj[b0:b_end, :] = tiled_matmul(normed, wv, KV_PROJ_K_CHUNK, KV_OUT_CHUNK) + + attn_out = torch.zeros(batch, hidden_size, dtype=torch.bfloat16) + + for b in range(batch): + ctx_len = seq_lens[b].item() + pos = ctx_len - 1 + ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE + + cos_row = rope_cos[pos : pos + 1, :] + sin_row = rope_sin[pos : pos + 1, :] + cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] + sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:] + + k_heads = k_proj[b].view(num_kv_heads, head_dim) + k_variance = k_heads.pow(2).mean(dim=-1, keepdim=True) + k_heads = k_heads * torch.rsqrt(k_variance + eps) * k_norm_weight.float() + k_lo_h, k_hi_h = k_heads[:, :half], k_heads[:, half:] + k_rot = torch.cat( + [k_lo_h * cos_lo - k_hi_h * sin_lo, k_hi_h * cos_hi + k_lo_h * sin_hi], + dim=-1, + ) + slot = int(slot_mapping[b].item()) + slot_block = slot // BLOCK_SIZE + slot_offset = slot % BLOCK_SIZE + + for ki in range(num_kv_heads): + cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset + k_cache[cache_row, :] = k_rot[ki].to(torch.bfloat16) + v_cache[cache_row, :] = v_proj[b, ki * head_dim : (ki + 1) * head_dim].to(torch.bfloat16) + + q_heads = q_proj[b].view(num_heads, head_dim) + q_variance = q_heads.pow(2).mean(dim=-1, keepdim=True) + q_heads = q_heads * torch.rsqrt(q_variance + eps) * q_norm_weight.float() + q_lo_h, q_hi_h = q_heads[:, :half], q_heads[:, half:] + q_rot = torch.cat( + [q_lo_h * cos_lo - q_hi_h * sin_lo, q_hi_h * cos_hi + q_lo_h * sin_hi], + dim=-1, + ) + + attn_row = torch.zeros(1, hidden_size, dtype=torch.bfloat16) + for kvh in range(num_kv_heads): + for qg in range(q_groups): + gi = kvh * q_groups + qg + q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH + q_grp_bf16 = q_rot[q_base : q_base + Q_HEAD_BATCH, :].to(torch.bfloat16) + + oi = torch.zeros(Q_HEAD_BATCH, head_dim, dtype=torch.float32) + li = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + mi = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + + for sb in range(ctx_blocks): + s0 = sb * BLOCK_SIZE + valid_len = min(BLOCK_SIZE, ctx_len - s0) + pbid = int(block_table[b * max_ctx_blocks + sb].item()) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + k_tile = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + v_tile = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + + raw_scores = q_grp_bf16.float() @ k_tile.float().T + if valid_len < BLOCK_SIZE: + raw_scores[:, valid_len:] = torch.finfo(torch.float32).min + scores = raw_scores * scale + cur_mi = scores.max(dim=-1, keepdim=True).values + exp_scores = torch.exp(scores - cur_mi) + exp_scores_bf16 = exp_scores.to(torch.bfloat16) + cur_li = exp_scores_bf16.float().sum(dim=-1, keepdim=True) + oi_tmp = exp_scores_bf16.float() @ v_tile.float() + + if sb == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = torch.maximum(mi, cur_mi) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(cur_mi - mi_new) + li = alpha * li + beta * cur_li + oi = oi * alpha + oi_tmp * beta + mi = mi_new + + ctx = oi / li + ctx_flat_bf16 = ctx.reshape(1, -1).to(torch.bfloat16) + attn_row[ + :, + q_base * head_dim : (q_base + Q_HEAD_BATCH) * head_dim, + ] = ctx_flat_bf16 + + attn_out[b : b + 1, :] = attn_row + + o_proj = tiled_matmul(attn_out, wo, OUT_PROJ_K_CHUNK, OUT_PROJ_N_CHUNK) + resid1 = o_proj + hidden_states.float() + + variance = chunked_row_sq_sum(resid1, K_CHUNK) / hidden_size + inv_rms = torch.rsqrt(variance + eps) + normed_bf16 = (resid1 * inv_rms * post_rms_weight).bfloat16() + + gate = tiled_matmul(normed_bf16, w_gate, K_CHUNK, MLP_OUT_CHUNK) + up = tiled_matmul(normed_bf16, w_up, K_CHUNK, MLP_OUT_CHUNK) + mlp_bf16 = (gate * torch.sigmoid(gate) * up).bfloat16() + down = tiled_matmul(mlp_bf16, w_down, DOWN_MLP_CHUNK, DOWN_OUT_CHUNK) + + tensors["out"][:] = (down + resid1).bfloat16() + + +if __name__ == "__main__": + import argparse + import sys + from golden import RunConfig, run + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("-b", "--batch", type=int, default=BATCH, + help=("User-visible batch size. Host allocates every " + "batch-dependent tensor at exactly this size; " + "the kernel internally rounds up to BATCH_TILE " + "(%d), zero-pads input loads via valid_shape, " + "and trims the BF16 output via vec-to-vec " + "textract. A single compiled program serves " + "any batch <= host KV-cache capacity. Default: " + "%%(default)s" % BATCH_TILE)) + parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) + parser.add_argument("--max-seq", action="store_true", default=False) + parser.add_argument("--enable-pmu", nargs="?", const=2, default=0, type=int, choices=[0, 1, 2, 4]) + parser.add_argument( + "--export-kernel-insight", + action="store_true", + default=False, + help=( + "After a successful run, export msprof op-simulator Insight traces " + "for all generated InCore kernels under the same build_output dir." + ), + ) + parser.add_argument( + "--kernel-insight-func", + action="append", + default=[], + help="Only export this generated kernel function; can be repeated.", + ) + args = parser.parse_args() + + result = run( + program=build_qwen3_decode_program(batch=args.batch), + specs=build_tensor_specs(batch=args.batch, use_max_seq=args.max_seq), + golden_fn=golden_qwen3_decode, + config=RunConfig( + rtol=3e-3, + atol=3e-3, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + enable_l2_swimlane=args.enable_l2_swimlane, + enable_pmu=args.enable_pmu, + ), + ), + ) + if not result.passed: + if result.error: + print(result.error) + raise SystemExit(1) + + if args.export_kernel_insight: + if result.work_dir is None: + print("kernel insight export failed: run result has no build_output directory", file=sys.stderr) + raise SystemExit(1) + from tools.export_all_kernel_insight import StepError, main as export_kernel_insight + + export_args = ["--build-dir", str(result.work_dir)] + for func in args.kernel_insight_func: + export_args.extend(["--func", func]) + try: + export_rc = export_kernel_insight(export_args) + except StepError as exc: + print(f"kernel insight export failed: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + if export_rc != 0: + raise SystemExit(export_rc) diff --git a/models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py b/models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py new file mode 100644 index 0000000..1bf6fe8 --- /dev/null +++ b/models/qwen3/14b/qwen3_14b_qk_matmul_softmax_fused.py @@ -0,0 +1,1066 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Qwen3-14B single-layer decode forward. + +Scope 1: + 1. RMSNorm of input hidden states + 2. Q/K/V projection via matmul + +Per-head q_norm / k_norm + +Scope 2: + 1. K RoPE + paged cache write, V paged cache write, Q RoPE + pad + 2. QK matmul + 3. Softmax + 4. SV matmul + 5. Online-softmax accumulation + final normalisation + +Scope 3: + 1. Output projection: attn_out × wo + 2. Residual addition with hidden_states + 3. Post-attention RMSNorm + 4. MLP: gate/up projections, SiLU activation, down projection + 5. Final residual addition +""" + +# pyright: reportUndefinedVariable=false + +import pypto.language as pl + +# Dynamic dims for arbitrary user_batch support. Host allocates every +# batch-dependent tensor at the user-visible batch (no host pad / no +# host trim); the kernel internally rounds up to BATCH_TILE, zero-pads +# trailing rows of every input via valid_shape on the load slice, and +# trims the BF16 ND output via vec-to-vec textract before tstore. A +# single compiled program serves any user_batch <= host KV-cache +# capacity. +USER_BATCH_DYN = pl.dynamic("USER_BATCH_DYN") +KV_CACHE_ROWS_DYN = pl.dynamic("KV_CACHE_ROWS_DYN") +BLOCK_TABLE_FLAT_DYN = pl.dynamic("BLOCK_TABLE_FLAT_DYN") + +BATCH = 16 +MAX_SEQ = 4096 +NUM_HEADS = 40 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +HIDDEN = NUM_HEADS * HEAD_DIM # 5120 +INTERMEDIATE = 17408 +KV_HIDDEN = NUM_KV_HEADS * HEAD_DIM + +EPS = 1e-6 +HIDDEN_INV = 1.0 / HIDDEN + +# Scope 1 tiling constants. +INPUT_PROJ_K_CHUNK = 256 +KV_PROJ_K_CHUNK = 256 +Q_OUT_CHUNK = 256 +KV_OUT_CHUNK = 256 +BATCH_TILE = 16 + +# Scope 2 tiling constants. +# Qwen3-14B uses 40 Q heads and 8 KV heads, so q_per_kv = 5. +Q_HEAD_BATCH = 5 +Q_HEAD_PAD = 16 +SEQ_TILE = 256 +SB_BATCH = 128 +BLOCK_SIZE = SEQ_TILE + +# Scope 3 tiling constants. +K_CHUNK = 256 +OUT_PROJ_K_CHUNK = 256 +OUT_PROJ_N_CHUNK = 256 +MLP_OUT_CHUNK = 256 +DOWN_MLP_CHUNK = 256 +DOWN_OUT_CHUNK = 256 + + +def build_qwen3_decode_program( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, +): + # The `batch` parameter is only used by build_tensor_specs to size + # host buffers; it is no longer baked into the program. Every + # batch-dependent kernel signature dim is a pl.dynamic() variable + # (USER_BATCH_DYN / BLOCK_TABLE_FLAT_DYN / KV_CACHE_ROWS_DYN), so a + # single compiled program serves any user_batch <= host capacity. + hidden = hidden_size + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + input_proj_k_blocks = hidden // INPUT_PROJ_K_CHUNK + kv_proj_k_blocks = hidden // KV_PROJ_K_CHUNK + out_proj_k_blocks = hidden // OUT_PROJ_K_CHUNK + hidden_blocks = hidden // K_CHUNK + down_out_blocks = hidden // DOWN_OUT_CHUNK + out_proj_n_blocks = hidden // OUT_PROJ_N_CHUNK + mlp_out_blocks = inter // MLP_OUT_CHUNK + down_mlp_blocks = inter // DOWN_MLP_CHUNK + max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + half_dim = head_dim // 2 + head_dim_inv = 1.0 / head_dim + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + attn_scale = 1.0 / (head_dim ** 0.5) + max_ctx_blocks = max_blocks_per_seq + + @pl.program + class Qwen3Decode: + @pl.function(type=pl.FunctionType.Opaque) + def qwen3_decode( + self, + hidden_states: pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + wq: pl.Tensor[[hidden, hidden], pl.BF16], + wk: pl.Tensor[[hidden, kv_hidden], pl.BF16], + wv: pl.Tensor[[hidden, kv_hidden], pl.BF16], + q_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + k_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + seq_lens: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + block_table: pl.Tensor[[BLOCK_TABLE_FLAT_DYN], pl.INT32], + slot_mapping: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + rope_cos: pl.Tensor[[max_seq, head_dim], pl.FP32], + rope_sin: pl.Tensor[[max_seq, head_dim], pl.FP32], + k_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + v_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + wo: pl.Tensor[[hidden, hidden], pl.BF16], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + w_gate: pl.Tensor[[hidden, inter], pl.BF16], + w_up: pl.Tensor[[hidden, inter], pl.BF16], + w_down: pl.Tensor[[inter, hidden], pl.BF16], + out: pl.Out[pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]], + ) -> pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16]: + # Runtime user_batch (host-visible batch) and BATCH_TILE-aligned + # internal batch_padded. All scope-1/scope-3 batch loops iterate + # over batch_padded and zero-pad/trim using valid_shape on + # input/output slices. Scope-2 iterates over user_batch directly + # (its outer loop is sequential per request, no per-tile pad + # gymnastics needed). + user_batch = pl.tensor.dim(hidden_states, 0) + batch_padded = ((user_batch + BATCH_TILE - 1) // BATCH_TILE) * BATCH_TILE + + # Intermediate FP32 tensors between scope 1 and scope 2. + # Allocated at runtime batch_padded; pl.create_tensor zero-inits + # so trailing (batch_padded - user_batch) padded rows are 0, + # which is the invariant relied on by Q/K-norm and scope-3. + q_proj = pl.create_tensor([batch_padded, hidden], dtype=pl.FP32) + k_proj = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + v_proj = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + q_proj_norm = pl.create_tensor([batch_padded, hidden], dtype=pl.FP32) + k_proj_norm = pl.create_tensor([batch_padded, kv_hidden], dtype=pl.FP32) + + # Scope 1: input RMSNorm + Q/K/V projection. + # Loop iterates over batch_padded (BATCH_TILE-aligned) so every + # matmul tile has a static known M dim of BATCH_TILE (a2a3 + # requirement). Trailing rows in the tail iter are zero-padded + # at load time via valid_shape on the hidden_states slice. + # RMSNorm of zero rows yields 0 (x=0 -> normed = 0 * rsqrt(EPS) + # * gamma = 0), so normed_tile padded rows stay 0. Subsequent + # q/k/v matmul reads from this in-kernel staging only, so + # padded q_proj/k_proj/v_proj rows are 0 acc, harmless. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + normed_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): + partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [BATCH_TILE, INPUT_PROJ_K_CHUNK], + [b0, k0], + valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + partial_sq = pl.add( + partial_sq, + pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, BATCH_TILE]), + ) + variance = pl.reshape( + pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), + [BATCH_TILE, 1], + ) + inv_rms = pl.recip(pl.sqrt(variance)) + + for kb in pl.pipeline(input_proj_k_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( + pl.slice( + hidden_states, + [BATCH_TILE, INPUT_PROJ_K_CHUNK], + [b0, k0], + valid_shape=[cur_valid, INPUT_PROJ_K_CHUNK], + ), + target_type=pl.FP32, + ) + gamma = input_rms_weight[:, k0 : k0 + INPUT_PROJ_K_CHUNK] + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) + normed_tile = pl.assemble( + normed_tile, + pl.cast(normed, target_type=pl.BF16), + [0, k0], + ) + + for q0 in pl.parallel(0, hidden, Q_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"): + q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, input_proj_k_blocks, stage=2): + k0 = kb * INPUT_PROJ_K_CHUNK + tile_a_i = normed_tile[:, k0 : k0 + INPUT_PROJ_K_CHUNK] + tile_b_i = wq[k0 : k0 + INPUT_PROJ_K_CHUNK, q0 : q0 + Q_OUT_CHUNK] + if k0 == 0: + q_acc = pl.matmul(tile_a_i, tile_b_i, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + + for kv0 in pl.parallel(0, kv_hidden, KV_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="k_proj"): + k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, kv_proj_k_blocks, stage=2): + k0 = kb * KV_PROJ_K_CHUNK + k_tile_a_i = normed_tile[:, k0 : k0 + KV_PROJ_K_CHUNK] + k_tile_w_i = wk[k0 : k0 + KV_PROJ_K_CHUNK, kv0 : kv0 + KV_OUT_CHUNK] + if k0 == 0: + k_acc = pl.matmul(k_tile_a_i, k_tile_w_i, out_dtype=pl.FP32) + else: + k_acc = pl.matmul_acc(k_acc, k_tile_a_i, k_tile_w_i) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="v_proj"): + v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, kv_proj_k_blocks, stage=2): + k0 = kb * KV_PROJ_K_CHUNK + v_tile_a_i = normed_tile[:, k0 : k0 + KV_PROJ_K_CHUNK] + v_tile_w_i = wv[k0 : k0 + KV_PROJ_K_CHUNK, kv0 : kv0 + KV_OUT_CHUNK] + if k0 == 0: + v_acc = pl.matmul(v_tile_a_i, v_tile_w_i, out_dtype=pl.FP32) + else: + v_acc = pl.matmul_acc(v_acc, v_tile_a_i, v_tile_w_i) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + + # HF-style per-head q_norm / k_norm before RoPE, batched to avoid + # generating unsupported 1x1 vec-tile scalar ops on A2/A3. + # Loops over batch_padded; q_proj/k_proj are kernel-internal + # staging with zero-init padded rows (RMSNorm of 0 stays 0), + # so no valid_shape is needed here. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_norm"): + for h in pl.range(num_kv_heads): + q0 = h * q_per_kv * head_dim + q_chunk = pl.reshape( + q_proj[b0 : b0 + BATCH_TILE, q0 : q0 + Q_HEAD_BATCH * head_dim], + [BATCH_TILE * Q_HEAD_BATCH, head_dim], + ) + q_sq_sum = pl.row_sum(pl.mul(q_chunk, q_chunk)) + q_inv_rms = pl.rsqrt(pl.add(pl.mul(q_sq_sum, head_dim_inv), EPS)) + q_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(q_chunk, q_inv_rms), + q_norm_weight, + ) + q_chunk_norm_flat = pl.reshape(q_chunk_norm, [BATCH_TILE, Q_HEAD_BATCH * head_dim]) + q_proj_norm = pl.assemble(q_proj_norm, q_chunk_norm_flat, [b0, q0]) + + k0 = h * head_dim + k_chunk = k_proj[b0 : b0 + BATCH_TILE, k0 : k0 + head_dim] + k_sq_sum = pl.row_sum(pl.mul(k_chunk, k_chunk)) + k_inv_rms = pl.rsqrt(pl.add(pl.mul(k_sq_sum, head_dim_inv), EPS)) + k_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(k_chunk, k_inv_rms), + k_norm_weight, + ) + k_proj_norm = pl.assemble(k_proj_norm, k_chunk_norm, [b0, k0]) + + # Scope 2: RoPE + KV cache update + grouped decode attention. + # attn_out is allocated at batch_padded so scope-3 (which loops + # over batch_padded) can read full BATCH_TILE rows in every + # iteration; padded rows are zero-init and stay 0 (scope-2 only + # writes valid rows). all_q_padded is sized similarly; each + # Q_HEAD_PAD block is padded inside rope_kv_cache. + attn_out = pl.create_tensor([batch_padded, hidden], dtype=pl.BF16) + all_q_padded = pl.create_tensor( + [batch_padded * total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16, + ) + + # Outer loop iterates user_batch sequentially (one row per iter). + # seq_lens / slot_mapping are sized [USER_BATCH_DYN] so reading + # b in [0, user_batch) is in-bounds. Padded b rows do not need + # attention; their attn_out rows stay zero (zero-init). + for b in pl.parallel(user_batch): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE + block_table_base = b * max_blocks_per_seq + slot = pl.tensor.read(slot_mapping, [b]) + slot_block = slot // BLOCK_SIZE + slot_offset = slot - slot_block * BLOCK_SIZE + cos_row = rope_cos[pos : pos + 1, :] + sin_row = rope_sin[pos : pos + 1, :] + cos_lo = cos_row[:, 0:half_dim] + cos_hi = cos_row[:, half_dim:head_dim] + sin_lo = sin_row[:, 0:half_dim] + sin_hi = sin_row[:, half_dim:head_dim] + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_kv_cache"): + for ki in pl.range(num_kv_heads): + kv_col = ki * head_dim + cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset + k_lo = k_proj_norm[b : b + 1, kv_col : kv_col + half_dim] + k_hi = k_proj_norm[b : b + 1, kv_col + half_dim : kv_col + head_dim] + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_lo, target_type=pl.BF16), + [cache_row, 0], + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_hi, target_type=pl.BF16), + [cache_row, half_dim], + ) + v_cache = pl.assemble( + v_cache, + pl.cast( + v_proj[b : b + 1, kv_col : kv_col + head_dim], + target_type=pl.BF16, + ), + [cache_row, 0], + ) + q_base = ki * q_per_kv + q_block = pl.reshape( + q_proj_norm[b : b + 1, q_base * head_dim : (q_base + Q_HEAD_BATCH) * head_dim], + [Q_HEAD_BATCH, head_dim], + ) + q_lo = q_block[:, 0:half_dim] + q_hi = q_block[:, half_dim:head_dim] + rot_lo_bf16 = pl.cast( + pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_lo_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD, 0], + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_hi_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD, half_dim], + ) + all_q_padded = pl.assemble( + all_q_padded, + pl.cast( + pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], dtype=pl.FP32, value=0.0), + target_type=pl.BF16, + ), + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + Q_HEAD_BATCH, 0], + ) + + attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) + for gi in pl.parallel(0, total_q_groups, 2): + gi0 = gi + gi1 = gi + 1 + + kvh0 = gi0 // q_groups + qg0 = gi0 - kvh0 * q_groups + q_base0 = kvh0 * q_per_kv + qg0 * Q_HEAD_BATCH + q_padded_row0 = b * total_q_groups * Q_HEAD_PAD + gi0 * Q_HEAD_PAD + q_padded0 = all_q_padded[q_padded_row0 : q_padded_row0 + Q_HEAD_PAD, :] + + kvh1 = gi1 // q_groups + qg1 = gi1 - kvh1 * q_groups + q_base1 = kvh1 * q_per_kv + qg1 * Q_HEAD_BATCH + q_padded_row1 = b * total_q_groups * Q_HEAD_PAD + gi1 * Q_HEAD_PAD + q_padded1 = all_q_padded[q_padded_row1 : q_padded_row1 + Q_HEAD_PAD, :] + + all_raw_scores0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + all_raw_scores1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + all_exp_padded0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_exp_padded1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_cur_mi0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_mi1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="qk_softmax"): + for sb in pl.range(ctx_blocks): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + + cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE + k_tile0 = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32) + all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) + + cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE + k_tile1 = k_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] + raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) + all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) + + s0 = sb * BLOCK_SIZE + valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) + + scores_valid0 = pl.slice( + all_raw_scores0, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) + scores0 = pl.mul(scores_padded0, attn_scale) + cur_mi0 = pl.row_max(scores0) + exp_scores0 = pl.exp(pl.row_expand_sub(scores0, cur_mi0)) + exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16) + exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32) + cur_li0 = pl.row_sum(exp_scores_fp32_0) + all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0]) + all_cur_mi0 = pl.assemble(all_cur_mi0, cur_mi0, [sb * Q_HEAD_PAD, 0]) + all_cur_li0 = pl.assemble(all_cur_li0, cur_li0, [sb * Q_HEAD_PAD, 0]) + + scores_valid1 = pl.slice( + all_raw_scores1, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min) + scores1 = pl.mul(scores_padded1, attn_scale) + cur_mi1 = pl.row_max(scores1) + exp_scores1 = pl.exp(pl.row_expand_sub(scores1, cur_mi1)) + exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16) + exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32) + cur_li1 = pl.row_sum(exp_scores_fp32_1) + all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0]) + all_cur_mi1 = pl.assemble(all_cur_mi1, cur_mi1, [sb * Q_HEAD_PAD, 0]) + all_cur_li1 = pl.assemble(all_cur_li1, cur_li1, [sb * Q_HEAD_PAD, 0]) + + all_oi_tmp0 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) + all_oi_tmp1 = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, head_dim], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="sv_matmul"): + for sb in pl.range(ctx_blocks): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + + cache_row0 = (pbid * num_kv_heads + kvh0) * BLOCK_SIZE + exp_tile0 = all_exp_padded0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + v_tile0 = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32) + all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0]) + + cache_row1 = (pbid * num_kv_heads + kvh1) * BLOCK_SIZE + exp_tile1 = all_exp_padded1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + v_tile1 = v_cache[cache_row1 : cache_row1 + BLOCK_SIZE, :] + oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32) + all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="online_softmax"): + oi0 = all_oi_tmp0[0:Q_HEAD_PAD, :] + mi0 = all_cur_mi0[0:Q_HEAD_PAD, :] + li0 = all_cur_li0[0:Q_HEAD_PAD, :] + oi1 = all_oi_tmp1[0:Q_HEAD_PAD, :] + mi1 = all_cur_mi1[0:Q_HEAD_PAD, :] + li1 = all_cur_li1[0:Q_HEAD_PAD, :] + for sb in pl.range(1, ctx_blocks): + oi_tmp_valid0 = all_oi_tmp0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_mi0 = all_cur_mi0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_li0 = all_cur_li0[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + mi_new0 = pl.maximum(mi0, cur_mi0) + alpha0 = pl.exp(pl.sub(mi0, mi_new0)) + beta0 = pl.exp(pl.sub(cur_mi0, mi_new0)) + li0 = pl.add(pl.mul(alpha0, li0), pl.mul(beta0, cur_li0)) + oi0 = pl.add(pl.row_expand_mul(oi0, alpha0), pl.row_expand_mul(oi_tmp_valid0, beta0)) + mi0 = mi_new0 + + oi_tmp_valid1 = all_oi_tmp1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_mi1 = all_cur_mi1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + cur_li1 = all_cur_li1[sb * Q_HEAD_PAD : (sb + 1) * Q_HEAD_PAD, :] + mi_new1 = pl.maximum(mi1, cur_mi1) + alpha1 = pl.exp(pl.sub(mi1, mi_new1)) + beta1 = pl.exp(pl.sub(cur_mi1, mi_new1)) + li1 = pl.add(pl.mul(alpha1, li1), pl.mul(beta1, cur_li1)) + oi1 = pl.add(pl.row_expand_mul(oi1, alpha1), pl.row_expand_mul(oi_tmp_valid1, beta1)) + mi1 = mi_new1 + + ctx0 = pl.row_expand_div(oi0, li0) + ctx_valid0 = ctx0[0:Q_HEAD_BATCH, :] + ctx_flat_bf16_0 = pl.cast(pl.reshape(ctx_valid0, [1, Q_HEAD_BATCH * head_dim]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_0, [0, q_base0 * head_dim]) + + ctx1 = pl.row_expand_div(oi1, li1) + ctx_valid1 = ctx1[0:Q_HEAD_BATCH, :] + ctx_flat_bf16_1 = pl.cast(pl.reshape(ctx_valid1, [1, Q_HEAD_BATCH * head_dim]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_1, [0, q_base1 * head_dim]) + + attn_out = pl.assemble(attn_out, attn_row, [b, 0]) + + # Scope 3: output projection + residual + post RMSNorm + MLP + residual. + # Loops over batch_padded so every iteration processes a full + # [BATCH_TILE, *] tile (a2a3 matmul M-tile constraint). + # `cur_valid` clamps the user-visible row count for input load + # (hidden_states valid_shape) and final out store (vec-to-vec + # textract trim). When user_batch is BATCH_TILE-aligned, + # cur_valid == BATCH_TILE every iter and trim is a no-op. + # + # Final down-proj + residual + cast uses the two-incore pattern + # validated in dynamic_batch_pad_repro: + # cube incore : matmul_acc -> FP32 -> assemble to GM scratch + # vec incore : tload FP32 chunk -> add FP32 resid -> cast BF16 + # (preserves ND layout) -> slice(valid_shape) + # -> assemble to out + # ND -> ND vec-to-vec textract is supported on ptoas 0.31. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + resid1_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="out_proj_residual"): + for ob in pl.parallel(0, out_proj_n_blocks, chunk=1): + o0 = ob * OUT_PROJ_N_CHUNK + + hidden_chunk = pl.slice( + hidden_states, + [BATCH_TILE, OUT_PROJ_N_CHUNK], + [b0, o0], + valid_shape=[cur_valid, OUT_PROJ_N_CHUNK], + ) + + o_acc = pl.create_tensor([BATCH_TILE, OUT_PROJ_N_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, out_proj_k_blocks, stage=2): + k0 = kb * OUT_PROJ_K_CHUNK + a_chunk = attn_out[b0 : b0 + BATCH_TILE, k0 : k0 + OUT_PROJ_K_CHUNK] + w_chunk = wo[k0 : k0 + OUT_PROJ_K_CHUNK, o0 : o0 + OUT_PROJ_N_CHUNK] + if k0 == 0: + o_acc = pl.matmul(a_chunk, w_chunk, out_dtype=pl.FP32) + else: + o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) + + resid = pl.cast(hidden_chunk, target_type=pl.FP32) + resid_sum = pl.add(o_acc, resid) + resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) + + post_norm_tile = pl.create_tensor([BATCH_TILE, hidden], dtype=pl.BF16) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="post_rmsnorm"): + sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] + sq_sum = pl.add( + sq_sum, + pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, BATCH_TILE]), + ) + inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) + + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = resid1_tile[:, k0 : k0 + K_CHUNK] + post_gamma = post_rms_weight[:, k0 : k0 + K_CHUNK] + post_normed = pl.col_expand_mul( + pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), + post_gamma, + ) + normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) + post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, k0]) + + mlp_tile = pl.create_tensor([BATCH_TILE, inter], dtype=pl.BF16) + for ob in pl.range(mlp_out_blocks): + o0 = ob * MLP_OUT_CHUNK + with pl.at(level=pl.Level.CORE_GROUP, name_hint="gate_proj"): + gate_acc = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, hidden_blocks, stage=2): + k0 = kb * K_CHUNK + post_chunk = post_norm_tile[:, k0 : k0 + K_CHUNK] + wg = w_gate[k0 : k0 + K_CHUNK, o0 : o0 + MLP_OUT_CHUNK] + if k0 == 0: + gate_acc = pl.matmul(post_chunk, wg, out_dtype=pl.FP32) + else: + gate_acc = pl.matmul_acc(gate_acc, post_chunk, wg) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="up_proj"): + up_acc = pl.create_tensor([BATCH_TILE, MLP_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.pipeline(0, hidden_blocks, stage=2): + k0 = kb * K_CHUNK + post_chunk = post_norm_tile[:, k0 : k0 + K_CHUNK] + wu = w_up[k0 : k0 + K_CHUNK, o0 : o0 + MLP_OUT_CHUNK] + if k0 == 0: + up_acc = pl.matmul(post_chunk, wu, out_dtype=pl.FP32) + else: + up_acc = pl.matmul_acc(up_acc, post_chunk, wu) + + with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="silu"): + sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) + mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) + mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) + mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, o0]) + + if cur_valid == BATCH_TILE: + with pl.at( + level=pl.Level.CORE_GROUP, + optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.LEFT_RIGHT)], + name_hint="down_proj_residual", + ): + for dob in pl.parallel(down_out_blocks, chunk=1): + d0 = dob * DOWN_OUT_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, 0:DOWN_MLP_CHUNK] + w_down_chunk = w_down[0:DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) + for ob in pl.pipeline(1, down_mlp_blocks, stage=2): + o0 = ob * DOWN_MLP_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] + w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) + resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] + out_chunk = pl.add(down_acc, resid_chunk_fp32) + out = pl.assemble(out, pl.cast(out_chunk, target_type=pl.BF16), [b0, d0]) + else: + for dob in pl.parallel(0, down_out_blocks, 1): + d0 = dob * DOWN_OUT_CHUNK + fp32_chunk_gm = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): + down_acc = pl.create_tensor([BATCH_TILE, DOWN_OUT_CHUNK], dtype=pl.FP32) + for ob in pl.pipeline(0, down_mlp_blocks, stage=2): + o0 = ob * DOWN_MLP_CHUNK + down_mlp_chunk_bf16 = mlp_tile[:, o0 : o0 + DOWN_MLP_CHUNK] + w_down_chunk = w_down[o0 : o0 + DOWN_MLP_CHUNK, d0 : d0 + DOWN_OUT_CHUNK] + if o0 == 0: + down_acc = pl.matmul(down_mlp_chunk_bf16, w_down_chunk, out_dtype=pl.FP32) + else: + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) + fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual_tail"): + down_chunk_fp32 = fp32_chunk_gm[:, 0:DOWN_OUT_CHUNK] + resid_chunk_fp32 = resid1_tile[:, d0 : d0 + DOWN_OUT_CHUNK] + out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) + out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) + out_chunk_trimmed = pl.slice( + out_chunk_cast, + [BATCH_TILE, DOWN_OUT_CHUNK], + [0, 0], + valid_shape=[cur_valid, DOWN_OUT_CHUNK], + ) + out = pl.assemble(out, out_chunk_trimmed, [b0, d0]) + + return out + + return Qwen3Decode + + +def build_tensor_specs( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, + use_max_seq: bool = False, +): + import torch + from golden import TensorSpec + + # Host allocates every batch-dependent tensor at the user-visible + # batch (no host pad / no host trim). The kernel internally rounds + # up to BATCH_TILE, zero-pads via valid_shape on input loads, and + # trims via vec-to-vec textract on the BF16 output. A single + # compiled program serves any batch <= host capacity (USER_BATCH_DYN + # / KV_CACHE_ROWS_DYN / BLOCK_TABLE_FLAT_DYN are pl.dynamic dims). + hidden = num_heads * head_dim + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + num_blocks = batch * max_blocks_per_seq + cache_rows = num_blocks * num_kv_heads * BLOCK_SIZE + synthetic_proj_scale = 0.5 + + if use_max_seq: + seq_lens_seed = torch.full((batch,), max_seq, dtype=torch.int32) + else: + seq_lens_seed = torch.randint(1, max_seq + 1, (batch,), dtype=torch.int32) + + def init_hidden_states(): + return torch.rand(batch, hidden_size) - 0.5 + + def init_rms_weight(): + return torch.rand(1, hidden_size) - 0.5 + + def init_wq(): + return torch.rand(hidden_size, hidden_size) / hidden_size ** 0.5 + + def init_wk(): + return torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_wv(): + return synthetic_proj_scale * torch.rand(hidden_size, kv_hidden) / hidden_size ** 0.5 + + def init_q_norm_weight(): + return torch.ones(1, head_dim) + + def init_k_norm_weight(): + return torch.ones(1, head_dim) + + def init_seq_lens(): + return seq_lens_seed.clone() + + def init_block_table(): + return torch.arange(num_blocks, dtype=torch.int32) + + def init_slot_mapping(): + slots = torch.empty(batch, dtype=torch.int32) + for b in range(batch): + pos = int(seq_lens_seed[b].item()) - 1 + logical_block = pos // BLOCK_SIZE + page_offset = pos % BLOCK_SIZE + phys_block = b * max_blocks_per_seq + logical_block + slots[b] = phys_block * BLOCK_SIZE + page_offset + return slots + + def init_rope_cos(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_rope_sin(): + return torch.rand(max_seq, head_dim) - 0.5 + + def init_k_cache(): + return torch.rand(cache_rows, head_dim) - 0.5 + + def init_v_cache(): + return synthetic_proj_scale * (torch.rand(cache_rows, head_dim) - 0.5) + + def init_wo(): + return synthetic_proj_scale * (torch.rand(hidden_size, hidden_size) - 0.5) / hidden_size ** 0.5 + + def init_post_rms_weight(): + return torch.ones(1, hidden_size) + + def init_w_gate(): + return synthetic_proj_scale * (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_up(): + return synthetic_proj_scale * (torch.rand(hidden_size, inter) - 0.5) / hidden_size ** 0.5 + + def init_w_down(): + return synthetic_proj_scale * (torch.rand(inter, hidden_size) - 0.5) / inter ** 0.5 + + return [ + TensorSpec("hidden_states", [batch, hidden_size], torch.bfloat16, + init_value=init_hidden_states), + TensorSpec("input_rms_weight", [1, hidden_size], torch.float32, + init_value=init_rms_weight), + TensorSpec("wq", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wq), + TensorSpec("wk", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wk), + TensorSpec("wv", [hidden_size, kv_hidden], torch.bfloat16, + init_value=init_wv), + TensorSpec("q_norm_weight", [1, head_dim], torch.float32, + init_value=init_q_norm_weight), + TensorSpec("k_norm_weight", [1, head_dim], torch.float32, + init_value=init_k_norm_weight), + TensorSpec("seq_lens", [batch], torch.int32, init_value=init_seq_lens), + TensorSpec("block_table", [batch * max_blocks_per_seq], torch.int32, + init_value=init_block_table), + TensorSpec("slot_mapping", [batch], torch.int32, + init_value=init_slot_mapping), + TensorSpec("rope_cos", [max_seq, head_dim], torch.float32, + init_value=init_rope_cos), + TensorSpec("rope_sin", [max_seq, head_dim], torch.float32, + init_value=init_rope_sin), + TensorSpec("k_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_k_cache), + TensorSpec("v_cache", [cache_rows, head_dim], torch.bfloat16, + init_value=init_v_cache), + TensorSpec("wo", [hidden_size, hidden_size], torch.bfloat16, + init_value=init_wo), + TensorSpec("post_rms_weight", [1, hidden_size], torch.float32, + init_value=init_post_rms_weight), + TensorSpec("w_gate", [hidden_size, inter], torch.bfloat16, + init_value=init_w_gate), + TensorSpec("w_up", [hidden_size, inter], torch.bfloat16, + init_value=init_w_up), + TensorSpec("w_down", [inter, hidden_size], torch.bfloat16, + init_value=init_w_down), + TensorSpec("out", [batch, hidden], torch.bfloat16, is_output=True), + ] + + +def golden_qwen3_decode(tensors): + """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" + import math + + import torch + + hidden_states = tensors["hidden_states"] + input_rms_weight = tensors["input_rms_weight"] + wq = tensors["wq"] + wk = tensors["wk"] + wv = tensors["wv"] + q_norm_weight = tensors["q_norm_weight"] + k_norm_weight = tensors["k_norm_weight"] + seq_lens = tensors["seq_lens"] + block_table = tensors["block_table"] + slot_mapping = tensors["slot_mapping"] + rope_cos = tensors["rope_cos"] + rope_sin = tensors["rope_sin"] + k_cache = tensors["k_cache"].clone() + v_cache = tensors["v_cache"].clone() + wo = tensors["wo"] + post_rms_weight = tensors["post_rms_weight"] + w_gate = tensors["w_gate"] + w_up = tensors["w_up"] + w_down = tensors["w_down"] + + batch = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + kv_hidden = wk.shape[1] + head_dim = rope_cos.shape[1] + max_seq = rope_cos.shape[0] + num_kv_heads = kv_hidden // head_dim + num_heads = hidden_size // head_dim + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + half = head_dim // 2 + scale = 1.0 / math.sqrt(head_dim) + eps = 1e-6 + max_ctx_blocks = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + + def tiled_matmul(lhs, rhs, k_chunk, n_chunk): + out = torch.zeros(lhs.shape[0], rhs.shape[1], dtype=torch.float32) + for n0 in range(0, rhs.shape[1], n_chunk): + acc = torch.zeros(lhs.shape[0], n_chunk, dtype=torch.float32) + for k0 in range(0, lhs.shape[1], k_chunk): + acc = acc + lhs[:, k0 : k0 + k_chunk].float() @ rhs[ + k0 : k0 + k_chunk, + n0 : n0 + n_chunk, + ].float() + out[:, n0 : n0 + n_chunk] = acc + return out + + def chunked_row_sq_sum(x, k_chunk): + acc = torch.zeros(x.shape[0], 1, dtype=torch.float32) + for k0 in range(0, x.shape[1], k_chunk): + x_chunk = x[:, k0 : k0 + k_chunk] + acc = acc + (x_chunk * x_chunk).sum(dim=-1, keepdim=True) + return acc + + q_proj = torch.zeros(batch, hidden_size, dtype=torch.float32) + k_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + v_proj = torch.zeros(batch, kv_hidden, dtype=torch.float32) + + for b0 in range(0, batch, BATCH_TILE): + b_end = min(b0 + BATCH_TILE, batch) + x_tile = hidden_states[b0:b_end, :].float() + + sq_sum = torch.zeros(b_end - b0, 1, dtype=torch.float32) + for k0 in range(0, hidden_size, INPUT_PROJ_K_CHUNK): + x_chunk = x_tile[:, k0:k0 + INPUT_PROJ_K_CHUNK] + sq_sum = sq_sum + (x_chunk ** 2).sum(dim=-1, keepdim=True) + variance = sq_sum / hidden_size + EPS + rms = torch.sqrt(variance) + normed = (x_tile / rms * input_rms_weight.float()).bfloat16() + + q_proj[b0:b_end, :] = tiled_matmul(normed, wq, INPUT_PROJ_K_CHUNK, Q_OUT_CHUNK) + k_proj[b0:b_end, :] = tiled_matmul(normed, wk, KV_PROJ_K_CHUNK, KV_OUT_CHUNK) + v_proj[b0:b_end, :] = tiled_matmul(normed, wv, KV_PROJ_K_CHUNK, KV_OUT_CHUNK) + + attn_out = torch.zeros(batch, hidden_size, dtype=torch.bfloat16) + + for b in range(batch): + ctx_len = seq_lens[b].item() + pos = ctx_len - 1 + ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE + + cos_row = rope_cos[pos : pos + 1, :] + sin_row = rope_sin[pos : pos + 1, :] + cos_lo, cos_hi = cos_row[:, :half], cos_row[:, half:] + sin_lo, sin_hi = sin_row[:, :half], sin_row[:, half:] + + k_heads = k_proj[b].view(num_kv_heads, head_dim) + k_variance = k_heads.pow(2).mean(dim=-1, keepdim=True) + k_heads = k_heads * torch.rsqrt(k_variance + eps) * k_norm_weight.float() + k_lo_h, k_hi_h = k_heads[:, :half], k_heads[:, half:] + k_rot = torch.cat( + [k_lo_h * cos_lo - k_hi_h * sin_lo, k_hi_h * cos_hi + k_lo_h * sin_hi], + dim=-1, + ) + slot = int(slot_mapping[b].item()) + slot_block = slot // BLOCK_SIZE + slot_offset = slot % BLOCK_SIZE + + for ki in range(num_kv_heads): + cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset + k_cache[cache_row, :] = k_rot[ki].to(torch.bfloat16) + v_cache[cache_row, :] = v_proj[b, ki * head_dim : (ki + 1) * head_dim].to(torch.bfloat16) + + q_heads = q_proj[b].view(num_heads, head_dim) + q_variance = q_heads.pow(2).mean(dim=-1, keepdim=True) + q_heads = q_heads * torch.rsqrt(q_variance + eps) * q_norm_weight.float() + q_lo_h, q_hi_h = q_heads[:, :half], q_heads[:, half:] + q_rot = torch.cat( + [q_lo_h * cos_lo - q_hi_h * sin_lo, q_hi_h * cos_hi + q_lo_h * sin_hi], + dim=-1, + ) + + attn_row = torch.zeros(1, hidden_size, dtype=torch.bfloat16) + for kvh in range(num_kv_heads): + for qg in range(q_groups): + gi = kvh * q_groups + qg + q_base = kvh * q_per_kv + qg * Q_HEAD_BATCH + q_grp_bf16 = q_rot[q_base : q_base + Q_HEAD_BATCH, :].to(torch.bfloat16) + + oi = torch.zeros(Q_HEAD_BATCH, head_dim, dtype=torch.float32) + li = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + mi = torch.zeros(Q_HEAD_BATCH, 1, dtype=torch.float32) + + for sb in range(ctx_blocks): + s0 = sb * BLOCK_SIZE + valid_len = min(BLOCK_SIZE, ctx_len - s0) + pbid = int(block_table[b * max_ctx_blocks + sb].item()) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + k_tile = k_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + v_tile = v_cache[cache_row0 : cache_row0 + BLOCK_SIZE, :] + + raw_scores = q_grp_bf16.float() @ k_tile.float().T + if valid_len < BLOCK_SIZE: + raw_scores[:, valid_len:] = torch.finfo(torch.float32).min + scores = raw_scores * scale + cur_mi = scores.max(dim=-1, keepdim=True).values + exp_scores = torch.exp(scores - cur_mi) + exp_scores_bf16 = exp_scores.to(torch.bfloat16) + cur_li = exp_scores_bf16.float().sum(dim=-1, keepdim=True) + oi_tmp = exp_scores_bf16.float() @ v_tile.float() + + if sb == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = torch.maximum(mi, cur_mi) + alpha = torch.exp(mi - mi_new) + beta = torch.exp(cur_mi - mi_new) + li = alpha * li + beta * cur_li + oi = oi * alpha + oi_tmp * beta + mi = mi_new + + ctx = oi / li + ctx_flat_bf16 = ctx.reshape(1, -1).to(torch.bfloat16) + attn_row[ + :, + q_base * head_dim : (q_base + Q_HEAD_BATCH) * head_dim, + ] = ctx_flat_bf16 + + attn_out[b : b + 1, :] = attn_row + + o_proj = tiled_matmul(attn_out, wo, OUT_PROJ_K_CHUNK, OUT_PROJ_N_CHUNK) + resid1 = o_proj + hidden_states.float() + + variance = chunked_row_sq_sum(resid1, K_CHUNK) / hidden_size + inv_rms = torch.rsqrt(variance + eps) + normed_bf16 = (resid1 * inv_rms * post_rms_weight).bfloat16() + + gate = tiled_matmul(normed_bf16, w_gate, K_CHUNK, MLP_OUT_CHUNK) + up = tiled_matmul(normed_bf16, w_up, K_CHUNK, MLP_OUT_CHUNK) + mlp_bf16 = (gate * torch.sigmoid(gate) * up).bfloat16() + down = tiled_matmul(mlp_bf16, w_down, DOWN_MLP_CHUNK, DOWN_OUT_CHUNK) + + tensors["out"][:] = (down + resid1).bfloat16() + + +if __name__ == "__main__": + import argparse + import sys + from golden import RunConfig, run + + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--platform", type=str, default="a2a3", + choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + parser.add_argument("-d", "--device", type=int, default=0) + parser.add_argument("-b", "--batch", type=int, default=BATCH, + help=("User-visible batch size. Host allocates every " + "batch-dependent tensor at exactly this size; " + "the kernel internally rounds up to BATCH_TILE " + "(%d), zero-pads input loads via valid_shape, " + "and trims the BF16 output via vec-to-vec " + "textract. A single compiled program serves " + "any batch <= host KV-cache capacity. Default: " + "%%(default)s" % BATCH_TILE)) + parser.add_argument("--enable-l2-swimlane", action="store_true", default=False) + parser.add_argument("--max-seq", action="store_true", default=False) + parser.add_argument("--enable-pmu", nargs="?", const=2, default=0, type=int, choices=[0, 1, 2, 4]) + parser.add_argument( + "--export-kernel-insight", + action="store_true", + default=False, + help=( + "After a successful run, export msprof op-simulator Insight traces " + "for all generated InCore kernels under the same build_output dir." + ), + ) + parser.add_argument( + "--kernel-insight-func", + action="append", + default=[], + help="Only export this generated kernel function; can be repeated.", + ) + args = parser.parse_args() + + result = run( + program=build_qwen3_decode_program(batch=args.batch), + specs=build_tensor_specs(batch=args.batch, use_max_seq=args.max_seq), + golden_fn=golden_qwen3_decode, + config=RunConfig( + rtol=3e-3, + atol=3e-3, + compile=dict(dump_passes=True), + runtime=dict( + platform=args.platform, + device_id=args.device, + enable_l2_swimlane=args.enable_l2_swimlane, + enable_pmu=args.enable_pmu, + ), + ), + ) + if not result.passed: + if result.error: + print(result.error) + raise SystemExit(1) + + if args.export_kernel_insight: + if result.work_dir is None: + print("kernel insight export failed: run result has no build_output directory", file=sys.stderr) + raise SystemExit(1) + from tools.export_all_kernel_insight import StepError, main as export_kernel_insight + + export_args = ["--build-dir", str(result.work_dir)] + for func in args.kernel_insight_func: + export_args.extend(["--func", func]) + try: + export_rc = export_kernel_insight(export_args) + except StepError as exc: + print(f"kernel insight export failed: {exc}", file=sys.stderr) + raise SystemExit(1) from exc + if export_rc != 0: + raise SystemExit(export_rc)