From af1562474aabe44add550ac1b4f44bbe525a14ba Mon Sep 17 00:00:00 2001 From: Xu Zhenghao <115647249+xzhxzhxzh123@users.noreply.github.com> Date: Wed, 20 May 2026 15:33:42 +0800 Subject: [PATCH] Update decode_layer.py --- models/qwen3/14b/decode_layer.py | 1218 +++++++++++++++--------------- 1 file changed, 630 insertions(+), 588 deletions(-) diff --git a/models/qwen3/14b/decode_layer.py b/models/qwen3/14b/decode_layer.py index 960e010..b7c6a91 100644 --- a/models/qwen3/14b/decode_layer.py +++ b/models/qwen3/14b/decode_layer.py @@ -6,29 +6,11 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""Qwen3-14B single-layer decode forward. +"""Qwen3-14B single-layer decode forward (SPMD RMSNorm variant). -Scope 1: - 1. RMSNorm of input hidden states - 2. Q/K/V projection via matmul - -Per-head q_norm / k_norm - -Scope 2: - 1. K RoPE + paged cache write, V paged cache write, Q RoPE + pad - 2. QK matmul - 3. Softmax - 4. SV matmul - 5. Online-softmax accumulation + final normalisation - -Scope 3: - 1. Output projection: attn_out × wo - 2. Residual addition with hidden_states - 3. Post-attention RMSNorm - 4. MLP: gate/up projections, SiLU activation, down projection - 5. Final residual addition - -The final RMSNorm and LM head projection live in rms_lm_head.py. +SPMD diff vs decode_layer.py: + 1. Two InCore kernels: rmsnorm_kernel / post_rmsnorm_kernel + 2. pl.spmd(2) replaces pl.at(CORE_GROUP, "rmsnorm") / pl.at(CORE_GROUP, "post_rmsnorm") """ # pyright: reportUndefinedVariable=false @@ -57,9 +39,6 @@ KV_HIDDEN, KV_OUT_CHUNK, KV_PROJ_K_CHUNK, - LAYER_DYN, - LAYER_HIDDEN_ROWS_DYN, - LAYER_INTER_ROWS_DYN, LM_HEAD_K_CHUNK, MAX_BLOCKS_PER_SEQ, MAX_SEQ, @@ -79,580 +58,652 @@ VOCAB, VOCAB_CHUNK, ) -from rms_lm_head import rms_lm_head - - -@pl.jit.inline -def decode_layer( - current_hidden: pl.Tensor[[BATCH, HIDDEN], pl.BF16], - input_rms_weight: pl.Tensor[[LAYER_DYN, HIDDEN], pl.FP32], - wq: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, HIDDEN], pl.BF16], - wk: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, KV_HIDDEN], pl.BF16], - wv: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, KV_HIDDEN], pl.BF16], - q_norm_weight: pl.Tensor[[LAYER_DYN, HEAD_DIM], pl.FP32], - k_norm_weight: pl.Tensor[[LAYER_DYN, HEAD_DIM], pl.FP32], - seq_lens: pl.Tensor[[USER_BATCH_DYN], pl.INT32], - block_table: pl.Tensor[[BLOCK_TABLE_FLAT_DYN], pl.INT32], - slot_mapping: pl.Tensor[[USER_BATCH_DYN], pl.INT32], - rope_cos: pl.Tensor[[ROPE_SEQ_DYN, HEAD_DIM], pl.FP32], - rope_sin: pl.Tensor[[ROPE_SEQ_DYN, HEAD_DIM], pl.FP32], - k_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, HEAD_DIM], pl.BF16], - v_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, HEAD_DIM], pl.BF16], - wo: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, HIDDEN], pl.BF16], - post_rms_weight: pl.Tensor[[LAYER_DYN, HIDDEN], pl.FP32], - w_gate: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, INTERMEDIATE], pl.BF16], - w_up: pl.Tensor[[LAYER_HIDDEN_ROWS_DYN, INTERMEDIATE], pl.BF16], - w_down: pl.Tensor[[LAYER_INTER_ROWS_DYN, HIDDEN], pl.BF16], - next_hidden: pl.Tensor[[BATCH, HIDDEN], pl.BF16], - layer_idx: pl.Scalar[pl.INT32], -) -> pl.Tensor[[BATCH, HIDDEN], pl.BF16]: - decode_scope1_hidden_blocks = HIDDEN // INPUT_PROJ_K_CHUNK - hidden_blocks = HIDDEN // K_CHUNK - decode_q_out_blocks = HIDDEN // Q_OUT_CHUNK - decode_mlp_out_blocks = INTERMEDIATE // MLP_OUT_CHUNK + +# SPMD tiling constants +RMSNORM_SPMD_CORES = 2 +RMSNORM_SPMD_ROWS = BATCH_TILE // RMSNORM_SPMD_CORES + + +def build_qwen3_decode_program( + batch: int = BATCH, + max_seq: int = MAX_SEQ, + hidden_size: int = HIDDEN, + intermediate_size: int = INTERMEDIATE, + num_heads: int = NUM_HEADS, + num_kv_heads: int = NUM_KV_HEADS, + head_dim: int = HEAD_DIM, + vocab_size: int = VOCAB, +): + hidden = hidden_size + kv_hidden = num_kv_heads * head_dim + inter = intermediate_size + vocab = vocab_size + decode_scope1_hidden_blocks = hidden // INPUT_PROJ_K_CHUNK + hidden_blocks = hidden // K_CHUNK + decode_q_out_blocks = hidden // Q_OUT_CHUNK + decode_mlp_out_blocks = inter // MLP_OUT_CHUNK + max_blocks_per_seq = (max_seq + BLOCK_SIZE - 1) // BLOCK_SIZE + half_dim = head_dim // 2 head_dim_inv = HEAD_DIM_INV decode_attn_scale = ATTN_SCALE - num_layers_actual = pl.tensor.dim(input_rms_weight, 0) - decode_layer_cache_rows = pl.tensor.dim(k_cache, 0) // num_layers_actual - user_batch = pl.tensor.dim(seq_lens, 0) - batch_padded = BATCH - layer_hidden_base = layer_idx * HIDDEN - layer_inter_base = layer_idx * INTERMEDIATE - layer_cache_base = layer_idx * decode_layer_cache_rows - - # Intermediate FP32 tensors between scope 1 and scope 2. - q_proj = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) - k_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) - v_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) - q_proj_norm = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) - k_proj_norm = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) - - # Scope 1: input RMSNorm + Q/K/V projection. - # The JIT inline path follows the fixed-BATCH single-layer kernel - # contract, so every matmul tile has a static M dim of BATCH_TILE. - for b0 in pl.parallel(0, batch_padded, BATCH_TILE): - normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rmsnorm"): - partial_sq = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.range(decode_scope1_hidden_blocks): - sq_k0 = kb * INPUT_PROJ_K_CHUNK - sq_chunk = pl.cast( + q_per_kv = num_heads // num_kv_heads + q_groups = q_per_kv // Q_HEAD_BATCH + total_q_groups = num_kv_heads * q_groups + max_ctx_blocks = max_blocks_per_seq + + @pl.program + class Qwen3Decode: + @pl.function(type=pl.FunctionType.InCore) + def rmsnorm_kernel( + self, + current_hidden: pl.Tensor[[BATCH, hidden], pl.BF16], + b0: pl.Scalar[pl.INDEX], + cur_valid: pl.Scalar[pl.INDEX], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + normed_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + block_idx = pl.tile.get_block_idx() + row_start = block_idx * RMSNORM_SPMD_ROWS + local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) + + partial_sq = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(decode_scope1_hidden_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( pl.slice( current_hidden, - [BATCH_TILE, INPUT_PROJ_K_CHUNK], - [b0, sq_k0], + [RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK], + [b0 + row_start, k0], + valid_shape=[local_valid, INPUT_PROJ_K_CHUNK], ), target_type=pl.FP32, ) partial_sq = pl.add( partial_sq, - pl.reshape(pl.row_sum(pl.mul(sq_chunk, sq_chunk)), [1, BATCH_TILE]), + pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, RMSNORM_SPMD_ROWS]), ) variance = pl.reshape( pl.add(pl.mul(partial_sq, HIDDEN_INV), EPS), - [BATCH_TILE, 1], + [RMSNORM_SPMD_ROWS, 1], ) inv_rms = pl.recip(pl.sqrt(variance)) - for kb in pl.range(decode_scope1_hidden_blocks): - norm_k0 = kb * INPUT_PROJ_K_CHUNK - norm_chunk = pl.cast( + for kb in pl.pipeline(decode_scope1_hidden_blocks, stage=4): + k0 = kb * INPUT_PROJ_K_CHUNK + x_chunk = pl.cast( pl.slice( current_hidden, - [BATCH_TILE, INPUT_PROJ_K_CHUNK], - [b0, norm_k0], + [RMSNORM_SPMD_ROWS, INPUT_PROJ_K_CHUNK], + [b0 + row_start, k0], + valid_shape=[local_valid, INPUT_PROJ_K_CHUNK], ), target_type=pl.FP32, ) - gamma = pl.slice(input_rms_weight, [1, INPUT_PROJ_K_CHUNK], [layer_idx, norm_k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(norm_chunk, inv_rms), gamma) + gamma = pl.slice(input_rms_weight, [1, INPUT_PROJ_K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) normed_tile = pl.assemble( normed_tile, pl.cast(normed, target_type=pl.BF16), - [0, norm_k0], - ) - - for q0 in pl.parallel(0, HIDDEN, Q_OUT_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"): - q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - for kb in pl.range(decode_scope1_hidden_blocks): - q_k0 = kb * INPUT_PROJ_K_CHUNK - q_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, q_k0]) - q_tile_b = pl.slice(wq, [INPUT_PROJ_K_CHUNK, Q_OUT_CHUNK], [layer_hidden_base + q_k0, q0]) - if q_k0 == 0: - q_acc = pl.matmul(q_tile_a, q_tile_b, out_dtype=pl.FP32) - else: - q_acc = pl.matmul_acc(q_acc, q_tile_a, q_tile_b) - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) - - for kv0 in pl.parallel(0, KV_HIDDEN, KV_OUT_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="k_proj"): - k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - for kb in pl.range(decode_scope1_hidden_blocks): - k_k0 = kb * INPUT_PROJ_K_CHUNK - k_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, k_k0]) - k_tile_b = pl.slice(wk, [INPUT_PROJ_K_CHUNK, KV_OUT_CHUNK], [layer_hidden_base + k_k0, kv0]) - if k_k0 == 0: - k_acc = pl.matmul(k_tile_a, k_tile_b, out_dtype=pl.FP32) - else: - k_acc = pl.matmul_acc(k_acc, k_tile_a, k_tile_b) - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="v_proj"): - v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - for kb in pl.range(decode_scope1_hidden_blocks): - v_k0 = kb * INPUT_PROJ_K_CHUNK - v_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, v_k0]) - v_tile_b = pl.slice(wv, [INPUT_PROJ_K_CHUNK, KV_OUT_CHUNK], [layer_hidden_base + v_k0, kv0]) - if v_k0 == 0: - v_acc = pl.matmul(v_tile_a, v_tile_b, out_dtype=pl.FP32) - else: - v_acc = pl.matmul_acc(v_acc, v_tile_a, v_tile_b) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) - - # HF-style per-head q_norm / k_norm before RoPE, matching the original - # single-layer qwen3_decode grouping by KV head. - for b0 in pl.parallel(0, batch_padded, BATCH_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_norm"): - for h in pl.range(NUM_KV_HEADS): - q0 = h * Q_PER_KV * HEAD_DIM - q_chunk = pl.reshape( - pl.slice(q_proj, [BATCH_TILE, Q_HEAD_BATCH * HEAD_DIM], [b0, q0]), - [BATCH_TILE * Q_HEAD_BATCH, HEAD_DIM], - ) - q_sq_sum = pl.row_sum(pl.mul(q_chunk, q_chunk)) - q_inv_rms = pl.rsqrt(pl.add(pl.mul(q_sq_sum, head_dim_inv), EPS)) - q_chunk_norm = pl.col_expand_mul( - pl.row_expand_mul(q_chunk, q_inv_rms), - pl.slice(q_norm_weight, [1, HEAD_DIM], [layer_idx, 0]), - ) - q_chunk_norm_flat = pl.reshape(q_chunk_norm, [BATCH_TILE, Q_HEAD_BATCH * HEAD_DIM]) - q_proj_norm = pl.assemble(q_proj_norm, q_chunk_norm_flat, [b0, q0]) - - k0 = h * HEAD_DIM - k_chunk = pl.slice(k_proj, [BATCH_TILE, HEAD_DIM], [b0, k0]) - k_sq_sum = pl.row_sum(pl.mul(k_chunk, k_chunk)) - k_inv_rms = pl.rsqrt(pl.add(pl.mul(k_sq_sum, head_dim_inv), EPS)) - k_chunk_norm = pl.col_expand_mul( - pl.row_expand_mul(k_chunk, k_inv_rms), - pl.slice(k_norm_weight, [1, HEAD_DIM], [layer_idx, 0]), - ) - k_proj_norm = pl.assemble(k_proj_norm, k_chunk_norm, [b0, k0]) - - # Scope 2: RoPE + KV cache update + grouped decode attention. - # This follows the original single-layer qwen3_decode paired-gi - # attention structure, with layer_cache_base added for full-model caches. - attn_out = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) - all_q_padded = pl.create_tensor( - [BATCH * TOTAL_Q_GROUPS * Q_HEAD_PAD, HEAD_DIM], dtype=pl.BF16, - ) - - # Scope 2 only touches runtime-visible rows; padded rows stay zero. - for b in pl.parallel(user_batch): - ctx_len = pl.tensor.read(seq_lens, [b]) - pos = ctx_len - 1 - ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE - block_table_base = b * MAX_BLOCKS_PER_SEQ - slot = pl.tensor.read(slot_mapping, [b]) - slot_block = slot // BLOCK_SIZE - slot_offset = slot - slot_block * BLOCK_SIZE - cos_row = pl.slice(rope_cos, [1, HEAD_DIM], [pos, 0]) - sin_row = pl.slice(rope_sin, [1, HEAD_DIM], [pos, 0]) - cos_lo = pl.slice(cos_row, [1, HALF_DIM], [0, 0]) - cos_hi = pl.slice(cos_row, [1, HALF_DIM], [0, HALF_DIM]) - sin_lo = pl.slice(sin_row, [1, HALF_DIM], [0, 0]) - sin_hi = pl.slice(sin_row, [1, HALF_DIM], [0, HALF_DIM]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_kv_cache"): - for ki in pl.range(NUM_KV_HEADS): - kv_col = ki * HEAD_DIM - cache_row = layer_cache_base + (slot_block * NUM_KV_HEADS + ki) * BLOCK_SIZE + slot_offset - k_lo = pl.slice(k_proj_norm, [1, HALF_DIM], [b, kv_col]) - k_hi = pl.slice(k_proj_norm, [1, HALF_DIM], [b, kv_col + HALF_DIM]) - rot_lo = pl.sub( - pl.col_expand_mul(k_lo, cos_lo), - pl.col_expand_mul(k_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(k_hi, cos_hi), - pl.col_expand_mul(k_lo, sin_hi), - ) - k_cache = pl.assemble( - k_cache, - pl.cast(rot_lo, target_type=pl.BF16), - [cache_row, 0], - ) - k_cache = pl.assemble( - k_cache, - pl.cast(rot_hi, target_type=pl.BF16), - [cache_row, HALF_DIM], - ) - v_cache = pl.assemble( - v_cache, - pl.cast( - pl.slice(v_proj, [1, HEAD_DIM], [b, kv_col]), - target_type=pl.BF16, - ), - [cache_row, 0], - ) - q_base = ki * Q_PER_KV - q_block = pl.reshape( - pl.slice(q_proj_norm, [1, Q_HEAD_BATCH * HEAD_DIM], [b, q_base * HEAD_DIM]), - [Q_HEAD_BATCH, HEAD_DIM], - ) - q_lo = pl.slice(q_block, [Q_HEAD_BATCH, HALF_DIM], [0, 0]) - q_hi = pl.slice(q_block, [Q_HEAD_BATCH, HALF_DIM], [0, HALF_DIM]) - rot_lo_bf16 = pl.cast( - pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), - target_type=pl.BF16, - ) - rot_hi_bf16 = pl.cast( - pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), - target_type=pl.BF16, + [row_start, k0], ) - all_q_padded = pl.assemble( - all_q_padded, - rot_lo_bf16, - [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD, 0], - ) - all_q_padded = pl.assemble( - all_q_padded, - rot_hi_bf16, - [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD, HALF_DIM], - ) - all_q_padded = pl.assemble( - all_q_padded, - pl.cast( - pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, HEAD_DIM], dtype=pl.FP32, value=0.0), - target_type=pl.BF16, - ), - [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD + Q_HEAD_BATCH, 0], - ) - - attn_row = pl.create_tensor([1, HIDDEN], dtype=pl.BF16) - for gi in pl.parallel(0, TOTAL_Q_GROUPS, 2): - gi0 = gi - gi1 = gi + 1 - - kvh0 = gi0 // Q_GROUPS - qg0 = gi0 - kvh0 * Q_GROUPS - q_base0 = kvh0 * Q_PER_KV + qg0 * Q_HEAD_BATCH - q_padded_row0 = b * TOTAL_Q_GROUPS * Q_HEAD_PAD + gi0 * Q_HEAD_PAD - q_padded0 = pl.slice(all_q_padded, [Q_HEAD_PAD, HEAD_DIM], [q_padded_row0, 0]) - - kvh1 = gi1 // Q_GROUPS - qg1 = gi1 - kvh1 * Q_GROUPS - q_base1 = kvh1 * Q_PER_KV + qg1 * Q_HEAD_BATCH - q_padded_row1 = b * TOTAL_Q_GROUPS * Q_HEAD_PAD + gi1 * Q_HEAD_PAD - q_padded1 = pl.slice(all_q_padded, [Q_HEAD_PAD, HEAD_DIM], [q_padded_row1, 0]) - - all_raw_scores0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) - all_raw_scores1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_matmul"): - for sb in pl.range(ctx_blocks): - qk_block_table_idx = block_table_base + sb - qk_pbid = pl.cast(pl.tensor.read(block_table, [qk_block_table_idx]), pl.INDEX) - - qk_cache_row0 = layer_cache_base + (qk_pbid * NUM_KV_HEADS + kvh0) * BLOCK_SIZE - k_tile0 = pl.slice(k_cache, [BLOCK_SIZE, HEAD_DIM], [qk_cache_row0, 0]) - raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32) - all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) - - qk_cache_row1 = layer_cache_base + (qk_pbid * NUM_KV_HEADS + kvh1) * BLOCK_SIZE - k_tile1 = pl.slice(k_cache, [BLOCK_SIZE, HEAD_DIM], [qk_cache_row1, 0]) - raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) - all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) - - all_exp_padded0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) - all_exp_padded1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) - all_cur_mi0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) - all_cur_mi1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) - all_cur_li0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) - all_cur_li1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax"): - for sb in pl.range(ctx_blocks): - s0 = sb * BLOCK_SIZE - valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) - - scores_valid0 = pl.slice( - all_raw_scores0, - [Q_HEAD_PAD, BLOCK_SIZE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_PAD, valid_len], - ) - scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) - scores0 = pl.mul(scores_padded0, decode_attn_scale) - softmax_cur_mi0 = pl.row_max(scores0) - exp_scores0 = pl.exp(pl.row_expand_sub(scores0, softmax_cur_mi0)) - exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16) - exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32) - softmax_cur_li0 = pl.row_sum(exp_scores_fp32_0) - all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0]) - all_cur_mi0 = pl.assemble(all_cur_mi0, softmax_cur_mi0, [sb * Q_HEAD_PAD, 0]) - all_cur_li0 = pl.assemble(all_cur_li0, softmax_cur_li0, [sb * Q_HEAD_PAD, 0]) - - scores_valid1 = pl.slice( - all_raw_scores1, - [Q_HEAD_PAD, BLOCK_SIZE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_PAD, valid_len], - ) - scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min) - scores1 = pl.mul(scores_padded1, decode_attn_scale) - softmax_cur_mi1 = pl.row_max(scores1) - exp_scores1 = pl.exp(pl.row_expand_sub(scores1, softmax_cur_mi1)) - exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16) - exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32) - softmax_cur_li1 = pl.row_sum(exp_scores_fp32_1) - all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0]) - all_cur_mi1 = pl.assemble(all_cur_mi1, softmax_cur_mi1, [sb * Q_HEAD_PAD, 0]) - all_cur_li1 = pl.assemble(all_cur_li1, softmax_cur_li1, [sb * Q_HEAD_PAD, 0]) - - all_oi_tmp0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, HEAD_DIM], dtype=pl.FP32) - all_oi_tmp1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, HEAD_DIM], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="sv_matmul"): - for sb in pl.range(ctx_blocks): - sv_block_table_idx = block_table_base + sb - sv_pbid = pl.cast(pl.tensor.read(block_table, [sv_block_table_idx]), pl.INDEX) - - sv_cache_row0 = layer_cache_base + (sv_pbid * NUM_KV_HEADS + kvh0) * BLOCK_SIZE - exp_tile0 = pl.slice(all_exp_padded0, [Q_HEAD_PAD, BLOCK_SIZE], [sb * Q_HEAD_PAD, 0]) - v_tile0 = pl.slice(v_cache, [BLOCK_SIZE, HEAD_DIM], [sv_cache_row0, 0]) - oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32) - all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0]) - - sv_cache_row1 = layer_cache_base + (sv_pbid * NUM_KV_HEADS + kvh1) * BLOCK_SIZE - exp_tile1 = pl.slice(all_exp_padded1, [Q_HEAD_PAD, BLOCK_SIZE], [sb * Q_HEAD_PAD, 0]) - v_tile1 = pl.slice(v_cache, [BLOCK_SIZE, HEAD_DIM], [sv_cache_row1, 0]) - oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32) - all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="online_softmax"): - oi0 = pl.slice(all_oi_tmp0, [Q_HEAD_PAD, HEAD_DIM], [0, 0]) - mi0 = pl.slice(all_cur_mi0, [Q_HEAD_PAD, 1], [0, 0]) - li0 = pl.slice(all_cur_li0, [Q_HEAD_PAD, 1], [0, 0]) - oi1 = pl.slice(all_oi_tmp1, [Q_HEAD_PAD, HEAD_DIM], [0, 0]) - mi1 = pl.slice(all_cur_mi1, [Q_HEAD_PAD, 1], [0, 0]) - li1 = pl.slice(all_cur_li1, [Q_HEAD_PAD, 1], [0, 0]) - for sb in pl.range(1, ctx_blocks): - oi_tmp_valid0 = pl.slice(all_oi_tmp0, [Q_HEAD_PAD, HEAD_DIM], [sb * Q_HEAD_PAD, 0]) - online_cur_mi0 = pl.slice(all_cur_mi0, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) - online_cur_li0 = pl.slice(all_cur_li0, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) - mi_new0 = pl.maximum(mi0, online_cur_mi0) - alpha0 = pl.exp(pl.sub(mi0, mi_new0)) - beta0 = pl.exp(pl.sub(online_cur_mi0, mi_new0)) - li0 = pl.add(pl.mul(alpha0, li0), pl.mul(beta0, online_cur_li0)) - oi0 = pl.add(pl.row_expand_mul(oi0, alpha0), pl.row_expand_mul(oi_tmp_valid0, beta0)) - mi0 = mi_new0 - - oi_tmp_valid1 = pl.slice(all_oi_tmp1, [Q_HEAD_PAD, HEAD_DIM], [sb * Q_HEAD_PAD, 0]) - online_cur_mi1 = pl.slice(all_cur_mi1, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) - online_cur_li1 = pl.slice(all_cur_li1, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) - mi_new1 = pl.maximum(mi1, online_cur_mi1) - alpha1 = pl.exp(pl.sub(mi1, mi_new1)) - beta1 = pl.exp(pl.sub(online_cur_mi1, mi_new1)) - li1 = pl.add(pl.mul(alpha1, li1), pl.mul(beta1, online_cur_li1)) - oi1 = pl.add(pl.row_expand_mul(oi1, alpha1), pl.row_expand_mul(oi_tmp_valid1, beta1)) - mi1 = mi_new1 - - ctx0 = pl.row_expand_div(oi0, li0) - ctx_valid0 = pl.slice(ctx0, [Q_HEAD_BATCH, HEAD_DIM], [0, 0]) - ctx_flat_bf16_0 = pl.cast(pl.reshape(ctx_valid0, [1, Q_HEAD_BATCH * HEAD_DIM]), target_type=pl.BF16) - attn_row = pl.assemble(attn_row, ctx_flat_bf16_0, [0, q_base0 * HEAD_DIM]) - - ctx1 = pl.row_expand_div(oi1, li1) - ctx_valid1 = pl.slice(ctx1, [Q_HEAD_BATCH, HEAD_DIM], [0, 0]) - ctx_flat_bf16_1 = pl.cast(pl.reshape(ctx_valid1, [1, Q_HEAD_BATCH * HEAD_DIM]), target_type=pl.BF16) - attn_row = pl.assemble(attn_row, ctx_flat_bf16_1, [0, q_base1 * HEAD_DIM]) - - attn_out = pl.assemble(attn_out, attn_row, [b, 0]) - - # Scope 3: output projection + residual + post RMSNorm + MLP + residual. - # Loops over batch_padded so every iteration processes a full - # [BATCH_TILE, *] tile (a2a3 matmul M-tile constraint). - # Final down-proj + residual + cast uses the two-incore pattern - # validated in dynamic_batch_pad_repro: - # cube incore : matmul_acc -> FP32 -> assemble to GM scratch - # vec incore : tload FP32 chunk -> add FP32 resid -> cast BF16 - # (preserves ND layout) -> assemble to out - for b0 in pl.parallel(0, batch_padded, BATCH_TILE): - resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32) - - for ob in pl.range(decode_q_out_blocks): - o0 = ob * Q_OUT_CHUNK - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"): - a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0]) - w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [layer_hidden_base, o0]) - o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]) - w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [layer_hidden_base + k0, o0]) - o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"): - resid = pl.cast( - pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), - target_type=pl.FP32, + return normed_tile + + @pl.function(type=pl.FunctionType.InCore) + def post_rmsnorm_kernel( + self, + resid1_tile: pl.Tensor[[BATCH_TILE, hidden], pl.FP32], + cur_valid: pl.Scalar[pl.INDEX], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + post_norm_tile: pl.InOut[pl.Tensor[[BATCH_TILE, hidden], pl.BF16]], + ) -> pl.Tensor[[BATCH_TILE, hidden], pl.BF16]: + block_idx = pl.tile.get_block_idx() + row_start = block_idx * RMSNORM_SPMD_ROWS + local_valid = pl.min(RMSNORM_SPMD_ROWS, cur_valid - row_start) + + sq_sum = pl.full([1, RMSNORM_SPMD_ROWS], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = pl.slice( + resid1_tile, + [RMSNORM_SPMD_ROWS, K_CHUNK], + [row_start, k0], + valid_shape=[local_valid, K_CHUNK], ) - resid_sum = pl.add(o_acc, resid) - resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) - - post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="post_rmsnorm"): - sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) - for kb in pl.range(hidden_blocks): - post_sq_k0 = kb * K_CHUNK - post_sq_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, post_sq_k0]) sq_sum = pl.add( sq_sum, - pl.reshape(pl.row_sum(pl.mul(post_sq_chunk, post_sq_chunk)), [1, BATCH_TILE]), + pl.reshape(pl.row_sum(pl.mul(resid_chunk, resid_chunk)), [1, RMSNORM_SPMD_ROWS]), ) inv_rms_s3 = pl.recip(pl.sqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS))) - for kb in pl.range(hidden_blocks): - post_norm_k0 = kb * K_CHUNK - post_norm_chunk = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, post_norm_k0]) - post_gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [layer_idx, post_norm_k0]) + for kb in pl.pipeline(hidden_blocks, stage=2): + k0 = kb * K_CHUNK + resid_chunk = pl.slice( + resid1_tile, + [RMSNORM_SPMD_ROWS, K_CHUNK], + [row_start, k0], + valid_shape=[local_valid, K_CHUNK], + ) + post_gamma = pl.slice(post_rms_weight, [1, K_CHUNK], [0, k0]) post_normed = pl.col_expand_mul( - pl.row_expand_mul(post_norm_chunk, pl.reshape(inv_rms_s3, [BATCH_TILE, 1])), + pl.row_expand_mul(resid_chunk, pl.reshape(inv_rms_s3, [RMSNORM_SPMD_ROWS, 1])), post_gamma, ) normed_bf16 = pl.cast(post_normed, target_type=pl.BF16) - post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [0, post_norm_k0]) - - mlp_tile = pl.create_tensor([BATCH_TILE, INTERMEDIATE], dtype=pl.BF16) - for ob in pl.range(decode_mlp_out_blocks): - mlp_o0 = ob * MLP_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP, name_hint="gate_proj"): - post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [layer_hidden_base, mlp_o0]) - gate_acc = pl.matmul(post_chunk_0, wg_0, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - gate_k0 = kb * K_CHUNK - gate_post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, gate_k0]) - wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [layer_hidden_base + gate_k0, mlp_o0]) - gate_acc = pl.matmul_acc(gate_acc, gate_post_chunk, wg) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="up_proj"): - post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) - wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [layer_hidden_base, mlp_o0]) - up_acc = pl.matmul(post_chunk_0, wu_0, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - up_k0 = kb * K_CHUNK - up_post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, up_k0]) - wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [layer_hidden_base + up_k0, mlp_o0]) - up_acc = pl.matmul_acc(up_acc, up_post_chunk, wu) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="silu"): - sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) - mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) - mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) - mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, mlp_o0]) - - for dob in pl.range(hidden_blocks): - d0 = dob * K_CHUNK - # FP32 GM scratch chunk used as the cube -> vec bridge. - # Per-iter [BATCH_TILE, K_CHUNK] is small (16*256*4 = - # 8 KiB) and avoids a large pre-allocated scratch. - fp32_chunk_gm = pl.create_tensor([BATCH_TILE, K_CHUNK], dtype=pl.FP32) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): - mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) - w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base, d0]) - down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) - for ob in pl.range(1, decode_mlp_out_blocks): - down_o0 = ob * MLP_OUT_CHUNK - down_mlp_chunk_bf16 = pl.slice( - mlp_tile, - [BATCH_TILE, MLP_OUT_CHUNK], - [0, down_o0], + post_norm_tile = pl.assemble(post_norm_tile, normed_bf16, [row_start, k0]) + return post_norm_tile + + @pl.function(type=pl.FunctionType.Opaque) + def qwen3_decode( + self, + hidden_states: pl.Tensor[[USER_BATCH_DYN, hidden], pl.BF16], + input_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + wq: pl.Tensor[[hidden, hidden], pl.BF16], + wk: pl.Tensor[[hidden, kv_hidden], pl.BF16], + wv: pl.Tensor[[hidden, kv_hidden], pl.BF16], + q_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + k_norm_weight: pl.Tensor[[1, head_dim], pl.FP32], + seq_lens: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + block_table: pl.Tensor[[BLOCK_TABLE_FLAT_DYN], pl.INT32], + slot_mapping: pl.Tensor[[USER_BATCH_DYN], pl.INT32], + rope_cos: pl.Tensor[[ROPE_SEQ_DYN, head_dim], pl.FP32], + rope_sin: pl.Tensor[[ROPE_SEQ_DYN, head_dim], pl.FP32], + k_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + v_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, head_dim], pl.BF16], + wo: pl.Tensor[[hidden, hidden], pl.BF16], + post_rms_weight: pl.Tensor[[1, hidden], pl.FP32], + w_gate: pl.Tensor[[hidden, inter], pl.BF16], + w_up: pl.Tensor[[hidden, inter], pl.BF16], + w_down: pl.Tensor[[inter, hidden], pl.BF16], + final_norm_weight: pl.Tensor[[1, hidden], pl.FP32], + lm_head_weight: pl.Tensor[[vocab, hidden], pl.BF16], + out: pl.Out[pl.Tensor[[USER_BATCH_DYN, vocab], pl.FP32]], + ) -> pl.Tensor[[USER_BATCH_DYN, vocab], pl.FP32]: + # === copy_hidden (from qwen3_decode_test) === # + user_batch = pl.tensor.dim(hidden_states, 0) + current_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) + for b0 in pl.parallel(0, BATCH, BATCH_TILE): + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="copy_hidden"): + for kb in pl.range(HIDDEN // K_CHUNK): + copy_k0 = kb * K_CHUNK + hidden_chunk = pl.slice( + hidden_states, + [BATCH_TILE, K_CHUNK], + [b0, copy_k0], + valid_shape=[cur_valid, K_CHUNK], + ) + current_hidden = pl.assemble(current_hidden, hidden_chunk, [b0, copy_k0]) + + # === decode_layer body (layer_idx=0 inlined, weight offsets simplified) === # + batch_padded = BATCH + next_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) + + # Intermediate FP32 tensors between scope 1 and scope 2. + q_proj = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) + k_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) + v_proj = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) + q_proj_norm = pl.create_tensor([BATCH, HIDDEN], dtype=pl.FP32) + k_proj_norm = pl.create_tensor([BATCH, KV_HIDDEN], dtype=pl.FP32) + + # Scope 1: input RMSNorm + Q/K/V projection. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + normed_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) + cur_valid = pl.min(BATCH_TILE, user_batch - b0) + + # SPMD rmsnorm (change 1/2) + with pl.spmd(RMSNORM_SPMD_CORES): + normed_tile = self.rmsnorm_kernel( + current_hidden, b0, cur_valid, input_rms_weight, normed_tile, ) - w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [layer_inter_base + down_o0, d0]) - down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) - fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual"): - # Vec-only incore: tload FP32 cube output as ND vec - # tile, add FP32 residual (also ND vec), cast to - # BF16 (vec-to-vec cast preserves ND layout). - down_chunk_fp32 = pl.slice(fp32_chunk_gm, [BATCH_TILE, K_CHUNK], [0, 0]) - resid_chunk_fp32 = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]) - out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) - out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) - next_hidden = pl.assemble(next_hidden, out_chunk_cast, [b0, d0]) - - return next_hidden - - -@pl.jit -def test_decode_layer( - hidden_states: pl.Tensor[[USER_BATCH_DYN, HIDDEN], pl.BF16], - input_rms_weight: pl.Tensor[[1, HIDDEN], pl.FP32], - wq: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16], - wk: pl.Tensor[[HIDDEN, KV_HIDDEN], pl.BF16], - wv: pl.Tensor[[HIDDEN, KV_HIDDEN], pl.BF16], - q_norm_weight: pl.Tensor[[1, HEAD_DIM], pl.FP32], - k_norm_weight: pl.Tensor[[1, HEAD_DIM], pl.FP32], - seq_lens: pl.Tensor[[USER_BATCH_DYN], pl.INT32], - block_table: pl.Tensor[[BLOCK_TABLE_FLAT_DYN], pl.INT32], - slot_mapping: pl.Tensor[[USER_BATCH_DYN], pl.INT32], - rope_cos: pl.Tensor[[ROPE_SEQ_DYN, HEAD_DIM], pl.FP32], - rope_sin: pl.Tensor[[ROPE_SEQ_DYN, HEAD_DIM], pl.FP32], - k_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, HEAD_DIM], pl.BF16], - v_cache: pl.Tensor[[KV_CACHE_ROWS_DYN, HEAD_DIM], pl.BF16], - wo: pl.Tensor[[HIDDEN, HIDDEN], pl.BF16], - post_rms_weight: pl.Tensor[[1, HIDDEN], pl.FP32], - w_gate: pl.Tensor[[HIDDEN, INTERMEDIATE], pl.BF16], - w_up: pl.Tensor[[HIDDEN, INTERMEDIATE], pl.BF16], - w_down: pl.Tensor[[INTERMEDIATE, HIDDEN], pl.BF16], - final_norm_weight: pl.Tensor[[1, HIDDEN], pl.FP32], - lm_head_weight: pl.Tensor[[VOCAB, HIDDEN], pl.BF16], - out: pl.Out[pl.Tensor[[USER_BATCH_DYN, VOCAB], pl.FP32]], -) -> pl.Tensor[[USER_BATCH_DYN, VOCAB], pl.FP32]: - user_batch = pl.tensor.dim(hidden_states, 0) - current_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) - for b0 in pl.parallel(0, BATCH, BATCH_TILE): - cur_valid = pl.min(BATCH_TILE, user_batch - b0) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="copy_hidden"): - for kb in pl.range(HIDDEN // K_CHUNK): - copy_k0 = kb * K_CHUNK - hidden_chunk = pl.slice( - hidden_states, - [BATCH_TILE, K_CHUNK], - [b0, copy_k0], - valid_shape=[cur_valid, K_CHUNK], - ) - current_hidden = pl.assemble(current_hidden, hidden_chunk, [b0, copy_k0]) - - next_hidden = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) - current_hidden = decode_layer( - current_hidden, - input_rms_weight, - wq, - wk, - wv, - q_norm_weight, - k_norm_weight, - seq_lens, - block_table, - slot_mapping, - rope_cos, - rope_sin, - k_cache, - v_cache, - wo, - post_rms_weight, - w_gate, - w_up, - w_down, - next_hidden, - 0, - ) - out = rms_lm_head(current_hidden, final_norm_weight, lm_head_weight, seq_lens, out) - return out + + for q0 in pl.parallel(0, HIDDEN, Q_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_proj"): + q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.range(decode_scope1_hidden_blocks): + q_k0 = kb * INPUT_PROJ_K_CHUNK + q_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, q_k0]) + q_tile_b = pl.slice(wq, [INPUT_PROJ_K_CHUNK, Q_OUT_CHUNK], [q_k0, q0]) + if q_k0 == 0: + q_acc = pl.matmul(q_tile_a, q_tile_b, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, q_tile_a, q_tile_b) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + + for kv0 in pl.parallel(0, KV_HIDDEN, KV_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="k_proj"): + k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.range(decode_scope1_hidden_blocks): + k_k0 = kb * INPUT_PROJ_K_CHUNK + k_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, k_k0]) + k_tile_b = pl.slice(wk, [INPUT_PROJ_K_CHUNK, KV_OUT_CHUNK], [k_k0, kv0]) + if k_k0 == 0: + k_acc = pl.matmul(k_tile_a, k_tile_b, out_dtype=pl.FP32) + else: + k_acc = pl.matmul_acc(k_acc, k_tile_a, k_tile_b) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="v_proj"): + v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + for kb in pl.range(decode_scope1_hidden_blocks): + v_k0 = kb * INPUT_PROJ_K_CHUNK + v_tile_a = pl.slice(normed_tile, [BATCH_TILE, INPUT_PROJ_K_CHUNK], [0, v_k0]) + v_tile_b = pl.slice(wv, [INPUT_PROJ_K_CHUNK, KV_OUT_CHUNK], [v_k0, kv0]) + if v_k0 == 0: + v_acc = pl.matmul(v_tile_a, v_tile_b, out_dtype=pl.FP32) + else: + v_acc = pl.matmul_acc(v_acc, v_tile_a, v_tile_b) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + + # HF-style per-head q_norm / k_norm before RoPE. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_norm"): + for h in pl.range(NUM_KV_HEADS): + q0 = h * Q_PER_KV * HEAD_DIM + q_chunk = pl.reshape( + pl.slice(q_proj, [BATCH_TILE, Q_HEAD_BATCH * HEAD_DIM], [b0, q0]), + [BATCH_TILE * Q_HEAD_BATCH, HEAD_DIM], + ) + q_sq_sum = pl.row_sum(pl.mul(q_chunk, q_chunk)) + q_inv_rms = pl.rsqrt(pl.add(pl.mul(q_sq_sum, head_dim_inv), EPS)) + q_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(q_chunk, q_inv_rms), + pl.slice(q_norm_weight, [1, HEAD_DIM], [0, 0]), + ) + q_chunk_norm_flat = pl.reshape(q_chunk_norm, [BATCH_TILE, Q_HEAD_BATCH * HEAD_DIM]) + q_proj_norm = pl.assemble(q_proj_norm, q_chunk_norm_flat, [b0, q0]) + + k0 = h * HEAD_DIM + k_chunk = pl.slice(k_proj, [BATCH_TILE, HEAD_DIM], [b0, k0]) + k_sq_sum = pl.row_sum(pl.mul(k_chunk, k_chunk)) + k_inv_rms = pl.rsqrt(pl.add(pl.mul(k_sq_sum, head_dim_inv), EPS)) + k_chunk_norm = pl.col_expand_mul( + pl.row_expand_mul(k_chunk, k_inv_rms), + pl.slice(k_norm_weight, [1, HEAD_DIM], [0, 0]), + ) + k_proj_norm = pl.assemble(k_proj_norm, k_chunk_norm, [b0, k0]) + + # Scope 2: RoPE + KV cache update + grouped decode attention. + attn_out = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) + all_q_padded = pl.create_tensor( + [BATCH * TOTAL_Q_GROUPS * Q_HEAD_PAD, HEAD_DIM], dtype=pl.BF16, + ) + + for b in pl.parallel(user_batch): + ctx_len = pl.tensor.read(seq_lens, [b]) + pos = ctx_len - 1 + ctx_blocks = (ctx_len + BLOCK_SIZE - 1) // BLOCK_SIZE + block_table_base = b * MAX_BLOCKS_PER_SEQ + slot = pl.tensor.read(slot_mapping, [b]) + slot_block = slot // BLOCK_SIZE + slot_offset = slot - slot_block * BLOCK_SIZE + cos_row = pl.slice(rope_cos, [1, HEAD_DIM], [pos, 0]) + sin_row = pl.slice(rope_sin, [1, HEAD_DIM], [pos, 0]) + cos_lo = pl.slice(cos_row, [1, HALF_DIM], [0, 0]) + cos_hi = pl.slice(cos_row, [1, HALF_DIM], [0, HALF_DIM]) + sin_lo = pl.slice(sin_row, [1, HALF_DIM], [0, 0]) + sin_hi = pl.slice(sin_row, [1, HALF_DIM], [0, HALF_DIM]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rope_kv_cache"): + for ki in pl.range(NUM_KV_HEADS): + kv_col = ki * HEAD_DIM + cache_row = (slot_block * NUM_KV_HEADS + ki) * BLOCK_SIZE + slot_offset + k_lo = pl.slice(k_proj_norm, [1, HALF_DIM], [b, kv_col]) + k_hi = pl.slice(k_proj_norm, [1, HALF_DIM], [b, kv_col + HALF_DIM]) + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_lo, target_type=pl.BF16), + [cache_row, 0], + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_hi, target_type=pl.BF16), + [cache_row, HALF_DIM], + ) + v_cache = pl.assemble( + v_cache, + pl.cast( + pl.slice(v_proj, [1, HEAD_DIM], [b, kv_col]), + target_type=pl.BF16, + ), + [cache_row, 0], + ) + q_base = ki * Q_PER_KV + q_block = pl.reshape( + pl.slice(q_proj_norm, [1, Q_HEAD_BATCH * HEAD_DIM], [b, q_base * HEAD_DIM]), + [Q_HEAD_BATCH, HEAD_DIM], + ) + q_lo = pl.slice(q_block, [Q_HEAD_BATCH, HALF_DIM], [0, 0]) + q_hi = pl.slice(q_block, [Q_HEAD_BATCH, HALF_DIM], [0, HALF_DIM]) + rot_lo_bf16 = pl.cast( + pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_lo_bf16, + [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD, 0], + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_hi_bf16, + [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD, HALF_DIM], + ) + all_q_padded = pl.assemble( + all_q_padded, + pl.cast( + pl.full([Q_HEAD_PAD - Q_HEAD_BATCH, HEAD_DIM], dtype=pl.FP32, value=0.0), + target_type=pl.BF16, + ), + [b * TOTAL_Q_GROUPS * Q_HEAD_PAD + ki * Q_HEAD_PAD + Q_HEAD_BATCH, 0], + ) + + attn_row = pl.create_tensor([1, HIDDEN], dtype=pl.BF16) + for gi in pl.parallel(0, TOTAL_Q_GROUPS, 2): + gi0 = gi + gi1 = gi + 1 + + kvh0 = gi0 // Q_GROUPS + qg0 = gi0 - kvh0 * Q_GROUPS + q_base0 = kvh0 * Q_PER_KV + qg0 * Q_HEAD_BATCH + q_padded_row0 = b * TOTAL_Q_GROUPS * Q_HEAD_PAD + gi0 * Q_HEAD_PAD + q_padded0 = pl.slice(all_q_padded, [Q_HEAD_PAD, HEAD_DIM], [q_padded_row0, 0]) + + kvh1 = gi1 // Q_GROUPS + qg1 = gi1 - kvh1 * Q_GROUPS + q_base1 = kvh1 * Q_PER_KV + qg1 * Q_HEAD_BATCH + q_padded_row1 = b * TOTAL_Q_GROUPS * Q_HEAD_PAD + gi1 * Q_HEAD_PAD + q_padded1 = pl.slice(all_q_padded, [Q_HEAD_PAD, HEAD_DIM], [q_padded_row1, 0]) + + all_raw_scores0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + all_raw_scores1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qk_matmul"): + for sb in pl.range(ctx_blocks): + qk_block_table_idx = block_table_base + sb + qk_pbid = pl.cast(pl.tensor.read(block_table, [qk_block_table_idx]), pl.INDEX) + + qk_cache_row0 = (qk_pbid * NUM_KV_HEADS + kvh0) * BLOCK_SIZE + k_tile0 = pl.slice(k_cache, [BLOCK_SIZE, HEAD_DIM], [qk_cache_row0, 0]) + raw_scores0 = pl.matmul(q_padded0, k_tile0, b_trans=True, out_dtype=pl.FP32) + all_raw_scores0 = pl.assemble(all_raw_scores0, raw_scores0, [sb * Q_HEAD_PAD, 0]) + + qk_cache_row1 = (qk_pbid * NUM_KV_HEADS + kvh1) * BLOCK_SIZE + k_tile1 = pl.slice(k_cache, [BLOCK_SIZE, HEAD_DIM], [qk_cache_row1, 0]) + raw_scores1 = pl.matmul(q_padded1, k_tile1, b_trans=True, out_dtype=pl.FP32) + all_raw_scores1 = pl.assemble(all_raw_scores1, raw_scores1, [sb * Q_HEAD_PAD, 0]) + + all_exp_padded0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_exp_padded1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, BLOCK_SIZE], dtype=pl.BF16) + all_cur_mi0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_mi1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) + all_cur_li1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, 1], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="softmax"): + for sb in pl.range(ctx_blocks): + s0 = sb * BLOCK_SIZE + valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) + + scores_valid0 = pl.slice( + all_raw_scores0, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded0 = pl.fillpad(scores_valid0, pad_value=pl.PadValue.min) + scores0 = pl.mul(scores_padded0, decode_attn_scale) + softmax_cur_mi0 = pl.row_max(scores0) + exp_scores0 = pl.exp(pl.row_expand_sub(scores0, softmax_cur_mi0)) + exp_scores_bf16_0 = pl.cast(exp_scores0, target_type=pl.BF16) + exp_scores_fp32_0 = pl.cast(exp_scores_bf16_0, target_type=pl.FP32) + softmax_cur_li0 = pl.row_sum(exp_scores_fp32_0) + all_exp_padded0 = pl.assemble(all_exp_padded0, exp_scores_bf16_0, [sb * Q_HEAD_PAD, 0]) + all_cur_mi0 = pl.assemble(all_cur_mi0, softmax_cur_mi0, [sb * Q_HEAD_PAD, 0]) + all_cur_li0 = pl.assemble(all_cur_li0, softmax_cur_li0, [sb * Q_HEAD_PAD, 0]) + + scores_valid1 = pl.slice( + all_raw_scores1, + [Q_HEAD_PAD, BLOCK_SIZE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded1 = pl.fillpad(scores_valid1, pad_value=pl.PadValue.min) + scores1 = pl.mul(scores_padded1, decode_attn_scale) + softmax_cur_mi1 = pl.row_max(scores1) + exp_scores1 = pl.exp(pl.row_expand_sub(scores1, softmax_cur_mi1)) + exp_scores_bf16_1 = pl.cast(exp_scores1, target_type=pl.BF16) + exp_scores_fp32_1 = pl.cast(exp_scores_bf16_1, target_type=pl.FP32) + softmax_cur_li1 = pl.row_sum(exp_scores_fp32_1) + all_exp_padded1 = pl.assemble(all_exp_padded1, exp_scores_bf16_1, [sb * Q_HEAD_PAD, 0]) + all_cur_mi1 = pl.assemble(all_cur_mi1, softmax_cur_mi1, [sb * Q_HEAD_PAD, 0]) + all_cur_li1 = pl.assemble(all_cur_li1, softmax_cur_li1, [sb * Q_HEAD_PAD, 0]) + + all_oi_tmp0 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, HEAD_DIM], dtype=pl.FP32) + all_oi_tmp1 = pl.create_tensor([MAX_BLOCKS_PER_SEQ * Q_HEAD_PAD, HEAD_DIM], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="sv_matmul"): + for sb in pl.range(ctx_blocks): + sv_block_table_idx = block_table_base + sb + sv_pbid = pl.cast(pl.tensor.read(block_table, [sv_block_table_idx]), pl.INDEX) + + sv_cache_row0 = (sv_pbid * NUM_KV_HEADS + kvh0) * BLOCK_SIZE + exp_tile0 = pl.slice(all_exp_padded0, [Q_HEAD_PAD, BLOCK_SIZE], [sb * Q_HEAD_PAD, 0]) + v_tile0 = pl.slice(v_cache, [BLOCK_SIZE, HEAD_DIM], [sv_cache_row0, 0]) + oi_tmp0 = pl.matmul(exp_tile0, v_tile0, out_dtype=pl.FP32) + all_oi_tmp0 = pl.assemble(all_oi_tmp0, oi_tmp0, [sb * Q_HEAD_PAD, 0]) + + sv_cache_row1 = (sv_pbid * NUM_KV_HEADS + kvh1) * BLOCK_SIZE + exp_tile1 = pl.slice(all_exp_padded1, [Q_HEAD_PAD, BLOCK_SIZE], [sb * Q_HEAD_PAD, 0]) + v_tile1 = pl.slice(v_cache, [BLOCK_SIZE, HEAD_DIM], [sv_cache_row1, 0]) + oi_tmp1 = pl.matmul(exp_tile1, v_tile1, out_dtype=pl.FP32) + all_oi_tmp1 = pl.assemble(all_oi_tmp1, oi_tmp1, [sb * Q_HEAD_PAD, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="online_softmax"): + oi0 = pl.slice(all_oi_tmp0, [Q_HEAD_PAD, HEAD_DIM], [0, 0]) + mi0 = pl.slice(all_cur_mi0, [Q_HEAD_PAD, 1], [0, 0]) + li0 = pl.slice(all_cur_li0, [Q_HEAD_PAD, 1], [0, 0]) + oi1 = pl.slice(all_oi_tmp1, [Q_HEAD_PAD, HEAD_DIM], [0, 0]) + mi1 = pl.slice(all_cur_mi1, [Q_HEAD_PAD, 1], [0, 0]) + li1 = pl.slice(all_cur_li1, [Q_HEAD_PAD, 1], [0, 0]) + for sb in pl.range(1, ctx_blocks): + oi_tmp_valid0 = pl.slice(all_oi_tmp0, [Q_HEAD_PAD, HEAD_DIM], [sb * Q_HEAD_PAD, 0]) + online_cur_mi0 = pl.slice(all_cur_mi0, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) + online_cur_li0 = pl.slice(all_cur_li0, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) + mi_new0 = pl.maximum(mi0, online_cur_mi0) + alpha0 = pl.exp(pl.sub(mi0, mi_new0)) + beta0 = pl.exp(pl.sub(online_cur_mi0, mi_new0)) + li0 = pl.add(pl.mul(alpha0, li0), pl.mul(beta0, online_cur_li0)) + oi0 = pl.add(pl.row_expand_mul(oi0, alpha0), pl.row_expand_mul(oi_tmp_valid0, beta0)) + mi0 = mi_new0 + + oi_tmp_valid1 = pl.slice(all_oi_tmp1, [Q_HEAD_PAD, HEAD_DIM], [sb * Q_HEAD_PAD, 0]) + online_cur_mi1 = pl.slice(all_cur_mi1, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) + online_cur_li1 = pl.slice(all_cur_li1, [Q_HEAD_PAD, 1], [sb * Q_HEAD_PAD, 0]) + mi_new1 = pl.maximum(mi1, online_cur_mi1) + alpha1 = pl.exp(pl.sub(mi1, mi_new1)) + beta1 = pl.exp(pl.sub(online_cur_mi1, mi_new1)) + li1 = pl.add(pl.mul(alpha1, li1), pl.mul(beta1, online_cur_li1)) + oi1 = pl.add(pl.row_expand_mul(oi1, alpha1), pl.row_expand_mul(oi_tmp_valid1, beta1)) + mi1 = mi_new1 + + ctx0 = pl.row_expand_div(oi0, li0) + ctx_valid0 = pl.slice(ctx0, [Q_HEAD_BATCH, HEAD_DIM], [0, 0]) + ctx_flat_bf16_0 = pl.cast(pl.reshape(ctx_valid0, [1, Q_HEAD_BATCH * HEAD_DIM]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_0, [0, q_base0 * HEAD_DIM]) + + ctx1 = pl.row_expand_div(oi1, li1) + ctx_valid1 = pl.slice(ctx1, [Q_HEAD_BATCH, HEAD_DIM], [0, 0]) + ctx_flat_bf16_1 = pl.cast(pl.reshape(ctx_valid1, [1, Q_HEAD_BATCH * HEAD_DIM]), target_type=pl.BF16) + attn_row = pl.assemble(attn_row, ctx_flat_bf16_1, [0, q_base1 * HEAD_DIM]) + + attn_out = pl.assemble(attn_out, attn_row, [b, 0]) + + # Scope 3: output projection + residual + post RMSNorm + MLP + residual. + for b0 in pl.parallel(0, batch_padded, BATCH_TILE): + resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.FP32) + + for ob in pl.range(decode_q_out_blocks): + o0 = ob * Q_OUT_CHUNK + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj"): + a_chunk_0 = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, 0]) + w_chunk_0 = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [0, o0]) + o_acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + a_chunk = pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]) + w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) + o_acc = pl.matmul_acc(o_acc, a_chunk, w_chunk) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="out_proj_residual"): + resid = pl.cast( + pl.slice(current_hidden, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), + target_type=pl.FP32, + ) + resid_sum = pl.add(o_acc, resid) + resid1_tile = pl.assemble(resid1_tile, resid_sum, [0, o0]) + + post_norm_tile = pl.create_tensor([BATCH_TILE, HIDDEN], dtype=pl.BF16) + # SPMD post_rmsnorm (change 2/2) + with pl.spmd(RMSNORM_SPMD_CORES): + post_norm_tile = self.post_rmsnorm_kernel( + resid1_tile, BATCH_TILE, post_rms_weight, post_norm_tile, + ) + + mlp_tile = pl.create_tensor([BATCH_TILE, INTERMEDIATE], dtype=pl.BF16) + for ob in pl.range(decode_mlp_out_blocks): + mlp_o0 = ob * MLP_OUT_CHUNK + with pl.at(level=pl.Level.CORE_GROUP, name_hint="gate_proj"): + post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) + wg_0 = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [0, mlp_o0]) + gate_acc = pl.matmul(post_chunk_0, wg_0, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + gate_k0 = kb * K_CHUNK + gate_post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, gate_k0]) + wg = pl.slice(w_gate, [K_CHUNK, MLP_OUT_CHUNK], [gate_k0, mlp_o0]) + gate_acc = pl.matmul_acc(gate_acc, gate_post_chunk, wg) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="up_proj"): + post_chunk_0 = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, 0]) + wu_0 = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [0, mlp_o0]) + up_acc = pl.matmul(post_chunk_0, wu_0, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + up_k0 = kb * K_CHUNK + up_post_chunk = pl.slice(post_norm_tile, [BATCH_TILE, K_CHUNK], [0, up_k0]) + wu = pl.slice(w_up, [K_CHUNK, MLP_OUT_CHUNK], [up_k0, mlp_o0]) + up_acc = pl.matmul_acc(up_acc, up_post_chunk, wu) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="silu"): + sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) + mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) + mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) + mlp_tile = pl.assemble(mlp_tile, mlp_chunk_bf16, [0, mlp_o0]) + + for dob in pl.range(hidden_blocks): + d0 = dob * K_CHUNK + # FP32 GM scratch chunk used as the cube -> vec bridge. + fp32_chunk_gm = pl.create_tensor([BATCH_TILE, K_CHUNK], dtype=pl.FP32) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj"): + mlp_chunk_0 = pl.slice(mlp_tile, [BATCH_TILE, MLP_OUT_CHUNK], [0, 0]) + w_down_chunk_0 = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) + down_acc = pl.matmul(mlp_chunk_0, w_down_chunk_0, out_dtype=pl.FP32) + for ob in pl.range(1, decode_mlp_out_blocks): + down_o0 = ob * MLP_OUT_CHUNK + down_mlp_chunk_bf16 = pl.slice( + mlp_tile, + [BATCH_TILE, MLP_OUT_CHUNK], + [0, down_o0], + ) + w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [down_o0, d0]) + down_acc = pl.matmul_acc(down_acc, down_mlp_chunk_bf16, w_down_chunk) + fp32_chunk_gm = pl.assemble(fp32_chunk_gm, down_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="down_proj_residual"): + down_chunk_fp32 = pl.slice(fp32_chunk_gm, [BATCH_TILE, K_CHUNK], [0, 0]) + resid_chunk_fp32 = pl.slice(resid1_tile, [BATCH_TILE, K_CHUNK], [0, d0]) + out_chunk = pl.add(down_chunk_fp32, resid_chunk_fp32) + out_chunk_cast = pl.cast(out_chunk, target_type=pl.BF16) + next_hidden = pl.assemble(next_hidden, out_chunk_cast, [b0, d0]) + + final_normed = pl.create_tensor([BATCH, HIDDEN], dtype=pl.BF16) + for b0 in pl.parallel(0, BATCH, BATCH_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="final_rmsnorm"): + sq_sum = pl.full([1, BATCH_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.range(HIDDEN // FINAL_RMS_K_CHUNK): + final_sq_k0 = kb * FINAL_RMS_K_CHUNK + final_sq_chunk = pl.cast( + pl.slice(next_hidden, [BATCH_TILE, FINAL_RMS_K_CHUNK], [b0, final_sq_k0]), + target_type=pl.FP32, + ) + sq_sum = pl.add( + sq_sum, + pl.reshape(pl.row_sum(pl.mul(final_sq_chunk, final_sq_chunk)), [1, BATCH_TILE]), + ) + inv_rms_final = pl.reshape( + pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)), + [BATCH_TILE, 1], + ) + + for kb in pl.range(HIDDEN // FINAL_RMS_K_CHUNK): + final_norm_k0 = kb * FINAL_RMS_K_CHUNK + final_hidden_chunk = pl.cast( + pl.slice(next_hidden, [BATCH_TILE, FINAL_RMS_K_CHUNK], [b0, final_norm_k0]), + target_type=pl.FP32, + ) + final_gamma = pl.slice(final_norm_weight, [1, FINAL_RMS_K_CHUNK], [0, final_norm_k0]) + final_normed_chunk = pl.col_expand_mul( + pl.row_expand_mul(final_hidden_chunk, inv_rms_final), + final_gamma, + ) + final_normed = pl.assemble( + final_normed, + pl.cast(final_normed_chunk, target_type=pl.BF16), + [b0, final_norm_k0], + ) + + for b0 in pl.parallel(0, BATCH, BATCH_TILE): + lm_valid_rows = pl.min(BATCH_TILE, user_batch - b0) + for ob in pl.parallel(VOCAB // VOCAB_CHUNK): + lm_o0 = ob * VOCAB_CHUNK + lm_acc_gm = pl.create_tensor([BATCH_TILE, VOCAB_CHUNK], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head"): + lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, 0]) + lm_weight_chunk = pl.slice(lm_head_weight, [VOCAB_CHUNK, LM_HEAD_K_CHUNK], [lm_o0, 0]) + lm_acc = pl.matmul(lm_hidden_chunk, lm_weight_chunk, out_dtype=pl.FP32, b_trans=True) + for kb in pl.range(1, HIDDEN // LM_HEAD_K_CHUNK): + lm_k0 = kb * LM_HEAD_K_CHUNK + lm_hidden_chunk = pl.slice(final_normed, [BATCH_TILE, LM_HEAD_K_CHUNK], [b0, lm_k0]) + lm_weight_chunk = pl.slice( + lm_head_weight, + [VOCAB_CHUNK, LM_HEAD_K_CHUNK], + [lm_o0, lm_k0], + ) + lm_acc = pl.matmul_acc(lm_acc, lm_hidden_chunk, lm_weight_chunk, b_trans=True) + lm_acc_gm = pl.assemble(lm_acc_gm, lm_acc, [0, 0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head_store"): + lm_acc_chunk = pl.slice(lm_acc_gm, [BATCH_TILE, VOCAB_CHUNK], [0, 0]) + lm_acc_trimmed = pl.slice( + lm_acc_chunk, + [BATCH_TILE, VOCAB_CHUNK], + [0, 0], + valid_shape=[lm_valid_rows, VOCAB_CHUNK], + ) + out = pl.assemble(out, lm_acc_trimmed, [b0, lm_o0]) + + return out + + return Qwen3Decode def build_tensor_specs( @@ -669,12 +720,6 @@ def build_tensor_specs( import torch from golden import TensorSpec - # Host allocates every batch-dependent tensor at the user-visible - # batch (no host pad / no host trim). The kernel internally rounds - # up to BATCH_TILE, zero-pads via valid_shape on input loads, and - # trims via vec-to-vec textract on the BF16 output. A single - # compiled program serves any batch <= host capacity (USER_BATCH_DYN - # / KV_CACHE_ROWS_DYN / BLOCK_TABLE_FLAT_DYN are pl.dynamic dims). hidden = num_heads * head_dim kv_hidden = num_kv_heads * head_dim inter = intermediate_size @@ -804,7 +849,7 @@ def init_lm_head_weight(): ] -def golden_decode_layer(tensors): +def golden_qwen3_decode(tensors): """PyTorch reference: scope1 (RMSNorm + projection), scope2 (attention), scope3 (output + MLP).""" import math @@ -989,27 +1034,23 @@ def tiled_lm_head(lhs, rhs_t, k_chunk, vocab_chunk): gate = tiled_matmul(normed_bf16, w_gate, K_CHUNK, MLP_OUT_CHUNK) up = tiled_matmul(normed_bf16, w_up, K_CHUNK, MLP_OUT_CHUNK) - mlp_bf16 = (gate * torch.sigmoid(gate) * up).bfloat16() - down = tiled_matmul(mlp_bf16, w_down, DOWN_MLP_CHUNK, DOWN_OUT_CHUNK) - final_hidden = (down + resid1).bfloat16() + silu = gate.float().sigmoid() * gate * up + down = tiled_matmul(silu.bfloat16(), w_down, MLP_OUT_CHUNK, DOWN_OUT_CHUNK) + output = (down + resid1).bfloat16() - variance = chunked_row_sq_sum(final_hidden.float(), FINAL_RMS_K_CHUNK) / hidden_size + variance = chunked_row_sq_sum(output.float(), FINAL_RMS_K_CHUNK) / hidden_size inv_rms = torch.rsqrt(variance + eps) - final_normed = (final_hidden.float() * inv_rms * final_norm_weight.float()).bfloat16() + final_normed = (output.float() * inv_rms * final_norm_weight).bfloat16() - tensors["out"][:] = tiled_lm_head( - final_normed, - lm_head_weight, - LM_HEAD_K_CHUNK, - VOCAB_CHUNK, - ) + logits = tiled_lm_head(final_normed, lm_head_weight, LM_HEAD_K_CHUNK, VOCAB_CHUNK) + tensors["out"][:] = logits if __name__ == "__main__": import argparse import sys - from golden import run_jit + from golden import run parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", @@ -1044,18 +1085,19 @@ def tiled_lm_head(lhs, rhs_t, k_chunk, vocab_chunk): ) args = parser.parse_args() - result = run_jit( - fn=test_decode_layer, + result = run( + program=build_qwen3_decode_program(batch=args.batch), specs=build_tensor_specs(batch=args.batch, use_max_seq=args.max_seq), - golden_fn=golden_decode_layer, + golden_fn=golden_qwen3_decode, + rtol=3e-3, + atol=3e-3, + compile_cfg=dict(dump_passes=True), runtime_cfg=dict( platform=args.platform, device_id=args.device, enable_l2_swimlane=args.enable_l2_swimlane, enable_pmu=args.enable_pmu, ), - rtol=3e-3, - atol=3e-3, ) if not result.passed: if result.error: