diff --git a/examples/advanced/gemm_eltwise.py b/examples/advanced/gemm_eltwise.py index b2d536c6..6814e20d 100644 --- a/examples/advanced/gemm_eltwise.py +++ b/examples/advanced/gemm_eltwise.py @@ -11,7 +11,7 @@ output = matmul(attn_out, wo) + hidden_states Stage 0 (matmul: attn_out x wo) and Stage 1 (residual add) can be: - - Fused: single pl.at block with chunked_loop_optimizer (mix mode) + - Fused: single pl.at block with auto_chunk (mix mode) - Split: separate pl.at blocks for each stage (split mode) Input and hidden_states are BF16; wo is BF16; output is FP32. @@ -36,7 +36,7 @@ def build_gemm_eltwise_mix_program( batch_tile: int = BATCH_TILE, chunk: int = 4, ): - """Build fused matmul + elementwise program with chunked_loop_optimizer.""" + """Build fused matmul + elementwise program with auto_chunk.""" k_blocks = hidden // k_chunk n_blocks = hidden // n_chunk @@ -50,26 +50,27 @@ def gemm_eltwise( wo: pl.Tensor[[hidden, hidden], pl.BF16], resid: pl.Out[pl.Tensor[[batch, hidden], pl.FP32]], ) -> pl.Tensor[[batch, hidden], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)]): - for nb in pl.parallel(0, n_blocks, chunk=chunk): - n0 = nb * n_chunk - # First K-tile: initialize accumulator via matmul - a_chunk_0 = pl.slice(attn_out, [batch_tile, k_chunk], [0, 0]) - w_chunk_0 = pl.slice(wo, [k_chunk, n_chunk], [0, n0]) - acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) - - # Remaining K-tiles: accumulate via matmul_acc - for kb in pl.range(1, k_blocks): - k0 = kb * k_chunk - a_chunk = pl.slice(attn_out, [batch_tile, k_chunk], [0, k0]) - w_chunk = pl.slice(wo, [k_chunk, n_chunk], [k0, n0]) - acc = pl.matmul_acc(acc, a_chunk, w_chunk) - - # Elementwise residual addition - hidden_chunk = pl.slice(hidden_states, [batch_tile, n_chunk], [0, n0]) - hidden_chunk_f32 = pl.cast(hidden_chunk, target_type=pl.FP32) - resid_sum = pl.add(acc, hidden_chunk_f32) - resid = pl.assemble(resid, resid_sum, [0, n0]) + for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk): + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]): + for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1): + n0 = nb * n_chunk + # First K-tile: initialize accumulator via matmul + a_chunk_0 = pl.slice(attn_out, [batch_tile, k_chunk], [0, 0]) + w_chunk_0 = pl.slice(wo, [k_chunk, n_chunk], [0, n0]) + acc = pl.matmul(a_chunk_0, w_chunk_0, out_dtype=pl.FP32) + + # Remaining K-tiles: accumulate via matmul_acc + for kb in pl.range(1, k_blocks): + k0 = kb * k_chunk + a_chunk = pl.slice(attn_out, [batch_tile, k_chunk], [0, k0]) + w_chunk = pl.slice(wo, [k_chunk, n_chunk], [k0, n0]) + acc = pl.matmul_acc(acc, a_chunk, w_chunk) + + # Elementwise residual addition + hidden_chunk = pl.slice(hidden_states, [batch_tile, n_chunk], [0, n0]) + hidden_chunk_f32 = pl.cast(hidden_chunk, target_type=pl.FP32) + resid_sum = pl.add(acc, hidden_chunk_f32) + resid = pl.assemble(resid, resid_sum, [0, n0]) return resid diff --git a/examples/beginner/hello_world.py b/examples/beginner/hello_world.py index 31f12e8b..7e82a6b2 100644 --- a/examples/beginner/hello_world.py +++ b/examples/beginner/hello_world.py @@ -38,11 +38,12 @@ def add_scalar( a: pl.Scalar[pl.FP32], y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]], ) -> pl.Tensor[[rows, cols], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for r in pl.parallel(0, rows, 1, chunk=row_chunk): - tile_x = pl.slice(x, [1, cols], [r, 0]) - tile_y = pl.add(tile_x, a) - y = pl.assemble(y, tile_y, [r, 0]) + for r_chunk in pl.parallel(0, rows, 1 * row_chunk): + with pl.at(level=pl.Level.CORE_GROUP): + for r in pl.range(r_chunk, r_chunk + 1 * row_chunk, 1): + tile_x = pl.slice(x, [1, cols], [r, 0]) + tile_y = pl.add(tile_x, a) + y = pl.assemble(y, tile_y, [r, 0]) return y diff --git a/examples/beginner/matmul.py b/examples/beginner/matmul.py index 18997a47..9ef8bdce 100644 --- a/examples/beginner/matmul.py +++ b/examples/beginner/matmul.py @@ -46,13 +46,15 @@ def matmul( b: pl.Tensor[[k, n], pl.FP32], c: pl.Out[pl.Tensor[[m, n], pl.FP32]], ) -> pl.Tensor[[m, n], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for mb in pl.parallel(0, m, m_tile, chunk=m_chunk): - for nb in pl.parallel(0, n, n_tile, chunk=n_chunk): - tile_a = pl.slice(a, [m_tile, k], [mb, 0]) - tile_b = pl.slice(b, [k, n_tile], [0, nb]) - tile_c = pl.matmul(tile_a, tile_b) - c = pl.assemble(c, tile_c, [mb, nb]) + for mb_chunk in pl.parallel(0, m, m_tile * m_chunk): + for nb_chunk in pl.parallel(0, n, n_tile * n_chunk): + with pl.at(level=pl.Level.CORE_GROUP): + for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile): + for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile): + tile_a = pl.slice(a, [m_tile, k], [mb, 0]) + tile_b = pl.slice(b, [k, n_tile], [0, nb]) + tile_c = pl.matmul(tile_a, tile_b) + c = pl.assemble(c, tile_c, [mb, nb]) return c diff --git a/examples/intermediate/gemm.py b/examples/intermediate/gemm.py index 1a3fb8bc..ca5eaddf 100644 --- a/examples/intermediate/gemm.py +++ b/examples/intermediate/gemm.py @@ -51,22 +51,24 @@ def gemm( b: pl.Tensor[[k, n], pl.FP32], c: pl.Out[pl.Tensor[[m, n], pl.FP32]], ) -> pl.Tensor[[m, n], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for mb in pl.parallel(0, m, m_tile, chunk=m_chunk): - for nb in pl.parallel(0, n, n_tile, chunk=n_chunk): - # First K-tile: initialize accumulator via matmul - tile_a = pl.slice(a, [m_tile, k_tile], [mb, 0]) - tile_b = pl.slice(b, [k_tile, n_tile], [0, nb]) - acc = pl.matmul(tile_a, tile_b) - - # Remaining K-tiles: accumulate via matmul_acc - for kb in pl.range(1, k_blocks): - k0 = kb * k_tile - tile_a_i = pl.slice(a, [m_tile, k_tile], [mb, k0]) - tile_b_i = pl.slice(b, [k_tile, n_tile], [k0, nb]) - acc = pl.matmul_acc(acc, tile_a_i, tile_b_i) - - c = pl.assemble(c, acc, [mb, nb]) + for mb_chunk in pl.parallel(0, m, m_tile * m_chunk): + for nb_chunk in pl.parallel(0, n, n_tile * n_chunk): + with pl.at(level=pl.Level.CORE_GROUP): + for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile): + for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile): + # First K-tile: initialize accumulator via matmul + tile_a = pl.slice(a, [m_tile, k_tile], [mb, 0]) + tile_b = pl.slice(b, [k_tile, n_tile], [0, nb]) + acc = pl.matmul(tile_a, tile_b) + + # Remaining K-tiles: accumulate via matmul_acc + for kb in pl.range(1, k_blocks): + k0 = kb * k_tile + tile_a_i = pl.slice(a, [m_tile, k_tile], [mb, k0]) + tile_b_i = pl.slice(b, [k_tile, n_tile], [k0, nb]) + acc = pl.matmul_acc(acc, tile_a_i, tile_b_i) + + c = pl.assemble(c, acc, [mb, nb]) return c diff --git a/examples/intermediate/layer_norm.py b/examples/intermediate/layer_norm.py index a2d56073..c89cf6bc 100644 --- a/examples/intermediate/layer_norm.py +++ b/examples/intermediate/layer_norm.py @@ -42,33 +42,34 @@ def layer_norm( beta: pl.Tensor[[1, hidden], pl.FP32], y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]], ) -> pl.Tensor[[rows, hidden], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for r in pl.parallel(0, rows, row_chunk, chunk=1): - tile_x = pl.slice(x, [row_chunk, hidden], [r, 0]) - gamma_tile = pl.slice(gamma, [1, hidden], [0, 0]) - beta_tile = pl.slice(beta, [1, hidden], [0, 0]) - - # Step 1: row mean — pre-scale before row_sum, no reshape - mean = pl.row_sum(pl.mul(tile_x, hidden_inv)) - - # Step 2: row variance + eps — pre-scale and pre-add - centred = pl.row_expand_sub(tile_x, mean) - var_eps = pl.row_sum( - pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv) - ) - - # Step 3: normalise — single reshape pair for sqrt - std = pl.reshape( - pl.sqrt(pl.reshape(var_eps, [1, row_chunk])), - [row_chunk, 1], - ) - normed = pl.row_expand_div(centred, std) - - # Step 4: apply gamma scale and beta offset - scaled = pl.col_expand_mul(normed, gamma_tile) - ones = pl.add(pl.sub(tile_x, tile_x), 1.0) - result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile)) - y = pl.assemble(y, result, [r, 0]) + for r_chunk in pl.parallel(0, rows, row_chunk * 1): + with pl.at(level=pl.Level.CORE_GROUP): + for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk): + tile_x = pl.slice(x, [row_chunk, hidden], [r, 0]) + gamma_tile = pl.slice(gamma, [1, hidden], [0, 0]) + beta_tile = pl.slice(beta, [1, hidden], [0, 0]) + + # Step 1: row mean — pre-scale before row_sum, no reshape + mean = pl.row_sum(pl.mul(tile_x, hidden_inv)) + + # Step 2: row variance + eps — pre-scale and pre-add + centred = pl.row_expand_sub(tile_x, mean) + var_eps = pl.row_sum( + pl.mul(pl.add(pl.mul(centred, centred), eps), hidden_inv) + ) + + # Step 3: normalise — single reshape pair for sqrt + std = pl.reshape( + pl.sqrt(pl.reshape(var_eps, [1, row_chunk])), + [row_chunk, 1], + ) + normed = pl.row_expand_div(centred, std) + + # Step 4: apply gamma scale and beta offset + scaled = pl.col_expand_mul(normed, gamma_tile) + ones = pl.add(pl.sub(tile_x, tile_x), 1.0) + result = pl.add(scaled, pl.col_expand_mul(ones, beta_tile)) + y = pl.assemble(y, result, [r, 0]) return y diff --git a/examples/intermediate/rms_norm.py b/examples/intermediate/rms_norm.py index dc5ca28d..2a956577 100644 --- a/examples/intermediate/rms_norm.py +++ b/examples/intermediate/rms_norm.py @@ -48,32 +48,33 @@ def rms_norm( gamma: pl.Tensor[[1, hidden], pl.FP32], y: pl.Out[pl.Tensor[[rows, hidden], pl.FP32]], ) -> pl.Tensor[[rows, hidden], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for r in pl.parallel(0, rows, row_chunk, chunk=1): - # Pass 1: accumulate sum(x^2) across hidden chunks - # row_sum produces [row_chunk, 1] col_major; scalar ops - # need row_major, so accumulate in [1, row_chunk] shape. - sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) - sq_sum = pl.mul(sq_sum, 0.0) - for hb in pl.range(hidden_blocks): - h0 = hb * hidden_chunk - x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) - rs = pl.row_sum(pl.mul(x_chunk, x_chunk)) - sq_sum = pl.add(sq_sum, pl.reshape(rs, [1, row_chunk])) - - # inv_rms = 1 / sqrt(mean(x^2) + eps) - inv_rms_T = pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), eps)) - inv_rms = pl.reshape(inv_rms_T, [row_chunk, 1]) - - # Pass 2: normalise and apply gamma weight - for hb in pl.range(hidden_blocks): - h0 = hb * hidden_chunk - x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) - gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0]) - normed = pl.col_expand_mul( - pl.row_expand_mul(x_chunk, inv_rms), gamma_chunk - ) - y = pl.assemble(y, normed, [r, h0]) + for r_chunk in pl.parallel(0, rows, row_chunk * 1): + with pl.at(level=pl.Level.CORE_GROUP): + for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk): + # Pass 1: accumulate sum(x^2) across hidden chunks + # row_sum produces [row_chunk, 1] col_major; scalar ops + # need row_major, so accumulate in [1, row_chunk] shape. + sq_sum = pl.create_tensor([1, row_chunk], dtype=pl.FP32) + sq_sum = pl.mul(sq_sum, 0.0) + for hb in pl.range(hidden_blocks): + h0 = hb * hidden_chunk + x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) + rs = pl.row_sum(pl.mul(x_chunk, x_chunk)) + sq_sum = pl.add(sq_sum, pl.reshape(rs, [1, row_chunk])) + + # inv_rms = 1 / sqrt(mean(x^2) + eps) + inv_rms_T = pl.rsqrt(pl.add(pl.mul(sq_sum, hidden_inv), eps)) + inv_rms = pl.reshape(inv_rms_T, [row_chunk, 1]) + + # Pass 2: normalise and apply gamma weight + for hb in pl.range(hidden_blocks): + h0 = hb * hidden_chunk + x_chunk = pl.slice(x, [row_chunk, hidden_chunk], [r, h0]) + gamma_chunk = pl.slice(gamma, [1, hidden_chunk], [0, h0]) + normed = pl.col_expand_mul( + pl.row_expand_mul(x_chunk, inv_rms), gamma_chunk + ) + y = pl.assemble(y, normed, [r, h0]) return y diff --git a/examples/intermediate/rope.py b/examples/intermediate/rope.py index cb2ef82d..f4c5644b 100644 --- a/examples/intermediate/rope.py +++ b/examples/intermediate/rope.py @@ -57,29 +57,30 @@ def rope( sin: pl.Tensor[[1, head_dim], pl.FP32], y: pl.Out[pl.Tensor[[total_rows, head_dim], pl.FP32]], ) -> pl.Tensor[[total_rows, head_dim], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for b in pl.parallel(0, batch, 1, chunk=batch_chunk): - # Slice cos/sin lo/hi halves directly from tensor - # so each becomes a separate tile.load (no textract). - cos_lo = pl.slice(cos, [1, half_dim], [0, 0]) - cos_hi = pl.slice(cos, [1, half_dim], [0, half_dim]) - sin_lo = pl.slice(sin, [1, half_dim], [0, 0]) - sin_hi = pl.slice(sin, [1, half_dim], [0, half_dim]) - - base = b * num_heads - x_lo = pl.slice(x, [num_heads, half_dim], [base, 0]) - x_hi = pl.slice(x, [num_heads, half_dim], [base, half_dim]) - - rot_lo = pl.sub( - pl.col_expand_mul(x_lo, cos_lo), - pl.col_expand_mul(x_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(x_hi, cos_hi), - pl.col_expand_mul(x_lo, sin_hi), - ) - y = pl.assemble(y, rot_lo, [base, 0]) - y = pl.assemble(y, rot_hi, [base, half_dim]) + for b_chunk in pl.parallel(0, batch, 1 * batch_chunk): + with pl.at(level=pl.Level.CORE_GROUP): + for b in pl.range(b_chunk, b_chunk + 1 * batch_chunk, 1): + # Slice cos/sin lo/hi halves directly from tensor + # so each becomes a separate tile.load (no textract). + cos_lo = pl.slice(cos, [1, half_dim], [0, 0]) + cos_hi = pl.slice(cos, [1, half_dim], [0, half_dim]) + sin_lo = pl.slice(sin, [1, half_dim], [0, 0]) + sin_hi = pl.slice(sin, [1, half_dim], [0, half_dim]) + + base = b * num_heads + x_lo = pl.slice(x, [num_heads, half_dim], [base, 0]) + x_hi = pl.slice(x, [num_heads, half_dim], [base, half_dim]) + + rot_lo = pl.sub( + pl.col_expand_mul(x_lo, cos_lo), + pl.col_expand_mul(x_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(x_hi, cos_hi), + pl.col_expand_mul(x_lo, sin_hi), + ) + y = pl.assemble(y, rot_lo, [base, 0]) + y = pl.assemble(y, rot_hi, [base, half_dim]) return y diff --git a/examples/intermediate/softmax.py b/examples/intermediate/softmax.py index 13429395..ee91d3c1 100644 --- a/examples/intermediate/softmax.py +++ b/examples/intermediate/softmax.py @@ -36,26 +36,27 @@ def softmax( x: pl.Tensor[[rows, cols], pl.FP32], y: pl.Out[pl.Tensor[[rows, cols], pl.FP32]], ) -> pl.Tensor[[rows, cols], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for r in pl.parallel(0, rows, row_chunk, chunk=1): - tile_x = pl.slice(x, [row_chunk, cols], [r, 0]) + for r_chunk in pl.parallel(0, rows, row_chunk * 1): + with pl.at(level=pl.Level.CORE_GROUP): + for r in pl.range(r_chunk, r_chunk + row_chunk * 1, row_chunk): + tile_x = pl.slice(x, [row_chunk, cols], [r, 0]) - # Step 1: row-wise max for numerical stability - row_max = pl.row_max(tile_x) + # Step 1: row-wise max for numerical stability + row_max = pl.row_max(tile_x) - # Step 2: subtract row max: x - max(x) - shifted = pl.row_expand_sub(tile_x, row_max) + # Step 2: subtract row max: x - max(x) + shifted = pl.row_expand_sub(tile_x, row_max) - # Step 3: exp(x - max(x)) - exp_shifted = pl.exp(shifted) + # Step 3: exp(x - max(x)) + exp_shifted = pl.exp(shifted) - # Step 4: row-wise sum of exp values - row_sum = pl.row_sum(exp_shifted) + # Step 4: row-wise sum of exp values + row_sum = pl.row_sum(exp_shifted) - # Step 5: divide each row by its sum - result = pl.row_expand_div(exp_shifted, row_sum) + # Step 5: divide each row by its sum + result = pl.row_expand_div(exp_shifted, row_sum) - y = pl.assemble(y, result, [r, 0]) + y = pl.assemble(y, result, [r, 0]) return y diff --git a/models/deepseek/v3_2/deepseek_v3_2_decode_front.py b/models/deepseek/v3_2/deepseek_v3_2_decode_front.py index 82451b94..7b6a63a6 100644 --- a/models/deepseek/v3_2/deepseek_v3_2_decode_front.py +++ b/models/deepseek/v3_2/deepseek_v3_2_decode_front.py @@ -372,8 +372,8 @@ def deepseek_v3_2_decode_front_scope1234( # Stage 2.8: Reduce q_idx_full across heads with weights to get q_idx_out. q_idx_out = pl.create_tensor([BATCH, INDEX_HEAD_DIM], dtype=pl.BF16) weights_flat = pl.reshape(weights, [BATCH * INDEX_HEADS]) - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="s2_q_reduce"): - for d0 in pl.parallel(0, INDEX_HEAD_DIM, QREDUCE_OUT_CHUNK): + for d0 in pl.parallel(0, INDEX_HEAD_DIM, QREDUCE_OUT_CHUNK): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="s2_q_reduce"): for b in pl.range(BATCH): s2_acc_b = pl.full([1, QREDUCE_OUT_CHUNK], dtype=pl.FP32, value=0.0) for h in pl.range(INDEX_HEADS): diff --git a/models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py b/models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py index 9fa27852..d61dd223 100644 --- a/models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py +++ b/models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py @@ -139,306 +139,306 @@ def deepseek_v3_2_prefill_front_layer( w_latent_to_v: pl.Tensor[[NUM_HEADS_CFG, KV_LORA_RANK_CFG, V_HEAD_DIM_CFG], pl.BF16], dispatch_buf: pl.Tensor[[EP_NODES_CFG, BATCH_CFG, MAX_SEQ_CFG, ATTN_OUT_CFG], pl.BF16], ) -> pl.Tensor[[EP_NODES_CFG, BATCH_CFG, MAX_SEQ_CFG, ATTN_OUT_CFG], pl.BF16]: - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - layer_id = pl.tensor.read(layer_id_t, [0]) - - for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): - seq_len_b = pl.tensor.read(seq_lens, [b]) - tok_blocks = (seq_len_b + TOK_TILE - 1) // TOK_TILE - for p0_idx in pl.range(tok_blocks): - p0 = p0_idx * TOK_TILE - valid_tok = pl.min(TOK_TILE, seq_len_b - p0) - - # Scope 1: RMSNorm + Q/K/V projections. - sq_sum = pl.create_tensor([TOK_TILE, 1], dtype=pl.FP32) - sq_sum = pl.mul(sq_sum, 0.0) - # Keep an explicit local Vec pad tensor alive in this - # scope so AllocateMemoryAddr can reflect high occupancy. - usage_pad = pl.create_tensor([TOK_TILE, LOCAL_PAD_WIDTH], dtype=pl.BF16, valid_shape=[valid_tok, LOCAL_PAD_WIDTH]) - usage_pad = pl.mul(usage_pad, 0.0) - usage_pad_fp = pl.cast(usage_pad, target_type=pl.FP32) - usage_pad_sum = pl.row_sum(usage_pad_fp) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk = pl.cast( - pl.slice(hidden_states, [TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[valid_tok, K_CHUNK]), - target_type=pl.FP32, - ) - sq_sum = pl.add(sq_sum, pl.row_sum(pl.mul(x_chunk, x_chunk))) - inv_rms = pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)) - inv_rms = pl.add(inv_rms, pl.mul(usage_pad_sum, 0.0)) - - q_proj_tile = pl.create_tensor([TOK_TILE, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16, valid_shape=[valid_tok, NUM_HEADS_CFG * QK_HEAD_DIM_CFG]) - kv_a_tile = pl.create_tensor([TOK_TILE, KV_A_OUT], dtype=pl.BF16, valid_shape=[valid_tok, KV_A_OUT]) - - # Fused Q path (local fusion trial for former incore_0/1): - # directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b - # without materializing full qr_tile. - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): - q0 = ob * Q_OUT_CHUNK - q_acc = pl.create_tensor([TOK_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - q_acc = pl.mul(q_acc, 0.0) + for b_chunk in pl.parallel(0, BATCH_CFG, 4): + with pl.at(level=pl.Level.CORE_GROUP): + layer_id = pl.tensor.read(layer_id_t, [0]) + for b in pl.range(b_chunk, b_chunk + 4): + seq_len_b = pl.tensor.read(seq_lens, [b]) + tok_blocks = (seq_len_b + TOK_TILE - 1) // TOK_TILE + for p0_idx in pl.range(tok_blocks): + p0 = p0_idx * TOK_TILE + valid_tok = pl.min(TOK_TILE, seq_len_b - p0) + + # Scope 1: RMSNorm + Q/K/V projections. + sq_sum = pl.create_tensor([TOK_TILE, 1], dtype=pl.FP32) + sq_sum = pl.mul(sq_sum, 0.0) + # Keep an explicit local Vec pad tensor alive in this + # scope so AllocateMemoryAddr can reflect high occupancy. + usage_pad = pl.create_tensor([TOK_TILE, LOCAL_PAD_WIDTH], dtype=pl.BF16, valid_shape=[valid_tok, LOCAL_PAD_WIDTH]) + usage_pad = pl.mul(usage_pad, 0.0) + usage_pad_fp = pl.cast(usage_pad, target_type=pl.FP32) + usage_pad_sum = pl.row_sum(usage_pad_fp) for kb in pl.range(HIDDEN_BLOCKS): k0 = kb * K_CHUNK x_chunk = pl.cast( pl.slice(hidden_states, [TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[valid_tok, K_CHUNK]), target_type=pl.FP32, ) - gamma_in = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma_in) - for rb in pl.range(QR_BLOCKS): - r0 = rb * LORA_CHUNK - wq_a_chunk = pl.slice(wq_a, [K_CHUNK, LORA_CHUNK], [k0, r0]) - qr_part = pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_a_chunk) - gamma_q = pl.slice(q_norm_weight, [1, LORA_CHUNK], [0, r0]) - qn_part = pl.col_expand_mul(qr_part, gamma_q) - wq_b_chunk = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [r0, q0]) - q_acc = pl.add(q_acc, pl.matmul(pl.cast(qn_part, target_type=pl.BF16), wq_b_chunk)) - q_proj_tile = pl.assemble(q_proj_tile, pl.cast(q_acc, target_type=pl.BF16), [0, q0]) - - for ob in pl.parallel(0, KV_A_BLOCKS, 1, chunk=8): - kv0 = ob * KV_OUT_CHUNK - kv_acc = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - kv_acc = pl.mul(kv_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk = pl.cast( - pl.slice(hidden_states, [TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[valid_tok, K_CHUNK]), + sq_sum = pl.add(sq_sum, pl.row_sum(pl.mul(x_chunk, x_chunk))) + inv_rms = pl.rsqrt(pl.add(pl.mul(sq_sum, HIDDEN_INV), EPS)) + inv_rms = pl.add(inv_rms, pl.mul(usage_pad_sum, 0.0)) + + q_proj_tile = pl.create_tensor([TOK_TILE, NUM_HEADS_CFG * QK_HEAD_DIM_CFG], dtype=pl.BF16, valid_shape=[valid_tok, NUM_HEADS_CFG * QK_HEAD_DIM_CFG]) + kv_a_tile = pl.create_tensor([TOK_TILE, KV_A_OUT], dtype=pl.BF16, valid_shape=[valid_tok, KV_A_OUT]) + + # Fused Q path (local fusion trial for former incore_0/1): + # directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b + # without materializing full qr_tile. + for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): + q0 = ob * Q_OUT_CHUNK + q_acc = pl.create_tensor([TOK_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + q_acc = pl.mul(q_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk = pl.cast( + pl.slice(hidden_states, [TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[valid_tok, K_CHUNK]), + target_type=pl.FP32, + ) + gamma_in = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma_in) + for rb in pl.range(QR_BLOCKS): + r0 = rb * LORA_CHUNK + wq_a_chunk = pl.slice(wq_a, [K_CHUNK, LORA_CHUNK], [k0, r0]) + qr_part = pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_a_chunk) + gamma_q = pl.slice(q_norm_weight, [1, LORA_CHUNK], [0, r0]) + qn_part = pl.col_expand_mul(qr_part, gamma_q) + wq_b_chunk = pl.slice(wq_b, [LORA_CHUNK, Q_OUT_CHUNK], [r0, q0]) + q_acc = pl.add(q_acc, pl.matmul(pl.cast(qn_part, target_type=pl.BF16), wq_b_chunk)) + q_proj_tile = pl.assemble(q_proj_tile, pl.cast(q_acc, target_type=pl.BF16), [0, q0]) + + for ob in pl.parallel(0, KV_A_BLOCKS, 1, chunk=8): + kv0 = ob * KV_OUT_CHUNK + kv_acc = pl.create_tensor([TOK_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + kv_acc = pl.mul(kv_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk = pl.cast( + pl.slice(hidden_states, [TOK_TILE, K_CHUNK], [b, p0, k0], valid_shape=[valid_tok, K_CHUNK]), + target_type=pl.FP32, + ) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) + wkv_chunk = pl.slice(wkv_a, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + kv_acc = pl.add(kv_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wkv_chunk)) + kv_a_tile = pl.assemble(kv_a_tile, pl.cast(kv_acc, target_type=pl.BF16), [0, kv0]) + + # Scope 2: RoPE + cache update + indexer topk + sparse attention. + # Fusion policy (aligned with decode_front): + # - Stage A/B/C all stay in ONE auto_incore scope. + # - A: per-token cache write + # - B1/B2: two-stage topk (block-local then global merge) + # - C: sparse attention consumes merged topk immediately + # This avoids materializing topk intermediates across kernel boundaries. + attn_tile = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.FP32, valid_shape=[valid_tok, ATTN_OUT_CFG]) + attn_tile = pl.mul(attn_tile, 0.0) + for ti in pl.range(valid_tok): + pos = p0 + ti + ctx_len = pos + 1 + cos_row = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG], [pos, 0]) + sin_row = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG], [pos, 0]) + cos_lo = pl.slice(cos_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) + cos_hi = pl.slice(cos_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) + sin_lo = pl.slice(sin_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) + sin_hi = pl.slice(sin_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) + + cache_row = b * MAX_SEQ_CFG + pos + kv_row = pl.cast(pl.slice(kv_a_tile, [1, KV_LORA_RANK_CFG], [ti, 0]), target_type=pl.FP32) + kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0]) + kv_normed = pl.col_expand_mul(kv_row, kv_gamma) + pe_row = pl.cast( + pl.slice(kv_a_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, KV_LORA_RANK_CFG]), target_type=pl.FP32, ) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms), gamma) - wkv_chunk = pl.slice(wkv_a, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - kv_acc = pl.add(kv_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wkv_chunk)) - kv_a_tile = pl.assemble(kv_a_tile, pl.cast(kv_acc, target_type=pl.BF16), [0, kv0]) - - # Scope 2: RoPE + cache update + indexer topk + sparse attention. - # Fusion policy (aligned with decode_front): - # - Stage A/B/C all stay in ONE auto_incore scope. - # - A: per-token cache write - # - B1/B2: two-stage topk (block-local then global merge) - # - C: sparse attention consumes merged topk immediately - # This avoids materializing topk intermediates across kernel boundaries. - attn_tile = pl.create_tensor([TOK_TILE, ATTN_OUT_CFG], dtype=pl.FP32, valid_shape=[valid_tok, ATTN_OUT_CFG]) - attn_tile = pl.mul(attn_tile, 0.0) - for ti in pl.range(valid_tok): - pos = p0 + ti - ctx_len = pos + 1 - cos_row = pl.slice(rope_cos, [1, QK_ROPE_HEAD_DIM_CFG], [pos, 0]) - sin_row = pl.slice(rope_sin, [1, QK_ROPE_HEAD_DIM_CFG], [pos, 0]) - cos_lo = pl.slice(cos_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - cos_hi = pl.slice(cos_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - sin_lo = pl.slice(sin_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - sin_hi = pl.slice(sin_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - - cache_row = b * MAX_SEQ_CFG + pos - kv_row = pl.cast(pl.slice(kv_a_tile, [1, KV_LORA_RANK_CFG], [ti, 0]), target_type=pl.FP32) - kv_gamma = pl.slice(kv_norm_weight, [1, KV_LORA_RANK_CFG], [0, 0]) - kv_normed = pl.col_expand_mul(kv_row, kv_gamma) - pe_row = pl.cast( - pl.slice(kv_a_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, KV_LORA_RANK_CFG]), - target_type=pl.FP32, - ) - pe_lo = pl.slice(pe_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - pe_hi = pl.slice(pe_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - pe_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - pe_rot = pl.assemble(pe_rot, pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo)), [0, 0]) - pe_rot = pl.assemble(pe_rot, pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) - kv_cache = pl.assemble(kv_cache, pl.cast(kv_normed, target_type=pl.BF16), [cache_row, 0]) - pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot, target_type=pl.BF16), [cache_row, 0]) - - # Stage B1: block-local topk (2 blocks, each 2K candidates). - topk_vals = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.FP32) - topk_idx = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.INT32) - blk_topk_vals = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.FP32) - blk_topk_idx = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.INT32) - topk_vals = pl.mul(topk_vals, -3.402823e38) - topk_idx = pl.mul(topk_idx, 0) - blk_topk_vals = pl.mul(blk_topk_vals, -3.402823e38) - blk_topk_idx = pl.mul(blk_topk_idx, 0) - for kk in pl.range(INDEX_TOPK_CFG): - neg_one = pl.create_tensor([1, 1], dtype=pl.INT32) - neg_one = pl.mul(neg_one, 0) - neg_one = pl.add(neg_one, -1) - topk_idx = pl.assemble(topk_idx, neg_one, [0, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [0, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [1, kk]) - - q_col0 = 0 - q_nope0 = pl.cast( - pl.slice(q_proj_tile, [1, QK_NOPE_HEAD_DIM_CFG], [ti, q_col0]), - target_type=pl.FP32, - ) - q_pe0 = pl.cast( - pl.slice(q_proj_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, q_col0 + QK_NOPE_HEAD_DIM_CFG]), - target_type=pl.FP32, - ) - q0_lo = pl.slice(q_pe0, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - q0_hi = pl.slice(q_pe0, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - q0_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - q0_rot = pl.assemble(q0_rot, pl.sub(pl.col_expand_mul(q0_lo, cos_lo), pl.col_expand_mul(q0_hi, sin_lo)), [0, 0]) - q0_rot = pl.assemble(q0_rot, pl.add(pl.col_expand_mul(q0_hi, cos_hi), pl.col_expand_mul(q0_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) - q0_nope_latent = pl.matmul( - pl.cast(q_nope0, target_type=pl.BF16), - pl.slice(w_q_nope_to_latent, [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [0, 0, 0]), - ) - - sparse_k_gen = pl.min(INDEX_TOPK_CFG, ctx_len) - for blk in pl.range(2): - blk_start = blk * INDEX_TOPK_CFG - blk_end = pl.min(ctx_len, blk_start + INDEX_TOPK_CFG) - for ss in pl.range(INDEX_TOPK_CFG): - s = blk_start + ss - if s < blk_end: - cache_s = b * MAX_SEQ_CFG + s - kv_s = pl.cast(pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), target_type=pl.FP32) - pe_s = pl.cast(pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), target_type=pl.FP32) - score_nope = pl.row_sum(pl.mul(q0_nope_latent, kv_s)) - score_pe = pl.row_sum(pl.mul(q0_rot, pe_s)) - score_fp32 = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) - score_fp8 = pl.cast(score_fp32, target_type=pl.FP8E4M3FN) - score_a5 = pl.cast(score_fp8, target_type=pl.FP32) - cur_score = pl.tensor.read(score_a5, [0, 0]) - - inserted = pl.create_tensor([1, 1], dtype=pl.INT32) - inserted = pl.mul(inserted, 0) - for kk in pl.range(sparse_k_gen): - ins = pl.tensor.read(inserted, [0, 0]) - kth_val = pl.tensor.read(blk_topk_vals, [blk, kk]) - if ins == 0: - if cur_score > kth_val: - for sh in pl.range(sparse_k_gen - 1, kk, -1): - prev_val = pl.tensor.read(blk_topk_vals, [blk, sh - 1]) - prev_idx = pl.tensor.read(blk_topk_idx, [blk, sh - 1]) - prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - prev_val_t = pl.mul(prev_val_t, 0.0) - prev_idx_t = pl.mul(prev_idx_t, 0) - prev_val_t = pl.add(prev_val_t, prev_val) - prev_idx_t = pl.add(prev_idx_t, prev_idx) - blk_topk_vals = pl.assemble(blk_topk_vals, prev_val_t, [blk, sh]) - blk_topk_idx = pl.assemble(blk_topk_idx, prev_idx_t, [blk, sh]) - cur_score_t = pl.create_tensor([1, 1], dtype=pl.FP32) - cur_index_t = pl.create_tensor([1, 1], dtype=pl.INT32) - one_t = pl.create_tensor([1, 1], dtype=pl.INT32) - cur_score_t = pl.mul(cur_score_t, 0.0) - cur_index_t = pl.mul(cur_index_t, 0) - one_t = pl.mul(one_t, 0) - cur_score_t = pl.add(cur_score_t, cur_score) - cur_index_t = pl.add(cur_index_t, s) - one_t = pl.add(one_t, 1) - blk_topk_vals = pl.assemble(blk_topk_vals, cur_score_t, [blk, kk]) - blk_topk_idx = pl.assemble(blk_topk_idx, cur_index_t, [blk, kk]) - inserted = pl.assemble(inserted, one_t, [0, 0]) - - # Stage B2: global merge from 2x(local topk) -> final topk. - for blk in pl.range(2): - for kk in pl.range(sparse_k_gen): - cand_idx = pl.tensor.read(blk_topk_idx, [blk, kk]) - if cand_idx >= 0: - cand_val = pl.tensor.read(blk_topk_vals, [blk, kk]) - inserted = pl.create_tensor([1, 1], dtype=pl.INT32) - inserted = pl.mul(inserted, 0) - for tkk in pl.range(sparse_k_gen): - ins = pl.tensor.read(inserted, [0, 0]) - kth_val = pl.tensor.read(topk_vals, [0, tkk]) - if ins == 0: - if cand_val > kth_val: - for sh in pl.range(sparse_k_gen - 1, tkk, -1): - prev_val = pl.tensor.read(topk_vals, [0, sh - 1]) - prev_idx = pl.tensor.read(topk_idx, [0, sh - 1]) - prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - prev_val_t = pl.mul(prev_val_t, 0.0) - prev_idx_t = pl.mul(prev_idx_t, 0) - prev_val_t = pl.add(prev_val_t, prev_val) - prev_idx_t = pl.add(prev_idx_t, prev_idx) - topk_vals = pl.assemble(topk_vals, prev_val_t, [0, sh]) - topk_idx = pl.assemble(topk_idx, prev_idx_t, [0, sh]) - cand_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) - cand_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) - one_t = pl.create_tensor([1, 1], dtype=pl.INT32) - cand_val_t = pl.mul(cand_val_t, 0.0) - cand_idx_t = pl.mul(cand_idx_t, 0) - one_t = pl.mul(one_t, 0) - cand_val_t = pl.add(cand_val_t, cand_val) - cand_idx_t = pl.add(cand_idx_t, cand_idx) - one_t = pl.add(one_t, 1) - topk_vals = pl.assemble(topk_vals, cand_val_t, [0, tkk]) - topk_idx = pl.assemble(topk_idx, cand_idx_t, [0, tkk]) - inserted = pl.assemble(inserted, one_t, [0, 0]) - - # Stage C: sparse attention directly consumes merged topk_idx. - attn_row = pl.create_tensor([1, ATTN_OUT_CFG], dtype=pl.FP32) - attn_row = pl.mul(attn_row, 0.0) - for h in pl.parallel(0, NUM_HEADS_CFG, 1, chunk=8): - q_col = h * QK_HEAD_DIM_CFG - q_nope = pl.cast( - pl.slice(q_proj_tile, [1, QK_NOPE_HEAD_DIM_CFG], [ti, q_col]), + pe_lo = pl.slice(pe_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) + pe_hi = pl.slice(pe_row, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) + pe_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) + pe_rot = pl.assemble(pe_rot, pl.sub(pl.col_expand_mul(pe_lo, cos_lo), pl.col_expand_mul(pe_hi, sin_lo)), [0, 0]) + pe_rot = pl.assemble(pe_rot, pl.add(pl.col_expand_mul(pe_hi, cos_hi), pl.col_expand_mul(pe_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) + kv_cache = pl.assemble(kv_cache, pl.cast(kv_normed, target_type=pl.BF16), [cache_row, 0]) + pe_cache = pl.assemble(pe_cache, pl.cast(pe_rot, target_type=pl.BF16), [cache_row, 0]) + + # Stage B1: block-local topk (2 blocks, each 2K candidates). + topk_vals = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.FP32) + topk_idx = pl.create_tensor([1, INDEX_TOPK_CFG], dtype=pl.INT32) + blk_topk_vals = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.FP32) + blk_topk_idx = pl.create_tensor([2, INDEX_TOPK_CFG], dtype=pl.INT32) + topk_vals = pl.mul(topk_vals, -3.402823e38) + topk_idx = pl.mul(topk_idx, 0) + blk_topk_vals = pl.mul(blk_topk_vals, -3.402823e38) + blk_topk_idx = pl.mul(blk_topk_idx, 0) + for kk in pl.range(INDEX_TOPK_CFG): + neg_one = pl.create_tensor([1, 1], dtype=pl.INT32) + neg_one = pl.mul(neg_one, 0) + neg_one = pl.add(neg_one, -1) + topk_idx = pl.assemble(topk_idx, neg_one, [0, kk]) + blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [0, kk]) + blk_topk_idx = pl.assemble(blk_topk_idx, neg_one, [1, kk]) + + q_col0 = 0 + q_nope0 = pl.cast( + pl.slice(q_proj_tile, [1, QK_NOPE_HEAD_DIM_CFG], [ti, q_col0]), target_type=pl.FP32, ) - q_pe = pl.cast( - pl.slice(q_proj_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, q_col + QK_NOPE_HEAD_DIM_CFG]), + q_pe0 = pl.cast( + pl.slice(q_proj_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, q_col0 + QK_NOPE_HEAD_DIM_CFG]), target_type=pl.FP32, ) - q_lo = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) - q_hi = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) - q_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) - q_rot = pl.assemble(q_rot, pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), [0, 0]) - q_rot = pl.assemble(q_rot, pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) - q_nope_latent = pl.matmul( - pl.cast(q_nope, target_type=pl.BF16), - pl.slice(w_q_nope_to_latent, [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [h, 0, 0]), + q0_lo = pl.slice(q_pe0, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) + q0_hi = pl.slice(q_pe0, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) + q0_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) + q0_rot = pl.assemble(q0_rot, pl.sub(pl.col_expand_mul(q0_lo, cos_lo), pl.col_expand_mul(q0_hi, sin_lo)), [0, 0]) + q0_rot = pl.assemble(q0_rot, pl.add(pl.col_expand_mul(q0_hi, cos_hi), pl.col_expand_mul(q0_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) + q0_nope_latent = pl.matmul( + pl.cast(q_nope0, target_type=pl.BF16), + pl.slice(w_q_nope_to_latent, [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [0, 0, 0]), ) - oi = pl.create_tensor([1, KV_LORA_RANK_CFG], dtype=pl.FP32) - li = pl.create_tensor([1, 1], dtype=pl.FP32) - mi = pl.create_tensor([1, 1], dtype=pl.FP32) - oi = pl.mul(oi, 0.0) - li = pl.mul(li, 0.0) - mi = pl.mul(mi, 0.0) - sparse_k = pl.min(INDEX_TOPK_CFG, ctx_len) - for kk in pl.range(sparse_k): - s = pl.tensor.read(topk_idx, [0, kk]) - if s >= 0: - cache_s = b * MAX_SEQ_CFG + s - kv_s = pl.cast(pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), target_type=pl.FP32) - pe_s = pl.cast(pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), target_type=pl.FP32) - score_nope = pl.row_sum(pl.mul(q_nope_latent, kv_s)) - score_pe = pl.row_sum(pl.mul(q_rot, pe_s)) - score = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) - cur_mi = score - cur_li = pl.exp(pl.sub(score, cur_mi)) - oi_tmp = pl.row_expand_mul(kv_s, cur_li) - if kk == 0: - oi = oi_tmp - li = cur_li - mi = cur_mi - else: - mi_new = pl.maximum(mi, cur_mi) - alpha = pl.exp(pl.sub(mi, mi_new)) - beta = pl.exp(pl.sub(cur_mi, mi_new)) - li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) - oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) - mi = mi_new - ctx_latent = pl.row_expand_div(oi, li) - v_col = h * V_HEAD_DIM_CFG - ctx_v = pl.create_tensor([1, V_HEAD_DIM_CFG], dtype=pl.FP32) - ctx_v = pl.mul(ctx_v, 0.0) - for vb in pl.range(V_OUT_BLOCKS): - v0 = vb * V_OUT_CHUNK - wv_tile = pl.slice(w_latent_to_v, [KV_LORA_RANK_CFG, V_OUT_CHUNK], [h, 0, v0]) - v_part = pl.matmul(pl.cast(ctx_latent, target_type=pl.BF16), wv_tile, out_dtype=pl.FP32) - ctx_v = pl.assemble(ctx_v, v_part, [0, v0]) - attn_row = pl.assemble(attn_row, ctx_v, [0, v_col]) - attn_tile = pl.assemble(attn_tile, attn_row, [ti, 0]) - - # Scope 3: dispatch writes and return after dispatch. - for ti in pl.range(valid_tok): - pos = p0 + ti - target_node = (b + pos + layer_id) % EP_NODES_CFG - token_row = pl.cast(pl.slice(attn_tile, [1, ATTN_OUT_CFG], [ti, 0]), target_type=pl.BF16) - dispatch_buf = pl.assemble(dispatch_buf, token_row, [target_node, b, pos, 0]) + sparse_k_gen = pl.min(INDEX_TOPK_CFG, ctx_len) + for blk in pl.range(2): + blk_start = blk * INDEX_TOPK_CFG + blk_end = pl.min(ctx_len, blk_start + INDEX_TOPK_CFG) + for ss in pl.range(INDEX_TOPK_CFG): + s = blk_start + ss + if s < blk_end: + cache_s = b * MAX_SEQ_CFG + s + kv_s = pl.cast(pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), target_type=pl.FP32) + pe_s = pl.cast(pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), target_type=pl.FP32) + score_nope = pl.row_sum(pl.mul(q0_nope_latent, kv_s)) + score_pe = pl.row_sum(pl.mul(q0_rot, pe_s)) + score_fp32 = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) + score_fp8 = pl.cast(score_fp32, target_type=pl.FP8E4M3FN) + score_a5 = pl.cast(score_fp8, target_type=pl.FP32) + cur_score = pl.tensor.read(score_a5, [0, 0]) + + inserted = pl.create_tensor([1, 1], dtype=pl.INT32) + inserted = pl.mul(inserted, 0) + for kk in pl.range(sparse_k_gen): + ins = pl.tensor.read(inserted, [0, 0]) + kth_val = pl.tensor.read(blk_topk_vals, [blk, kk]) + if ins == 0: + if cur_score > kth_val: + for sh in pl.range(sparse_k_gen - 1, kk, -1): + prev_val = pl.tensor.read(blk_topk_vals, [blk, sh - 1]) + prev_idx = pl.tensor.read(blk_topk_idx, [blk, sh - 1]) + prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) + prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) + prev_val_t = pl.mul(prev_val_t, 0.0) + prev_idx_t = pl.mul(prev_idx_t, 0) + prev_val_t = pl.add(prev_val_t, prev_val) + prev_idx_t = pl.add(prev_idx_t, prev_idx) + blk_topk_vals = pl.assemble(blk_topk_vals, prev_val_t, [blk, sh]) + blk_topk_idx = pl.assemble(blk_topk_idx, prev_idx_t, [blk, sh]) + cur_score_t = pl.create_tensor([1, 1], dtype=pl.FP32) + cur_index_t = pl.create_tensor([1, 1], dtype=pl.INT32) + one_t = pl.create_tensor([1, 1], dtype=pl.INT32) + cur_score_t = pl.mul(cur_score_t, 0.0) + cur_index_t = pl.mul(cur_index_t, 0) + one_t = pl.mul(one_t, 0) + cur_score_t = pl.add(cur_score_t, cur_score) + cur_index_t = pl.add(cur_index_t, s) + one_t = pl.add(one_t, 1) + blk_topk_vals = pl.assemble(blk_topk_vals, cur_score_t, [blk, kk]) + blk_topk_idx = pl.assemble(blk_topk_idx, cur_index_t, [blk, kk]) + inserted = pl.assemble(inserted, one_t, [0, 0]) + + # Stage B2: global merge from 2x(local topk) -> final topk. + for blk in pl.range(2): + for kk in pl.range(sparse_k_gen): + cand_idx = pl.tensor.read(blk_topk_idx, [blk, kk]) + if cand_idx >= 0: + cand_val = pl.tensor.read(blk_topk_vals, [blk, kk]) + inserted = pl.create_tensor([1, 1], dtype=pl.INT32) + inserted = pl.mul(inserted, 0) + for tkk in pl.range(sparse_k_gen): + ins = pl.tensor.read(inserted, [0, 0]) + kth_val = pl.tensor.read(topk_vals, [0, tkk]) + if ins == 0: + if cand_val > kth_val: + for sh in pl.range(sparse_k_gen - 1, tkk, -1): + prev_val = pl.tensor.read(topk_vals, [0, sh - 1]) + prev_idx = pl.tensor.read(topk_idx, [0, sh - 1]) + prev_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) + prev_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) + prev_val_t = pl.mul(prev_val_t, 0.0) + prev_idx_t = pl.mul(prev_idx_t, 0) + prev_val_t = pl.add(prev_val_t, prev_val) + prev_idx_t = pl.add(prev_idx_t, prev_idx) + topk_vals = pl.assemble(topk_vals, prev_val_t, [0, sh]) + topk_idx = pl.assemble(topk_idx, prev_idx_t, [0, sh]) + cand_val_t = pl.create_tensor([1, 1], dtype=pl.FP32) + cand_idx_t = pl.create_tensor([1, 1], dtype=pl.INT32) + one_t = pl.create_tensor([1, 1], dtype=pl.INT32) + cand_val_t = pl.mul(cand_val_t, 0.0) + cand_idx_t = pl.mul(cand_idx_t, 0) + one_t = pl.mul(one_t, 0) + cand_val_t = pl.add(cand_val_t, cand_val) + cand_idx_t = pl.add(cand_idx_t, cand_idx) + one_t = pl.add(one_t, 1) + topk_vals = pl.assemble(topk_vals, cand_val_t, [0, tkk]) + topk_idx = pl.assemble(topk_idx, cand_idx_t, [0, tkk]) + inserted = pl.assemble(inserted, one_t, [0, 0]) + + # Stage C: sparse attention directly consumes merged topk_idx. + attn_row = pl.create_tensor([1, ATTN_OUT_CFG], dtype=pl.FP32) + attn_row = pl.mul(attn_row, 0.0) + for h in pl.parallel(0, NUM_HEADS_CFG, 1, chunk=8): + q_col = h * QK_HEAD_DIM_CFG + q_nope = pl.cast( + pl.slice(q_proj_tile, [1, QK_NOPE_HEAD_DIM_CFG], [ti, q_col]), + target_type=pl.FP32, + ) + q_pe = pl.cast( + pl.slice(q_proj_tile, [1, QK_ROPE_HEAD_DIM_CFG], [ti, q_col + QK_NOPE_HEAD_DIM_CFG]), + target_type=pl.FP32, + ) + q_lo = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, 0]) + q_hi = pl.slice(q_pe, [1, QK_ROPE_HEAD_DIM_CFG // 2], [0, QK_ROPE_HEAD_DIM_CFG // 2]) + q_rot = pl.create_tensor([1, QK_ROPE_HEAD_DIM_CFG], dtype=pl.FP32) + q_rot = pl.assemble(q_rot, pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), [0, 0]) + q_rot = pl.assemble(q_rot, pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), [0, QK_ROPE_HEAD_DIM_CFG // 2]) + q_nope_latent = pl.matmul( + pl.cast(q_nope, target_type=pl.BF16), + pl.slice(w_q_nope_to_latent, [QK_NOPE_HEAD_DIM_CFG, KV_LORA_RANK_CFG], [h, 0, 0]), + ) + + oi = pl.create_tensor([1, KV_LORA_RANK_CFG], dtype=pl.FP32) + li = pl.create_tensor([1, 1], dtype=pl.FP32) + mi = pl.create_tensor([1, 1], dtype=pl.FP32) + oi = pl.mul(oi, 0.0) + li = pl.mul(li, 0.0) + mi = pl.mul(mi, 0.0) + sparse_k = pl.min(INDEX_TOPK_CFG, ctx_len) + for kk in pl.range(sparse_k): + s = pl.tensor.read(topk_idx, [0, kk]) + if s >= 0: + cache_s = b * MAX_SEQ_CFG + s + kv_s = pl.cast(pl.slice(kv_cache, [1, KV_LORA_RANK_CFG], [cache_s, 0]), target_type=pl.FP32) + pe_s = pl.cast(pl.slice(pe_cache, [1, QK_ROPE_HEAD_DIM_CFG], [cache_s, 0]), target_type=pl.FP32) + score_nope = pl.row_sum(pl.mul(q_nope_latent, kv_s)) + score_pe = pl.row_sum(pl.mul(q_rot, pe_s)) + score = pl.mul(pl.add(score_nope, score_pe), ATTN_SCALE) + cur_mi = score + cur_li = pl.exp(pl.sub(score, cur_mi)) + oi_tmp = pl.row_expand_mul(kv_s, cur_li) + if kk == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = pl.maximum(mi, cur_mi) + alpha = pl.exp(pl.sub(mi, mi_new)) + beta = pl.exp(pl.sub(cur_mi, mi_new)) + li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) + oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) + mi = mi_new + ctx_latent = pl.row_expand_div(oi, li) + v_col = h * V_HEAD_DIM_CFG + ctx_v = pl.create_tensor([1, V_HEAD_DIM_CFG], dtype=pl.FP32) + ctx_v = pl.mul(ctx_v, 0.0) + for vb in pl.range(V_OUT_BLOCKS): + v0 = vb * V_OUT_CHUNK + wv_tile = pl.slice(w_latent_to_v, [KV_LORA_RANK_CFG, V_OUT_CHUNK], [h, 0, v0]) + v_part = pl.matmul(pl.cast(ctx_latent, target_type=pl.BF16), wv_tile, out_dtype=pl.FP32) + ctx_v = pl.assemble(ctx_v, v_part, [0, v0]) + attn_row = pl.assemble(attn_row, ctx_v, [0, v_col]) + attn_tile = pl.assemble(attn_tile, attn_row, [ti, 0]) + + # Scope 3: dispatch writes and return after dispatch. + for ti in pl.range(valid_tok): + pos = p0 + ti + target_node = (b + pos + layer_id) % EP_NODES_CFG + token_row = pl.cast(pl.slice(attn_tile, [1, ATTN_OUT_CFG], [ti, 0]), target_type=pl.BF16) + dispatch_buf = pl.assemble(dispatch_buf, token_row, [target_node, b, pos, 0]) return dispatch_buf diff --git a/models/deepseek/v4/decode_attention_hca.py b/models/deepseek/v4/decode_attention_hca.py index 4541fe7a..25e25552 100644 --- a/models/deepseek/v4/decode_attention_hca.py +++ b/models/deepseek/v4/decode_attention_hca.py @@ -270,7 +270,7 @@ def attention_hca( # Chunk over T so [T, SPARSE_TOPK] INT32 (T=128 * 640 * 4 = 320KB) doesn't blow Vec budget. topk_idxs = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) for t0 in pl.range(0, T, HCA_TOPK_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="hca_topk"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="hca_topk"): win_idx = pl.arange(0, [1, WIN], dtype=pl.INT32) cmp_idx = pl.add( pl.arange(0, [1, CMP_TOPK], dtype=pl.INT32), diff --git a/models/deepseek/v4/decode_attention_swa.py b/models/deepseek/v4/decode_attention_swa.py index b915d2fa..c38b1268 100644 --- a/models/deepseek/v4/decode_attention_swa.py +++ b/models/deepseek/v4/decode_attention_swa.py @@ -160,21 +160,22 @@ def attention_swa( block_table_flat = pl.reshape(block_table, [B * MAX_BLOCKS]) # Per-batch per-token KV scatter: token s of batch b -> slot (start_pos + s) % WIN. for s_idx in pl.range(S): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="swa_scatter_kv"): - ori_slot = (start_pos + s_idx) % WIN - for b in pl.parallel(0, B, 1, chunk=16): - blk_id = pl.cast(pl.read(block_table_flat, [b]), pl.INDEX) - dst_row = blk_id * BLOCK_SIZE + ori_slot - kv_cache_flat = pl.assemble( - kv_cache_flat, - kv[b * S + s_idx : b * S + s_idx + 1, 0:HEAD_DIM], - [dst_row, 0], - ) + ori_slot = (start_pos + s_idx) % WIN + for b_chunk in pl.parallel(0, B, 1 * 16): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="swa_scatter_kv"): + for b in pl.range(b_chunk, b_chunk + 1 * 16, 1): + blk_id = pl.cast(pl.read(block_table_flat, [b]), pl.INDEX) + dst_row = blk_id * BLOCK_SIZE + ori_slot + kv_cache_flat = pl.assemble( + kv_cache_flat, + kv[b * S + s_idx : b * S + s_idx + 1, 0:HEAD_DIM], + [dst_row, 0], + ) kv_cache = pl.reshape(kv_cache_flat, [BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM]) sparse_topk = pl.create_tensor([T, SPARSE_TOPK], dtype=pl.INT32) for b0 in pl.range(0, T, SWA_BATCH_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="swa_topk"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="swa_topk"): idx_row = pl.arange(0, [1, WIN], dtype=pl.INT32) pad_row = pl.full([1, SPARSE_IDX_TOPK], dtype=pl.INT32, value=-1) sparse_topk_row = pl.concat(idx_row, pad_row) @@ -187,7 +188,7 @@ def attention_swa( cmp_kv_dummy = pl.create_tensor([SPARSE_CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], dtype=pl.BF16) cmp_block_table_dummy = pl.create_tensor([B, SPARSE_CMP_MAX_BLOCKS], dtype=pl.INT32) for b0 in pl.range(0, B, SWA_BATCH_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="swa_cmp_dummy"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="swa_cmp_dummy"): cmp_block_table_dummy_tile = pl.full([SWA_BATCH_CHUNK, SPARSE_CMP_MAX_BLOCKS], dtype=pl.INT32, value=-1) cmp_block_table_dummy = pl.assemble(cmp_block_table_dummy, cmp_block_table_dummy_tile, [b0, 0]) diff --git a/models/deepseek/v4/decode_compressor_ratio128.py b/models/deepseek/v4/decode_compressor_ratio128.py index 65ebd81e..a7f272d6 100644 --- a/models/deepseek/v4/decode_compressor_ratio128.py +++ b/models/deepseek/v4/decode_compressor_ratio128.py @@ -238,16 +238,17 @@ def compressor( # Per-batch fan-out: write kv_final[b] to kv[b, 0, :] (row b*S of kv_flat). kv_cache_flat = pl.reshape(kv_cache, [B * IDX_KV_LEN, HEAD_DIM]) cache_col = start_pos // COMPRESS_RATIO - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="kv_and_cache_write"): - for b_idx in pl.parallel(0, B, chunk=16): - kv_row_fp32 = kv_final[b_idx : b_idx + 1, 0 : HEAD_DIM] - kv_flat = pl.assemble(kv_flat, kv_row_fp32, [b_idx * S, 0]) - cache_row = b_idx * IDX_KV_LEN + cache_col - kv_cache_flat = pl.assemble( - kv_cache_flat, - pl.cast(kv_row_fp32, target_type=pl.BF16, mode="rint"), - [cache_row, 0], - ) + for b_chunk in pl.parallel(0, B, 1 * 16): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_and_cache_write"): + for b_idx in pl.range(b_chunk, b_chunk + 1 * 16, 1): + kv_row_fp32 = kv_final[b_idx : b_idx + 1, 0 : HEAD_DIM] + kv_flat = pl.assemble(kv_flat, kv_row_fp32, [b_idx * S, 0]) + cache_row = b_idx * IDX_KV_LEN + cache_col + kv_cache_flat = pl.assemble( + kv_cache_flat, + pl.cast(kv_row_fp32, target_type=pl.BF16, mode="rint"), + [cache_row, 0], + ) kv_cache = pl.reshape(kv_cache_flat, [B, IDX_KV_LEN, HEAD_DIM]) if pre_tokens < S: diff --git a/models/deepseek/v4/decode_indexer.py b/models/deepseek/v4/decode_indexer.py index cac4da75..5f456918 100644 --- a/models/deepseek/v4/decode_indexer.py +++ b/models/deepseek/v4/decode_indexer.py @@ -59,7 +59,7 @@ # Ops are per-row independent, so this just makes the tiles HEAD_GROUP-taller (no # GM intermediates, bit-identical numerics). Vec buffer scales with HEAD_ROWS; # qr_hadamard_quant Vec buffer caps single-loop GRP at 2: GRP=4 overflows even with -# chunked_loop_optimizer (199KB, INT8 store needs >=32-col chunk so can't shrink), +# auto_chunk (199KB, INT8 store needs >=32-col chunk so can't shrink), # and GRP=8 overflows L0C (qr_hadamard_acc [GRP*64,128] FP32 = 256KB). HEAD_GROUP = 2 if T >= 2 else 1 HEAD_ROWS = IDX_N_HEADS * HEAD_GROUP @@ -267,7 +267,7 @@ def indexer( cache0 = cb * CACHE_TILE valid_len = pl.min(CACHE_TILE, cache_len - cache0) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="score_quant"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="score_quant"): for bi in pl.range(SCORE_B_GROUP): b = bg + bi kv0 = b * IDX_KV_LEN diff --git a/models/deepseek/v4/hc_post.py b/models/deepseek/v4/hc_post.py index fe808cdd..fc7aa46e 100644 --- a/models/deepseek/v4/hc_post.py +++ b/models/deepseek/v4/hc_post.py @@ -43,31 +43,32 @@ def hc_post( y_flat = pl.reshape(y, [T, HC_DIM]) for out_h in pl.parallel(HC_MULT): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="hc_post"): - for t in pl.parallel(0, T, 1, chunk=16): - post_w = pl.read(post_flat, [t * HC_MULT + out_h]) - for db in pl.range(D_BLOCKS): - d0 = db * D_CHUNK - x_row = pl.cast( - pl.slice(x_flat, [1, D_CHUNK], [t, d0]), - target_type=pl.FP32, - ) - y_row = pl.mul(x_row, post_w) - for in_h in pl.range(HC_MULT): - comb_w = pl.read( - comb_flat, - [t * HC_MULT * HC_MULT + in_h * HC_MULT + out_h], - ) - residual_row = pl.cast( - pl.slice(residual_flat, [1, D_CHUNK], [t, in_h * D + d0]), + for t_chunk in pl.parallel(0, T, 1 * 16): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="hc_post"): + for t in pl.range(t_chunk, t_chunk + 1 * 16, 1): + post_w = pl.read(post_flat, [t * HC_MULT + out_h]) + for db in pl.range(D_BLOCKS): + d0 = db * D_CHUNK + x_row = pl.cast( + pl.slice(x_flat, [1, D_CHUNK], [t, d0]), target_type=pl.FP32, ) - y_row = pl.add(y_row, pl.mul(residual_row, comb_w)) - y_flat = pl.assemble( - y_flat, - pl.cast(y_row, target_type=pl.BF16, mode="rint"), - [t, out_h * D + d0], - ) + y_row = pl.mul(x_row, post_w) + for in_h in pl.range(HC_MULT): + comb_w = pl.read( + comb_flat, + [t * HC_MULT * HC_MULT + in_h * HC_MULT + out_h], + ) + residual_row = pl.cast( + pl.slice(residual_flat, [1, D_CHUNK], [t, in_h * D + d0]), + target_type=pl.FP32, + ) + y_row = pl.add(y_row, pl.mul(residual_row, comb_w)) + y_flat = pl.assemble( + y_flat, + pl.cast(y_row, target_type=pl.BF16, mode="rint"), + [t, out_h * D + d0], + ) y = pl.reshape(y_flat, [B, S, HC_MULT, D]) return y diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index d38d33c1..fae4286e 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -99,7 +99,7 @@ def qkv_proj_rope( # Stage 0.1: attn_norm RMS — parallel partial sum (Opt S). # Single-task serial reduce was ~93us at S=2; split into ATTN_RMS_PARTIALS - # workers + a small final reduce. chunked_loop_optimizer is REQUIRED here: + # workers + a small final reduce. auto_chunk is REQUIRED here: # without it the inner pl.range tile allocations accumulate and exceed the # 192KB Vec UB at S=2/T=128 (verified by compile failure during tuning). # PARTIALS=2 (not 4+) keeps the FP32 add associativity-free, preserving `q` @@ -107,7 +107,7 @@ def qkv_proj_rope( D_BLOCKS_PER_PARTIAL = D_BLOCKS // ATTN_RMS_PARTIALS x_sq_partial = pl.create_tensor([ATTN_RMS_PARTIALS, T], dtype=pl.FP32) for wg in pl.parallel(0, ATTN_RMS_PARTIALS, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="attn_norm_rms_partial"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm_rms_partial"): rms_d_base = wg * D_BLOCKS_PER_PARTIAL * D_CHUNK local_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) for rms_db in pl.range(D_BLOCKS_PER_PARTIAL): @@ -156,11 +156,11 @@ def qkv_proj_rope( # Stage 2.1: qr_rms — same partial-sum pattern as attn_norm_rms (Opt U). # Inner loop is cast-free (qr_fp32 is already FP32) so Vec pressure is lower - # than attn_norm_rms_partial, but chunked_loop_optimizer is kept for parity. + # than attn_norm_rms_partial, but auto_chunk is kept for parity. Q_BLOCKS_PER_QR_PARTIAL = Q_BLOCKS // QR_RMS_PARTIALS qr_sq_partial = pl.create_tensor([QR_RMS_PARTIALS, T], dtype=pl.FP32) for wgr in pl.parallel(0, QR_RMS_PARTIALS, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="qr_rms_partial"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_rms_partial"): qr_rms_q_base = wgr * Q_BLOCKS_PER_QR_PARTIAL * Q_LORA_CHUNK qr_local_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) for qr_rms_qb in pl.range(Q_BLOCKS_PER_QR_PARTIAL): diff --git a/models/kimi/kimi_k2_decode_draft.py b/models/kimi/kimi_k2_decode_draft.py index 750ff203..3899c660 100644 --- a/models/kimi/kimi_k2_decode_draft.py +++ b/models/kimi/kimi_k2_decode_draft.py @@ -137,7 +137,7 @@ def kimi_k2_decode_layer( # ========================================================================= # Scope 1: Input RMSNorm + QKV Projection # ========================================================================= - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP): sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) @@ -157,196 +157,208 @@ def kimi_k2_decode_layer( inv_rms_tile = pl.slice(inv_rms, [BATCH_TILE, 1], [b0, 0]) # Q projection - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - q0 = ob * Q_OUT_CHUNK - q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - q_acc = pl.mul(q_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) - wq_chunk = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) - q_acc = pl.add(q_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk)) - q_proj = pl.assemble(q_proj, pl.cast(q_acc, target_type=pl.BF16), [b0, q0]) + for ob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + q0 = ob * Q_OUT_CHUNK + q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + q_acc = pl.mul(q_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) + x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) + wq_chunk = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) + q_acc = pl.add(q_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk)) + q_proj = pl.assemble(q_proj, pl.cast(q_acc, target_type=pl.BF16), [b0, q0]) # K/V projection - for ob in pl.parallel(0, KV_OUT_BLOCKS, 1, chunk=8): - kv0 = ob * KV_OUT_CHUNK - k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - k_acc = pl.mul(k_acc, 0.0) - v_acc = pl.mul(v_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) - normed_bf16 = pl.cast(normed, target_type=pl.BF16) - wk_chunk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - wv_chunk = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - k_acc = pl.add(k_acc, pl.matmul(normed_bf16, wk_chunk)) - v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk)) - k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0]) - v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0]) + for ob_chunk in pl.parallel(0, KV_OUT_BLOCKS, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + kv0 = ob * KV_OUT_CHUNK + k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + k_acc = pl.mul(k_acc, 0.0) + v_acc = pl.mul(v_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) + x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) + normed_bf16 = pl.cast(normed, target_type=pl.BF16) + wk_chunk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + wv_chunk = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + k_acc = pl.add(k_acc, pl.matmul(normed_bf16, wk_chunk)) + v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk)) + k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0]) + v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0]) # ========================================================================= # Scope 2: RoPE + KV Cache Update + Flash Decoding Attention # ========================================================================= - for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): - pos = pl.tensor.read(cache_pos, [b]) - ctx_len = pos + 1 - ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + for b_chunk in pl.parallel(0, BATCH_CFG, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for b in pl.range(b_chunk, b_chunk + 4): + pos = pl.tensor.read(cache_pos, [b]) + ctx_len = pos + 1 + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE - # Load RoPE coefficients - cos_row = pl.slice(rope_cos, [1, HEAD_DIM_CFG], [pos, 0]) - sin_row = pl.slice(rope_sin, [1, HEAD_DIM_CFG], [pos, 0]) - cos_lo = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - cos_hi = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - - # RoPE for K/V and update cache - for kvh in pl.parallel(0, NUM_KV_HEADS_CFG, 1, chunk=4): - kv_col = kvh * HEAD_DIM_CFG - k_row = pl.cast( - pl.slice(k_proj, [1, HEAD_DIM_CFG], [b, kv_col]), - target_type=pl.FP32, - ) - k_lo = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - k_hi = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + # Load RoPE coefficients + cos_row = pl.slice(rope_cos, [1, HEAD_DIM_CFG], [pos, 0]) + sin_row = pl.slice(rope_sin, [1, HEAD_DIM_CFG], [pos, 0]) + cos_lo = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + cos_hi = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + + # RoPE for K/V and update cache + for kvh_chunk in pl.parallel(0, NUM_KV_HEADS_CFG, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for kvh in pl.range(kvh_chunk, kvh_chunk + 4): + kv_col = kvh * HEAD_DIM_CFG + k_row = pl.cast( + pl.slice(k_proj, [1, HEAD_DIM_CFG], [b, kv_col]), + target_type=pl.FP32, + ) + k_lo = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + k_hi = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - # Apply RoPE - k_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - k_rot = pl.assemble( - k_rot, - pl.sub(pl.col_expand_mul(k_lo, cos_lo), pl.col_expand_mul(k_hi, sin_lo)), - [0, 0], - ) - k_rot = pl.assemble( - k_rot, - pl.add(pl.col_expand_mul(k_hi, cos_hi), pl.col_expand_mul(k_lo, sin_hi)), - [0, HEAD_DIM_CFG // 2], - ) + # Apply RoPE + k_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + k_rot = pl.assemble( + k_rot, + pl.sub(pl.col_expand_mul(k_lo, cos_lo), pl.col_expand_mul(k_hi, sin_lo)), + [0, 0], + ) + k_rot = pl.assemble( + k_rot, + pl.add(pl.col_expand_mul(k_hi, cos_hi), pl.col_expand_mul(k_lo, sin_hi)), + [0, HEAD_DIM_CFG // 2], + ) - # Update KV cache - cache_row = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + pos - k_cache = pl.assemble(k_cache, pl.cast(k_rot, target_type=pl.BF16), [cache_row, 0]) - v_cache = pl.assemble( - v_cache, - pl.slice(v_proj, [1, HEAD_DIM_CFG], [b, kv_col]), - [cache_row, 0], - ) - - # Flash Decoding Attention per head - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) - attn_row = pl.mul(attn_row, 0.0) - - for h in pl.parallel(0, NUM_HEADS_CFG, 1, chunk=8): - kvh = h // Q_PER_KV_CFG - q_col = h * HEAD_DIM_CFG - - # RoPE for Q - q_row = pl.cast( - pl.slice(q_proj, [1, HEAD_DIM_CFG], [b, q_col]), - target_type=pl.FP32, - ) - q_lo = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - q_hi = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - q_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - q_rot = pl.assemble( - q_rot, - pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), - [0, 0], - ) - q_rot = pl.assemble( - q_rot, - pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), - [0, HEAD_DIM_CFG // 2], - ) - q_rot_bf16 = pl.cast(q_rot, target_type=pl.BF16) - - # Online softmax state - oi = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - li = pl.create_tensor([1, 1], dtype=pl.FP32) - mi = pl.create_tensor([1, 1], dtype=pl.FP32) - oi = pl.mul(oi, 0.0) - li = pl.mul(li, 0.0) - mi = pl.mul(mi, 0.0) - - # Process KV cache in chunks (sliding window) - for sb in pl.range(ctx_blocks): - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 + # Update KV cache + cache_row = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + pos + k_cache = pl.assemble(k_cache, pl.cast(k_rot, target_type=pl.BF16), [cache_row, 0]) + v_cache = pl.assemble( + v_cache, + pl.slice(v_proj, [1, HEAD_DIM_CFG], [b, kv_col]), + [cache_row, 0], + ) + + # Flash Decoding Attention per head + with pl.at(level=pl.Level.CORE_GROUP): + attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) + attn_row = pl.mul(attn_row, 0.0) + + for h_chunk in pl.parallel(0, NUM_HEADS_CFG, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for h in pl.range(h_chunk, h_chunk + 8): + kvh = h // Q_PER_KV_CFG + q_col = h * HEAD_DIM_CFG + + # RoPE for Q + q_row = pl.cast( + pl.slice(q_proj, [1, HEAD_DIM_CFG], [b, q_col]), + target_type=pl.FP32, + ) + q_lo = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + q_hi = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + q_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + q_rot = pl.assemble( + q_rot, + pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), + [0, 0], + ) + q_rot = pl.assemble( + q_rot, + pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), + [0, HEAD_DIM_CFG // 2], + ) + q_rot_bf16 = pl.cast(q_rot, target_type=pl.BF16) + + # Online softmax state + oi = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + li = pl.create_tensor([1, 1], dtype=pl.FP32) + mi = pl.create_tensor([1, 1], dtype=pl.FP32) + oi = pl.mul(oi, 0.0) + li = pl.mul(li, 0.0) + mi = pl.mul(mi, 0.0) + + # Process KV cache in chunks (sliding window) + for sb in pl.range(ctx_blocks): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 - k_tile = pl.slice(k_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) - v_tile = pl.slice(v_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) + k_tile = pl.slice(k_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) + v_tile = pl.slice(v_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) - # Q @ K^T * scale - scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE) - scores_valid = pl.slice(scores, [1, valid_len], [0, 0]) + # Q @ K^T * scale + scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE) + scores_valid = pl.slice(scores, [1, valid_len], [0, 0]) - # Online softmax - cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32) - exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi)) - cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32) + # Online softmax + cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32) + exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi)) + cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32) - exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32) - exp_pad = pl.mul(exp_pad, 0.0) - exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) + exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32) + exp_pad = pl.mul(exp_pad, 0.0) + exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) - oi_tmp = pl.matmul( - pl.cast(exp_pad, target_type=pl.BF16), - v_tile, - out_dtype=pl.FP32, - ) - - if sb == 0: - oi = oi_tmp - li = cur_li - mi = cur_mi - else: - mi_new = pl.maximum(mi, cur_mi) - alpha = pl.exp(pl.sub(mi, mi_new)) - beta = pl.exp(pl.sub(cur_mi, mi_new)) - li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) - oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) - mi = mi_new - - ctx = pl.row_expand_div(oi, li) - attn_row = pl.assemble(attn_row, ctx, [0, q_col]) - - attn_out = pl.assemble(attn_out, attn_row, [b, 0]) + oi_tmp = pl.matmul( + pl.cast(exp_pad, target_type=pl.BF16), + v_tile, + out_dtype=pl.FP32, + ) + + if sb == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = pl.maximum(mi, cur_mi) + alpha = pl.exp(pl.sub(mi, mi_new)) + beta = pl.exp(pl.sub(cur_mi, mi_new)) + li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) + oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) + mi = mi_new + + ctx = pl.row_expand_div(oi, li) + attn_row = pl.assemble(attn_row, ctx, [0, q_col]) + + attn_out = pl.assemble(attn_out, attn_row, [b, 0]) # ========================================================================= # Scope 3: Output Projection + Residual + Post RMSNorm + MoE # ========================================================================= - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP): for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): # Output projection + residual resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): - o0 = ob * Q_OUT_CHUNK - o_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - o_acc = pl.mul(o_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - a_chunk = pl.cast( - pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]), - target_type=pl.BF16, - ) - w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) - o_acc = pl.add(o_acc, pl.matmul(a_chunk, w_chunk)) - resid = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), - target_type=pl.FP32, - ) - resid1_tile = pl.assemble(resid1_tile, pl.add(o_acc, resid), [0, o0]) + for ob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + o0 = ob * Q_OUT_CHUNK + o_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + o_acc = pl.mul(o_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + a_chunk = pl.cast( + pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.BF16, + ) + w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) + o_acc = pl.add(o_acc, pl.matmul(a_chunk, w_chunk)) + resid = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), + target_type=pl.FP32, + ) + resid1_tile = pl.assemble(resid1_tile, pl.add(o_acc, resid), [0, o0]) # Post RMSNorm sq_sum = pl.create_tensor([BATCH_TILE, 1], dtype=pl.FP32) @@ -414,12 +426,14 @@ def kimi_k2_decode_layer( mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) - for dob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - d0 = dob * Q_OUT_CHUNK - down_prev = pl.slice(shared_out, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) - w_down_chunk = pl.slice(w_down_shared, [MLP_OUT_CHUNK, Q_OUT_CHUNK], [o0, d0]) - down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) - shared_out = pl.assemble(shared_out, down_next, [0, d0]) + for dob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for dob in pl.range(dob_chunk, dob_chunk + 4): + d0 = dob * Q_OUT_CHUNK + down_prev = pl.slice(shared_out, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) + w_down_chunk = pl.slice(w_down_shared, [MLP_OUT_CHUNK, Q_OUT_CHUNK], [o0, d0]) + down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) + shared_out = pl.assemble(shared_out, down_next, [0, d0]) # Routed Experts (top-K selection) # Simplified: process all experts with gating weight @@ -450,14 +464,16 @@ def kimi_k2_decode_layer( mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) - for dob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - d0 = dob * Q_OUT_CHUNK - down_prev = pl.slice(expert_out, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) - # Slice and reshape 3D down weights to 2D - w_down_chunk = pl.slice(w_down_experts, [NUM_EXPERTS_CFG, MLP_OUT_CHUNK, Q_OUT_CHUNK], [exp_idx, o0, d0]) - w_down_chunk = pl.reshape(w_down_chunk, [MLP_OUT_CHUNK, Q_OUT_CHUNK]) - down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) - expert_out = pl.assemble(expert_out, down_next, [0, d0]) + for dob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for dob in pl.range(dob_chunk, dob_chunk + 4): + d0 = dob * Q_OUT_CHUNK + down_prev = pl.slice(expert_out, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) + # Slice and reshape 3D down weights to 2D + w_down_chunk = pl.slice(w_down_experts, [NUM_EXPERTS_CFG, MLP_OUT_CHUNK, Q_OUT_CHUNK], [exp_idx, o0, d0]) + w_down_chunk = pl.reshape(w_down_chunk, [MLP_OUT_CHUNK, Q_OUT_CHUNK]) + down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) + expert_out = pl.assemble(expert_out, down_next, [0, d0]) # Weight by gating probability gate_weight = pl.slice(gate_prob, [BATCH_TILE, 1], [0, exp_idx]) diff --git a/models/milm/milm_decode_draft.py b/models/milm/milm_decode_draft.py index 77b19f54..6bf14d5e 100644 --- a/models/milm/milm_decode_draft.py +++ b/models/milm/milm_decode_draft.py @@ -128,7 +128,7 @@ def milm_decode_layer( # Scope 1: Input RMSNorm + QKV Projection # Optimized with chunked computation to reduce InCore pressure # ========================================================================= - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP): # Compute sum of squares for RMSNorm sq_sum = pl.create_tensor([BATCH_CFG, 1], dtype=pl.FP32) sq_sum = pl.mul(sq_sum, 0.0) @@ -149,197 +149,209 @@ def milm_decode_layer( inv_rms_tile = pl.slice(inv_rms, [BATCH_TILE, 1], [b0, 0]) # Q projection (parallel over output chunks) - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - q0 = ob * Q_OUT_CHUNK - q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - q_acc = pl.mul(q_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - # RMSNorm: (x / rms) * gamma - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) - wq_chunk = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) - q_acc = pl.add(q_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk)) - q_proj = pl.assemble(q_proj, pl.cast(q_acc, target_type=pl.BF16), [b0, q0]) + for ob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + q0 = ob * Q_OUT_CHUNK + q_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + q_acc = pl.mul(q_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) + x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + # RMSNorm: (x / rms) * gamma + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) + wq_chunk = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) + q_acc = pl.add(q_acc, pl.matmul(pl.cast(normed, target_type=pl.BF16), wq_chunk)) + q_proj = pl.assemble(q_proj, pl.cast(q_acc, target_type=pl.BF16), [b0, q0]) # K/V projection (parallel over output chunks) - for ob in pl.parallel(0, KV_OUT_BLOCKS, 1, chunk=8): - kv0 = ob * KV_OUT_CHUNK - k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) - k_acc = pl.mul(k_acc, 0.0) - v_acc = pl.mul(v_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) - x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) - gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) - normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) - normed_bf16 = pl.cast(normed, target_type=pl.BF16) - wk_chunk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - wv_chunk = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - k_acc = pl.add(k_acc, pl.matmul(normed_bf16, wk_chunk)) - v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk)) - k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0]) - v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0]) + for ob_chunk in pl.parallel(0, KV_OUT_BLOCKS, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + kv0 = ob * KV_OUT_CHUNK + k_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + v_acc = pl.create_tensor([BATCH_TILE, KV_OUT_CHUNK], dtype=pl.FP32) + k_acc = pl.mul(k_acc, 0.0) + v_acc = pl.mul(v_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + x_chunk_bf16 = pl.slice(hidden_states, [BATCH_TILE, K_CHUNK], [b0, k0]) + x_chunk = pl.cast(x_chunk_bf16, target_type=pl.FP32) + gamma = pl.slice(input_rms_weight, [1, K_CHUNK], [0, k0]) + normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, inv_rms_tile), gamma) + normed_bf16 = pl.cast(normed, target_type=pl.BF16) + wk_chunk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + wv_chunk = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + k_acc = pl.add(k_acc, pl.matmul(normed_bf16, wk_chunk)) + v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk)) + k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0]) + v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0]) # ========================================================================= # Scope 2: RoPE + KV Cache Update + Flash Decoding Attention # ========================================================================= - for b in pl.parallel(0, BATCH_CFG, 1, chunk=4): - pos = pl.tensor.read(cache_pos, [b]) - ctx_len = pos + 1 - ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE + for b_chunk in pl.parallel(0, BATCH_CFG, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for b in pl.range(b_chunk, b_chunk + 4): + pos = pl.tensor.read(cache_pos, [b]) + ctx_len = pos + 1 + ctx_blocks = (ctx_len + SEQ_TILE - 1) // SEQ_TILE - # Load RoPE coefficients for current position - cos_row = pl.slice(rope_cos, [1, HEAD_DIM_CFG], [pos, 0]) - sin_row = pl.slice(rope_sin, [1, HEAD_DIM_CFG], [pos, 0]) - cos_lo = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - cos_hi = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - - # Apply RoPE to K/V and update cache - for kvh in pl.parallel(0, NUM_KV_HEADS_CFG, 1, chunk=4): - kv_col = kvh * HEAD_DIM_CFG - k_row = pl.cast( - pl.slice(k_proj, [1, HEAD_DIM_CFG], [b, kv_col]), - target_type=pl.FP32, - ) - k_lo = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - k_hi = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + # Load RoPE coefficients for current position + cos_row = pl.slice(rope_cos, [1, HEAD_DIM_CFG], [pos, 0]) + sin_row = pl.slice(rope_sin, [1, HEAD_DIM_CFG], [pos, 0]) + cos_lo = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + cos_hi = pl.slice(cos_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + sin_lo = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + sin_hi = pl.slice(sin_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + + # Apply RoPE to K/V and update cache + for kvh_chunk in pl.parallel(0, NUM_KV_HEADS_CFG, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for kvh in pl.range(kvh_chunk, kvh_chunk + 4): + kv_col = kvh * HEAD_DIM_CFG + k_row = pl.cast( + pl.slice(k_proj, [1, HEAD_DIM_CFG], [b, kv_col]), + target_type=pl.FP32, + ) + k_lo = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + k_hi = pl.slice(k_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - # RoPE: [k_lo, k_hi] -> [k_lo*cos - k_hi*sin, k_hi*cos + k_lo*sin] - k_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - k_rot = pl.assemble( - k_rot, - pl.sub(pl.col_expand_mul(k_lo, cos_lo), pl.col_expand_mul(k_hi, sin_lo)), - [0, 0], - ) - k_rot = pl.assemble( - k_rot, - pl.add(pl.col_expand_mul(k_hi, cos_hi), pl.col_expand_mul(k_lo, sin_hi)), - [0, HEAD_DIM_CFG // 2], - ) + # RoPE: [k_lo, k_hi] -> [k_lo*cos - k_hi*sin, k_hi*cos + k_lo*sin] + k_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + k_rot = pl.assemble( + k_rot, + pl.sub(pl.col_expand_mul(k_lo, cos_lo), pl.col_expand_mul(k_hi, sin_lo)), + [0, 0], + ) + k_rot = pl.assemble( + k_rot, + pl.add(pl.col_expand_mul(k_hi, cos_hi), pl.col_expand_mul(k_lo, sin_hi)), + [0, HEAD_DIM_CFG // 2], + ) - # Update KV cache - cache_row = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + pos - k_cache = pl.assemble(k_cache, pl.cast(k_rot, target_type=pl.BF16), [cache_row, 0]) - v_cache = pl.assemble( - v_cache, - pl.slice(v_proj, [1, HEAD_DIM_CFG], [b, kv_col]), - [cache_row, 0], - ) - - # Flash Decoding Attention (per head with GQA) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) - attn_row = pl.mul(attn_row, 0.0) - - for h in pl.parallel(0, NUM_HEADS_CFG, 1, chunk=8): - kvh = h // Q_PER_KV_CFG # GQA: multiple Q heads share one KV head - q_col = h * HEAD_DIM_CFG - - # Apply RoPE to Q - q_row = pl.cast( - pl.slice(q_proj, [1, HEAD_DIM_CFG], [b, q_col]), - target_type=pl.FP32, - ) - q_lo = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, 0]) - q_hi = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) - q_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - q_rot = pl.assemble( - q_rot, - pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), - [0, 0], - ) - q_rot = pl.assemble( - q_rot, - pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), - [0, HEAD_DIM_CFG // 2], - ) - q_rot_bf16 = pl.cast(q_rot, target_type=pl.BF16) - - # Online softmax state for Flash Decoding - oi = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) - li = pl.create_tensor([1, 1], dtype=pl.FP32) - mi = pl.create_tensor([1, 1], dtype=pl.FP32) - oi = pl.mul(oi, 0.0) - li = pl.mul(li, 0.0) - mi = pl.mul(mi, 0.0) - - # Process KV cache in chunks - for sb in pl.range(ctx_blocks): - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 + # Update KV cache + cache_row = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + pos + k_cache = pl.assemble(k_cache, pl.cast(k_rot, target_type=pl.BF16), [cache_row, 0]) + v_cache = pl.assemble( + v_cache, + pl.slice(v_proj, [1, HEAD_DIM_CFG], [b, kv_col]), + [cache_row, 0], + ) + + # Flash Decoding Attention (per head with GQA) + with pl.at(level=pl.Level.CORE_GROUP): + attn_row = pl.create_tensor([1, HIDDEN_CFG], dtype=pl.FP32) + attn_row = pl.mul(attn_row, 0.0) + + for h_chunk in pl.parallel(0, NUM_HEADS_CFG, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for h in pl.range(h_chunk, h_chunk + 8): + kvh = h // Q_PER_KV_CFG # GQA: multiple Q heads share one KV head + q_col = h * HEAD_DIM_CFG + + # Apply RoPE to Q + q_row = pl.cast( + pl.slice(q_proj, [1, HEAD_DIM_CFG], [b, q_col]), + target_type=pl.FP32, + ) + q_lo = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, 0]) + q_hi = pl.slice(q_row, [1, HEAD_DIM_CFG // 2], [0, HEAD_DIM_CFG // 2]) + q_rot = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + q_rot = pl.assemble( + q_rot, + pl.sub(pl.col_expand_mul(q_lo, cos_lo), pl.col_expand_mul(q_hi, sin_lo)), + [0, 0], + ) + q_rot = pl.assemble( + q_rot, + pl.add(pl.col_expand_mul(q_hi, cos_hi), pl.col_expand_mul(q_lo, sin_hi)), + [0, HEAD_DIM_CFG // 2], + ) + q_rot_bf16 = pl.cast(q_rot, target_type=pl.BF16) + + # Online softmax state for Flash Decoding + oi = pl.create_tensor([1, HEAD_DIM_CFG], dtype=pl.FP32) + li = pl.create_tensor([1, 1], dtype=pl.FP32) + mi = pl.create_tensor([1, 1], dtype=pl.FP32) + oi = pl.mul(oi, 0.0) + li = pl.mul(li, 0.0) + mi = pl.mul(mi, 0.0) + + # Process KV cache in chunks + for sb in pl.range(ctx_blocks): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + cache_row0 = b * NUM_KV_HEADS_CFG * MAX_SEQ_CFG + kvh * MAX_SEQ_CFG + s0 - k_tile = pl.slice(k_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) - v_tile = pl.slice(v_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) + k_tile = pl.slice(k_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) + v_tile = pl.slice(v_cache, [SEQ_TILE, HEAD_DIM_CFG], [cache_row0, 0]) - # Q @ K^T * scale - scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE) - scores_valid = pl.slice(scores, [1, valid_len], [0, 0]) + # Q @ K^T * scale + scores = pl.mul(pl.matmul(q_rot_bf16, k_tile, b_trans=True), ATTN_SCALE) + scores_valid = pl.slice(scores, [1, valid_len], [0, 0]) - # Online softmax (numerically stable) - cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32) - exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi)) - cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32) + # Online softmax (numerically stable) + cur_mi = pl.cast(pl.row_max(scores_valid), target_type=pl.FP32) + exp_scores = pl.exp(pl.row_expand_sub(scores_valid, cur_mi)) + cur_li = pl.cast(pl.row_sum(exp_scores), target_type=pl.FP32) - exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32) - exp_pad = pl.mul(exp_pad, 0.0) - exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) + exp_pad = pl.create_tensor([1, SEQ_TILE], dtype=pl.FP32) + exp_pad = pl.mul(exp_pad, 0.0) + exp_pad = pl.assemble(exp_pad, exp_scores, [0, 0]) - oi_tmp = pl.matmul( - pl.cast(exp_pad, target_type=pl.BF16), - v_tile, - out_dtype=pl.FP32, - ) - - if sb == 0: - oi = oi_tmp - li = cur_li - mi = cur_mi - else: - mi_new = pl.maximum(mi, cur_mi) - alpha = pl.exp(pl.sub(mi, mi_new)) - beta = pl.exp(pl.sub(cur_mi, mi_new)) - li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) - oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) - mi = mi_new - - ctx = pl.row_expand_div(oi, li) - attn_row = pl.assemble(attn_row, ctx, [0, q_col]) - - attn_out = pl.assemble(attn_out, attn_row, [b, 0]) + oi_tmp = pl.matmul( + pl.cast(exp_pad, target_type=pl.BF16), + v_tile, + out_dtype=pl.FP32, + ) + + if sb == 0: + oi = oi_tmp + li = cur_li + mi = cur_mi + else: + mi_new = pl.maximum(mi, cur_mi) + alpha = pl.exp(pl.sub(mi, mi_new)) + beta = pl.exp(pl.sub(cur_mi, mi_new)) + li = pl.add(pl.mul(alpha, li), pl.mul(beta, cur_li)) + oi = pl.add(pl.row_expand_mul(oi, alpha), pl.row_expand_mul(oi_tmp, beta)) + mi = mi_new + + ctx = pl.row_expand_div(oi, li) + attn_row = pl.assemble(attn_row, ctx, [0, q_col]) + + attn_out = pl.assemble(attn_out, attn_row, [b, 0]) # ========================================================================= # Scope 3: Output Projection + Residual + Post RMSNorm + SwiGLU MLP # ========================================================================= - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP): for b0 in pl.range(0, BATCH_CFG, BATCH_TILE): # Output projection + residual (first residual connection) resid1_tile = pl.create_tensor([BATCH_TILE, HIDDEN_CFG], dtype=pl.FP32) - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): - o0 = ob * Q_OUT_CHUNK - o_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) - o_acc = pl.mul(o_acc, 0.0) - for kb in pl.range(HIDDEN_BLOCKS): - k0 = kb * K_CHUNK - a_chunk = pl.cast( - pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]), - target_type=pl.BF16, - ) - w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) - o_acc = pl.add(o_acc, pl.matmul(a_chunk, w_chunk)) - resid = pl.cast( - pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), - target_type=pl.FP32, - ) - resid1_tile = pl.assemble(resid1_tile, pl.add(o_acc, resid), [0, o0]) + for ob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + o0 = ob * Q_OUT_CHUNK + o_acc = pl.create_tensor([BATCH_TILE, Q_OUT_CHUNK], dtype=pl.FP32) + o_acc = pl.mul(o_acc, 0.0) + for kb in pl.range(HIDDEN_BLOCKS): + k0 = kb * K_CHUNK + a_chunk = pl.cast( + pl.slice(attn_out, [BATCH_TILE, K_CHUNK], [b0, k0]), + target_type=pl.BF16, + ) + w_chunk = pl.slice(wo, [K_CHUNK, Q_OUT_CHUNK], [k0, o0]) + o_acc = pl.add(o_acc, pl.matmul(a_chunk, w_chunk)) + resid = pl.cast( + pl.slice(hidden_states, [BATCH_TILE, Q_OUT_CHUNK], [b0, o0]), + target_type=pl.FP32, + ) + resid1_tile = pl.assemble(resid1_tile, pl.add(o_acc, resid), [0, o0]) # Post RMSNorm (before MLP) sq_sum = pl.create_tensor([BATCH_TILE, 1], dtype=pl.FP32) @@ -383,21 +395,25 @@ def milm_decode_layer( mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) # Down projection - for dob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - d0 = dob * Q_OUT_CHUNK - down_prev = pl.slice(down_proj_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) - w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, Q_OUT_CHUNK], [o0, d0]) - down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) - down_proj_tile = pl.assemble(down_proj_tile, down_next, [0, d0]) + for dob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for dob in pl.range(dob_chunk, dob_chunk + 4): + d0 = dob * Q_OUT_CHUNK + down_prev = pl.slice(down_proj_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, d0]) + w_down_chunk = pl.slice(w_down, [MLP_OUT_CHUNK, Q_OUT_CHUNK], [o0, d0]) + down_next = pl.add(down_prev, pl.matmul(mlp_chunk_bf16, w_down_chunk)) + down_proj_tile = pl.assemble(down_proj_tile, down_next, [0, d0]) # Final residual connection - for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=4): - o0 = ob * Q_OUT_CHUNK - down_acc = pl.add( - pl.slice(down_proj_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]), - pl.slice(resid1_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]), - ) - out = pl.assemble(out, pl.cast(down_acc, target_type=pl.BF16), [b0, o0]) + for ob_chunk in pl.parallel(0, Q_OUT_BLOCKS, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + o0 = ob * Q_OUT_CHUNK + down_acc = pl.add( + pl.slice(down_proj_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]), + pl.slice(resid1_tile, [BATCH_TILE, Q_OUT_CHUNK], [0, o0]), + ) + out = pl.assemble(out, pl.cast(down_acc, target_type=pl.BF16), [b0, o0]) return out diff --git a/models/qwen3/14b/qwen3_14b_l3_generate.py b/models/qwen3/14b/qwen3_14b_l3_generate.py index d997c058..fa34339a 100644 --- a/models/qwen3/14b/qwen3_14b_l3_generate.py +++ b/models/qwen3/14b/qwen3_14b_l3_generate.py @@ -247,66 +247,74 @@ def qwen3_prefill_all( ) q_proj_tile = pl.create_tensor([TOK_TILE, hidden], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_q_proj"): - for ob in pl.parallel(q_out_blocks, chunk=4): - q0 = ob * Q_OUT_CHUNK - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_w = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [layer_off_h, q0]) - q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_w_i = pl.slice( - wq, [K_CHUNK, Q_OUT_CHUNK], [layer_off_h + k0, q0] - ) - q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i) - q_proj_tile = pl.assemble(q_proj_tile, q_acc, [0, q0]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_q_proj"): + for ob_chunk in pl.parallel(0, q_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + q0 = ob * Q_OUT_CHUNK + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_w = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [layer_off_h, q0]) + q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_w_i = pl.slice( + wq, [K_CHUNK, Q_OUT_CHUNK], [layer_off_h + k0, q0] + ) + q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i) + q_proj_tile = pl.assemble(q_proj_tile, q_acc, [0, q0]) k_proj_tile = pl.create_tensor([TOK_TILE, kv_hidden], dtype=pl.FP32) v_proj_tile = pl.create_tensor([TOK_TILE, kv_hidden], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_kv_proj"): - for ob in pl.parallel(kv_out_blocks, chunk=4): - kv0 = ob * KV_OUT_CHUNK - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0]) - k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_wk_i = pl.slice( - wk, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h + k0, kv0] - ) - k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj_tile = pl.assemble(k_proj_tile, k_acc, [0, kv0]) - - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0]) - v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_wv_i = pl.slice( - wv, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h + k0, kv0] - ) - v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj_tile = pl.assemble(v_proj_tile, v_acc, [0, kv0]) - - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_q_norm"): - for qh in pl.parallel(0, num_heads, chunk=num_heads): - q_col = qh * head_dim - q_head = pl.slice(q_proj_tile, [TOK_TILE, head_dim], [0, q_col]) - q_sq = pl.reshape(pl.row_sum(pl.mul(q_head, q_head)), [TOK_TILE, 1]) - q_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_sq, head_dim_inv), EPS))) - q_normed = pl.col_expand_mul(pl.row_expand_mul(q_head, q_inv_rms), q_norm_w) - q_proj_tile = pl.assemble(q_proj_tile, q_normed, [0, q_col]) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_k_norm"): - for kh in pl.parallel(0, num_kv_heads, chunk=num_kv_heads): - k_col = kh * head_dim - k_head = pl.slice(k_proj_tile, [TOK_TILE, head_dim], [0, k_col]) - k_sq = pl.reshape(pl.row_sum(pl.mul(k_head, k_head)), [TOK_TILE, 1]) - k_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(k_sq, head_dim_inv), EPS))) - k_normed = pl.col_expand_mul(pl.row_expand_mul(k_head, k_inv_rms), k_norm_w) - k_proj_tile = pl.assemble(k_proj_tile, k_normed, [0, k_col]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_kv_proj"): + for ob_chunk in pl.parallel(0, kv_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + kv0 = ob * KV_OUT_CHUNK + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0]) + k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_wk_i = pl.slice( + wk, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h + k0, kv0] + ) + k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) + k_proj_tile = pl.assemble(k_proj_tile, k_acc, [0, kv0]) + + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0]) + v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_wv_i = pl.slice( + wv, [K_CHUNK, KV_OUT_CHUNK], [layer_off_h + k0, kv0] + ) + v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) + v_proj_tile = pl.assemble(v_proj_tile, v_acc, [0, kv0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_q_norm"): + for qh_chunk in pl.parallel(0, num_heads, num_heads): + with pl.at(level=pl.Level.CORE_GROUP): + for qh in pl.range(qh_chunk, qh_chunk + num_heads): + q_col = qh * head_dim + q_head = pl.slice(q_proj_tile, [TOK_TILE, head_dim], [0, q_col]) + q_sq = pl.reshape(pl.row_sum(pl.mul(q_head, q_head)), [TOK_TILE, 1]) + q_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_sq, head_dim_inv), EPS))) + q_normed = pl.col_expand_mul(pl.row_expand_mul(q_head, q_inv_rms), q_norm_w) + q_proj_tile = pl.assemble(q_proj_tile, q_normed, [0, q_col]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_k_norm"): + for kh_chunk in pl.parallel(0, num_kv_heads, num_kv_heads): + with pl.at(level=pl.Level.CORE_GROUP): + for kh in pl.range(kh_chunk, kh_chunk + num_kv_heads): + k_col = kh * head_dim + k_head = pl.slice(k_proj_tile, [TOK_TILE, head_dim], [0, k_col]) + k_sq = pl.reshape(pl.row_sum(pl.mul(k_head, k_head)), [TOK_TILE, 1]) + k_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(k_sq, head_dim_inv), EPS))) + k_normed = pl.col_expand_mul(pl.row_expand_mul(k_head, k_inv_rms), k_norm_w) + k_proj_tile = pl.assemble(k_proj_tile, k_normed, [0, k_col]) # ── Scope 2: RoPE + KV cache update + causal attention ── attn_tile = pl.create_tensor([TOK_TILE, hidden], dtype=pl.BF16) @@ -324,101 +332,105 @@ def qwen3_prefill_all( all_q_padded = pl.create_tensor( [total_q_groups * Q_HEAD_PAD, head_dim], dtype=pl.BF16 ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_q_pad"): - for gi in pl.parallel(0, total_q_groups, chunk=total_q_groups): - all_q_padded = pl.assemble( - all_q_padded, - pl.cast( - pl.full( - [Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], - dtype=pl.FP32, - value=0.0, - ), - target_type=pl.BF16, - ), - [gi * Q_HEAD_PAD + Q_HEAD_BATCH, 0], - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_q_pad"): + for gi_chunk in pl.parallel(0, total_q_groups, total_q_groups): + with pl.at(level=pl.Level.CORE_GROUP): + for gi in pl.range(gi_chunk, gi_chunk + total_q_groups): + all_q_padded = pl.assemble( + all_q_padded, + pl.cast( + pl.full( + [Q_HEAD_PAD - Q_HEAD_BATCH, head_dim], + dtype=pl.FP32, + value=0.0, + ), + target_type=pl.BF16, + ), + [gi * Q_HEAD_PAD + Q_HEAD_BATCH, 0], + ) cache_slot = pl.cast( pl.tensor.read(slot_mapping, [b * max_seq + pos]), pl.INDEX ) cache_slot_block = cache_slot // BLOCK_SIZE cache_slot_offset = cache_slot - cache_slot_block * BLOCK_SIZE - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_rope_kv_cache"): - for ki in pl.parallel(0, num_kv_heads, chunk=8): - kv_col = ki * head_dim - k_lo = pl.reshape( - pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col]), [1, half_dim] - ) - k_hi = pl.reshape( - pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col + half_dim]), - [1, half_dim], - ) - rot_lo = pl.sub( - pl.col_expand_mul(k_lo, cos_lo), - pl.col_expand_mul(k_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(k_hi, cos_hi), - pl.col_expand_mul(k_lo, sin_hi), - ) - cache_row = ( - (cache_slot_block * num_kv_heads + ki) * BLOCK_SIZE - + cache_slot_offset - ) - k_cache_all = pl.assemble( - k_cache_all, - pl.cast(rot_lo, target_type=pl.BF16), - [layer_off_cache + cache_row, 0], - ) - k_cache_all = pl.assemble( - k_cache_all, - pl.cast(rot_hi, target_type=pl.BF16), - [layer_off_cache + cache_row, half_dim], - ) - v_cache_all = pl.assemble( - v_cache_all, - pl.cast( - pl.reshape( - pl.slice(v_proj_tile, [1, head_dim], [ti, ki * head_dim]), - [1, head_dim], - ), - target_type=pl.BF16, - ), - [layer_off_cache + cache_row, 0], - ) - q_base = ki * q_per_kv - for qi in pl.range(Q_HEAD_BATCH): - q_col = (q_base + qi) * head_dim - q_lo = pl.reshape( - pl.slice(q_proj_tile, [1, half_dim], [ti, q_col]), - [1, half_dim], - ) - q_hi = pl.reshape( - pl.slice(q_proj_tile, [1, half_dim], [ti, q_col + half_dim]), - [1, half_dim], - ) - rot_lo_bf16 = pl.cast( - pl.sub( - pl.col_expand_mul(q_lo, cos_lo), - pl.col_expand_mul(q_hi, sin_lo), - ), - target_type=pl.BF16, - ) - rot_hi_bf16 = pl.cast( - pl.add( - pl.col_expand_mul(q_hi, cos_hi), - pl.col_expand_mul(q_lo, sin_hi), - ), - target_type=pl.BF16, - ) - all_q_padded = pl.assemble( - all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0] - ) - all_q_padded = pl.assemble( - all_q_padded, - rot_hi_bf16, - [ki * Q_HEAD_PAD + qi, half_dim], - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_rope_kv_cache"): + for ki_chunk in pl.parallel(0, num_kv_heads, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ki in pl.range(ki_chunk, ki_chunk + 8): + kv_col = ki * head_dim + k_lo = pl.reshape( + pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col]), [1, half_dim] + ) + k_hi = pl.reshape( + pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col + half_dim]), + [1, half_dim], + ) + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + cache_row = ( + (cache_slot_block * num_kv_heads + ki) * BLOCK_SIZE + + cache_slot_offset + ) + k_cache_all = pl.assemble( + k_cache_all, + pl.cast(rot_lo, target_type=pl.BF16), + [layer_off_cache + cache_row, 0], + ) + k_cache_all = pl.assemble( + k_cache_all, + pl.cast(rot_hi, target_type=pl.BF16), + [layer_off_cache + cache_row, half_dim], + ) + v_cache_all = pl.assemble( + v_cache_all, + pl.cast( + pl.reshape( + pl.slice(v_proj_tile, [1, head_dim], [ti, ki * head_dim]), + [1, head_dim], + ), + target_type=pl.BF16, + ), + [layer_off_cache + cache_row, 0], + ) + q_base = ki * q_per_kv + for qi in pl.range(Q_HEAD_BATCH): + q_col = (q_base + qi) * head_dim + q_lo = pl.reshape( + pl.slice(q_proj_tile, [1, half_dim], [ti, q_col]), + [1, half_dim], + ) + q_hi = pl.reshape( + pl.slice(q_proj_tile, [1, half_dim], [ti, q_col + half_dim]), + [1, half_dim], + ) + rot_lo_bf16 = pl.cast( + pl.sub( + pl.col_expand_mul(q_lo, cos_lo), + pl.col_expand_mul(q_hi, sin_lo), + ), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add( + pl.col_expand_mul(q_hi, cos_hi), + pl.col_expand_mul(q_lo, sin_hi), + ), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble( + all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0] + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_hi_bf16, + [ki * Q_HEAD_PAD + qi, half_dim], + ) attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) for gi in pl.range(total_q_groups): @@ -444,74 +456,80 @@ def qwen3_prefill_all( [max_ctx_blocks * Q_HEAD_BATCH_PAD, 1], dtype=pl.FP32 ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_qk_matmul"): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - block_table_idx = b * max_blocks_per_seq + sb - pbid = pl.cast( - pl.tensor.read(block_table, [block_table_idx]), pl.INDEX - ) - cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE - k_tile = pl.slice( - k_cache_all, - [SEQ_TILE, head_dim], - [layer_off_cache + cache_row0, 0], - ) - raw_scores = pl.matmul( - q_padded, k_tile, b_trans=True, out_dtype=pl.FP32 - ) - all_raw_scores = pl.assemble( - all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0] - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_qk_matmul"): + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + block_table_idx = b * max_blocks_per_seq + sb + pbid = pl.cast( + pl.tensor.read(block_table, [block_table_idx]), pl.INDEX + ) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + k_tile = pl.slice( + k_cache_all, + [SEQ_TILE, head_dim], + [layer_off_cache + cache_row0, 0], + ) + raw_scores = pl.matmul( + q_padded, k_tile, b_trans=True, out_dtype=pl.FP32 + ) + all_raw_scores = pl.assemble( + all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0] + ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_softmax"): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, - [Q_HEAD_BATCH_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], - ) - scores_padded = pl.fillpad( - scores_valid, pad_value=pl.PadValue.min - ) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - cur_li = pl.row_sum( - pl.cast(exp_scores_bf16, target_type=pl.FP32) - ) - all_exp_padded = pl.assemble( - all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0] - ) - all_cur_mi = pl.assemble( - all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH_PAD, 0] - ) - all_cur_li = pl.assemble( - all_cur_li, cur_li, [sb * Q_HEAD_BATCH_PAD, 0] - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_softmax"): + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + scores_valid = pl.slice( + all_raw_scores, + [Q_HEAD_BATCH_PAD, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_BATCH, valid_len], + ) + scores_padded = pl.fillpad( + scores_valid, pad_value=pl.PadValue.min + ) + scores = pl.mul(scores_padded, attn_scale) + cur_mi = pl.row_max(scores) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + cur_li = pl.row_sum( + pl.cast(exp_scores_bf16, target_type=pl.FP32) + ) + all_exp_padded = pl.assemble( + all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0] + ) + all_cur_mi = pl.assemble( + all_cur_mi, cur_mi, [sb * Q_HEAD_BATCH_PAD, 0] + ) + all_cur_li = pl.assemble( + all_cur_li, cur_li, [sb * Q_HEAD_BATCH_PAD, 0] + ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_sv_matmul"): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - block_table_idx = b * max_blocks_per_seq + sb - pbid = pl.cast( - pl.tensor.read(block_table, [block_table_idx]), pl.INDEX - ) - cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE - exp_tile = pl.slice( - all_exp_padded, [Q_HEAD_PAD, SEQ_TILE], [sb * Q_HEAD_PAD, 0] - ) - v_tile = pl.slice( - v_cache_all, - [SEQ_TILE, head_dim], - [layer_off_cache + cache_row0, 0], - ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble( - all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0] - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_sv_matmul"): + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + block_table_idx = b * max_blocks_per_seq + sb + pbid = pl.cast( + pl.tensor.read(block_table, [block_table_idx]), pl.INDEX + ) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + exp_tile = pl.slice( + all_exp_padded, [Q_HEAD_PAD, SEQ_TILE], [sb * Q_HEAD_PAD, 0] + ) + v_tile = pl.slice( + v_cache_all, + [SEQ_TILE, head_dim], + [layer_off_cache + cache_row0, 0], + ) + oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + all_oi_tmp = pl.assemble( + all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0] + ) with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_online_softmax_init"): oi = pl.full([Q_HEAD_BATCH_PAD, head_dim], dtype=pl.FP32, value=0.0) @@ -669,7 +687,7 @@ def qwen3_prefill_all( ) up_acc = pl.matmul_acc(up_acc, pci, wui) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="prefill_silu"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_silu"): sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_silu_tile = pl.assemble( @@ -832,59 +850,63 @@ def qwen3_decode_all( [0, k0], ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_q_proj"): - for ob in pl.parallel(q_out_blocks, chunk=4): - q0 = ob * Q_OUT_CHUNK - tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) - tile_b = pl.slice( - wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [layer_off_h, q0] - ) - q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) - for kb in pl.range(1, scope1_hidden_blocks): - k0 = kb * SCOPE1_K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) - tile_b_i = pl.slice( - wq, - [SCOPE1_K_CHUNK, Q_OUT_CHUNK], - [layer_off_h + k0, q0], - ) - q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) - q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) - - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_kv_proj"): - for ob in pl.parallel(kv_out_blocks, chunk=4): - kv0 = ob * KV_OUT_CHUNK - tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) - tile_wk = pl.slice( - wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0] - ) - k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) - for kb in pl.range(1, scope1_hidden_blocks): - k0 = kb * SCOPE1_K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) - tile_wk_i = pl.slice( - wk, - [SCOPE1_K_CHUNK, KV_OUT_CHUNK], - [layer_off_h + k0, kv0], - ) - k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_q_proj"): + for ob_chunk in pl.parallel(0, q_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + q0 = ob * Q_OUT_CHUNK + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_b = pl.slice( + wq, [SCOPE1_K_CHUNK, Q_OUT_CHUNK], [layer_off_h, q0] + ) + q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_b_i = pl.slice( + wq, + [SCOPE1_K_CHUNK, Q_OUT_CHUNK], + [layer_off_h + k0, q0], + ) + q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_b_i) + q_proj = pl.assemble(q_proj, q_acc, [b0, q0]) + + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_kv_proj"): + for ob_chunk in pl.parallel(0, kv_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + kv0 = ob * KV_OUT_CHUNK + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_wk = pl.slice( + wk, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0] + ) + k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_wk_i = pl.slice( + wk, + [SCOPE1_K_CHUNK, KV_OUT_CHUNK], + [layer_off_h + k0, kv0], + ) + k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) + k_proj = pl.assemble(k_proj, k_acc, [b0, kv0]) - tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) - tile_wv = pl.slice( - wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0] - ) - v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - for kb in pl.range(1, scope1_hidden_blocks): - k0 = kb * SCOPE1_K_CHUNK - tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) - tile_wv_i = pl.slice( - wv, - [SCOPE1_K_CHUNK, KV_OUT_CHUNK], - [layer_off_h + k0, kv0], - ) - v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) + tile_a = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, 0]) + tile_wv = pl.slice( + wv, [SCOPE1_K_CHUNK, KV_OUT_CHUNK], [layer_off_h, kv0] + ) + v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) + for kb in pl.range(1, scope1_hidden_blocks): + k0 = kb * SCOPE1_K_CHUNK + tile_a_i = pl.slice(normed_tile, [BATCH_TILE, SCOPE1_K_CHUNK], [0, k0]) + tile_wv_i = pl.slice( + wv, + [SCOPE1_K_CHUNK, KV_OUT_CHUNK], + [layer_off_h + k0, kv0], + ) + v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) + v_proj = pl.assemble(v_proj, v_acc, [b0, kv0]) # HF-style per-head Q/K norm before RoPE. for b0 in pl.parallel(0, batch_padded, BATCH_TILE): @@ -942,67 +964,69 @@ def qwen3_decode_all( sin_lo = pl.slice(sin_row, [1, half_dim], [0, 0]) sin_hi = pl.slice(sin_row, [1, half_dim], [0, half_dim]) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_rope_kv_cache"): - for ki in pl.parallel(0, num_kv_heads, chunk=8): - kv_col = ki * head_dim - cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset - k_lo = pl.slice(k_proj_norm, [1, half_dim], [b, kv_col]) - k_hi = pl.slice(k_proj_norm, [1, half_dim], [b, kv_col + half_dim]) - rot_lo = pl.sub( - pl.col_expand_mul(k_lo, cos_lo), - pl.col_expand_mul(k_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(k_hi, cos_hi), - pl.col_expand_mul(k_lo, sin_hi), - ) - k_cache_all = pl.assemble( - k_cache_all, - pl.cast(rot_lo, target_type=pl.BF16), - [layer_off_cache + cache_row, 0], - ) - k_cache_all = pl.assemble( - k_cache_all, - pl.cast(rot_hi, target_type=pl.BF16), - [layer_off_cache + cache_row, half_dim], - ) - v_cache_all = pl.assemble( - v_cache_all, - pl.cast( - pl.slice(v_proj, [1, head_dim], [b, kv_col]), - target_type=pl.BF16, - ), - [layer_off_cache + cache_row, 0], - ) - q_base = ki * q_per_kv - for qi in pl.range(Q_HEAD_BATCH): - q_col = (q_base + qi) * head_dim - q_lo = pl.slice(q_proj_norm, [1, half_dim], [b, q_col]) - q_hi = pl.slice(q_proj_norm, [1, half_dim], [b, q_col + half_dim]) - rot_lo_bf16 = pl.cast( - pl.sub( - pl.col_expand_mul(q_lo, cos_lo), - pl.col_expand_mul(q_hi, sin_lo), - ), - target_type=pl.BF16, - ) - rot_hi_bf16 = pl.cast( - pl.add( - pl.col_expand_mul(q_hi, cos_hi), - pl.col_expand_mul(q_lo, sin_hi), - ), - target_type=pl.BF16, - ) - all_q_padded = pl.assemble( - all_q_padded, - rot_lo_bf16, - [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0], - ) - all_q_padded = pl.assemble( - all_q_padded, - rot_hi_bf16, - [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim], - ) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_rope_kv_cache"): + for ki_chunk in pl.parallel(0, num_kv_heads, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ki in pl.range(ki_chunk, ki_chunk + 8): + kv_col = ki * head_dim + cache_row = (slot_block * num_kv_heads + ki) * BLOCK_SIZE + slot_offset + k_lo = pl.slice(k_proj_norm, [1, half_dim], [b, kv_col]) + k_hi = pl.slice(k_proj_norm, [1, half_dim], [b, kv_col + half_dim]) + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), + ) + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + k_cache_all = pl.assemble( + k_cache_all, + pl.cast(rot_lo, target_type=pl.BF16), + [layer_off_cache + cache_row, 0], + ) + k_cache_all = pl.assemble( + k_cache_all, + pl.cast(rot_hi, target_type=pl.BF16), + [layer_off_cache + cache_row, half_dim], + ) + v_cache_all = pl.assemble( + v_cache_all, + pl.cast( + pl.slice(v_proj, [1, head_dim], [b, kv_col]), + target_type=pl.BF16, + ), + [layer_off_cache + cache_row, 0], + ) + q_base = ki * q_per_kv + for qi in pl.range(Q_HEAD_BATCH): + q_col = (q_base + qi) * head_dim + q_lo = pl.slice(q_proj_norm, [1, half_dim], [b, q_col]) + q_hi = pl.slice(q_proj_norm, [1, half_dim], [b, q_col + half_dim]) + rot_lo_bf16 = pl.cast( + pl.sub( + pl.col_expand_mul(q_lo, cos_lo), + pl.col_expand_mul(q_hi, sin_lo), + ), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add( + pl.col_expand_mul(q_hi, cos_hi), + pl.col_expand_mul(q_lo, sin_hi), + ), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_lo_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, 0], + ) + all_q_padded = pl.assemble( + all_q_padded, + rot_hi_bf16, + [b * total_q_groups * Q_HEAD_PAD + ki * Q_HEAD_PAD + qi, half_dim], + ) attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) attn_row_padded = pl.create_tensor( @@ -1025,7 +1049,7 @@ def qwen3_decode_all( [total_q_groups * max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32, ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_qk_matmul"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_qk_matmul"): for gi in pl.range(total_q_groups): kvh = gi // q_groups q_padded = pl.slice( @@ -1033,74 +1057,80 @@ def qwen3_decode_all( [Q_HEAD_PAD, head_dim], [b * total_q_groups * Q_HEAD_PAD + gi * Q_HEAD_PAD, 0], ) - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - block_table_idx = block_table_base + sb - pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) - cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE - k_tile = pl.slice( - k_cache_all, - [BLOCK_SIZE, head_dim], - [layer_off_cache + cache_row0, 0], - ) - raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) - all_raw_scores = pl.assemble( - all_raw_scores, raw_scores, - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + k_tile = pl.slice( + k_cache_all, + [BLOCK_SIZE, head_dim], + [layer_off_cache + cache_row0, 0], + ) + raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) + all_raw_scores = pl.assemble( + all_raw_scores, raw_scores, + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_softmax"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_softmax"): for gi in pl.range(total_q_groups): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * BLOCK_SIZE - valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, - [Q_HEAD_PAD, BLOCK_SIZE], - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_PAD, valid_len], - ) - scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) - cur_li = pl.row_sum(exp_scores_fp32) - all_exp_padded = pl.assemble( - all_exp_padded, exp_scores_bf16, - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) - all_cur_mi = pl.assemble( - all_cur_mi, cur_mi, - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) - all_cur_li = pl.assemble( - all_cur_li, cur_li, - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + s0 = sb * BLOCK_SIZE + valid_len = pl.min(BLOCK_SIZE, ctx_len - s0) + scores_valid = pl.slice( + all_raw_scores, + [Q_HEAD_PAD, BLOCK_SIZE], + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_PAD, valid_len], + ) + scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) + scores = pl.mul(scores_padded, attn_scale) + cur_mi = pl.row_max(scores) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + exp_scores_fp32 = pl.cast(exp_scores_bf16, target_type=pl.FP32) + cur_li = pl.row_sum(exp_scores_fp32) + all_exp_padded = pl.assemble( + all_exp_padded, exp_scores_bf16, + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) + all_cur_mi = pl.assemble( + all_cur_mi, cur_mi, + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) + all_cur_li = pl.assemble( + all_cur_li, cur_li, + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_sv_matmul"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_sv_matmul"): for gi in pl.range(total_q_groups): kvh = gi // q_groups - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - block_table_idx = block_table_base + sb - pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) - cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE - exp_tile = pl.slice( - all_exp_padded, - [Q_HEAD_PAD, BLOCK_SIZE], - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) - v_tile = pl.slice( - v_cache_all, - [BLOCK_SIZE, head_dim], - [layer_off_cache + cache_row0, 0], - ) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble( - all_oi_tmp, oi_tmp, - [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], - ) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + block_table_idx = block_table_base + sb + pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX) + cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE + exp_tile = pl.slice( + all_exp_padded, + [Q_HEAD_PAD, BLOCK_SIZE], + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) + v_tile = pl.slice( + v_cache_all, + [BLOCK_SIZE, head_dim], + [layer_off_cache + cache_row0, 0], + ) + oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + all_oi_tmp = pl.assemble( + all_oi_tmp, oi_tmp, + [(gi * max_ctx_blocks + sb) * Q_HEAD_PAD, 0], + ) with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_online_softmax"): for gi in pl.range(total_q_groups): @@ -1239,7 +1269,7 @@ def qwen3_decode_all( ) up_acc = pl.matmul_acc(up_acc, post_chunk, wu) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="decode_silu"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="decode_silu"): sigmoid = pl.recip(pl.add(pl.exp(pl.neg(gate_acc)), 1.0)) mlp_chunk = pl.mul(pl.mul(gate_acc, sigmoid), up_acc) mlp_chunk_bf16 = pl.cast(mlp_chunk, target_type=pl.BF16) @@ -1353,21 +1383,23 @@ def qwen3_lm_head( lm_head_weight: pl.Tensor[[padded_vocab, hidden], pl.BF16], out: pl.Out[pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32]], ) -> pl.Tensor[[BATCH_TILE, padded_vocab], pl.FP32]: - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="lm_head"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="lm_head"): for b0 in pl.range(0, BATCH_TILE, BATCH_TILE): - for ob in pl.parallel(vocab_blocks, chunk=8): - o0 = ob * VOCAB_CHUNK - h0 = pl.slice(hidden_in, [BATCH_TILE, K_CHUNK], [b0, 0]) - w0 = pl.slice(lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, 0]) - acc = pl.matmul(h0, w0, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - h_chunk = pl.slice(hidden_in, [BATCH_TILE, K_CHUNK], [b0, k0]) - w_chunk = pl.slice( - lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, k0] - ) - acc = pl.matmul_acc(acc, h_chunk, w_chunk, b_trans=True) - out = pl.assemble(out, acc, [b0, o0]) + for ob_chunk in pl.parallel(0, vocab_blocks, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + o0 = ob * VOCAB_CHUNK + h0 = pl.slice(hidden_in, [BATCH_TILE, K_CHUNK], [b0, 0]) + w0 = pl.slice(lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, 0]) + acc = pl.matmul(h0, w0, out_dtype=pl.FP32, b_trans=True) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + h_chunk = pl.slice(hidden_in, [BATCH_TILE, K_CHUNK], [b0, k0]) + w_chunk = pl.slice( + lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, k0] + ) + acc = pl.matmul_acc(acc, h_chunk, w_chunk, b_trans=True) + out = pl.assemble(out, acc, [b0, o0]) return out # ── L2: fused final-RMSNorm + LM-head (single chip task) ─────────────── @@ -1425,23 +1457,25 @@ def qwen3_rms_lmhead( ) # Phase 2 – LM-head GEMM: reads rms_normed written above (HBM). # Body identical to qwen3_lm_head with hidden_in → rms_normed. - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk], name_hint="rms_lmhead_lm_head"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rms_lmhead_lm_head"): for b0 in pl.range(0, BATCH_TILE, BATCH_TILE): - for ob in pl.parallel(vocab_blocks, chunk=8): - o0 = ob * VOCAB_CHUNK - h0 = pl.slice(rms_normed, [BATCH_TILE, K_CHUNK], [b0, 0]) - w0 = pl.slice(lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, 0]) - acc = pl.matmul(h0, w0, out_dtype=pl.FP32, b_trans=True) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - h_chunk = pl.slice( - rms_normed, [BATCH_TILE, K_CHUNK], [b0, k0] - ) - w_chunk = pl.slice( - lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, k0] - ) - acc = pl.matmul_acc(acc, h_chunk, w_chunk, b_trans=True) - out = pl.assemble(out, acc, [b0, o0]) + for ob_chunk in pl.parallel(0, vocab_blocks, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 8): + o0 = ob * VOCAB_CHUNK + h0 = pl.slice(rms_normed, [BATCH_TILE, K_CHUNK], [b0, 0]) + w0 = pl.slice(lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, 0]) + acc = pl.matmul(h0, w0, out_dtype=pl.FP32, b_trans=True) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + h_chunk = pl.slice( + rms_normed, [BATCH_TILE, K_CHUNK], [b0, k0] + ) + w_chunk = pl.slice( + lm_head_weight, [VOCAB_CHUNK, K_CHUNK], [o0, k0] + ) + acc = pl.matmul_acc(acc, h_chunk, w_chunk, b_trans=True) + out = pl.assemble(out, acc, [b0, o0]) return rms_normed, out # ── HOST SubWorker: sample & prepare next decode inputs ───────────────── diff --git a/models/qwen3/32b/qwen3_32b_decode.py b/models/qwen3/32b/qwen3_32b_decode.py index d503893c..19065e19 100644 --- a/models/qwen3/32b/qwen3_32b_decode.py +++ b/models/qwen3/32b/qwen3_32b_decode.py @@ -331,7 +331,7 @@ def qwen3_decode( # Stage 1 & 2: Output projection + residual addition with hidden_states. for ob in pl.parallel(0, HIDDEN // Q_OUT_CHUNK, 2): - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="out_proj_residual"): + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)], name_hint="out_proj_residual"): for oi in pl.range(ob, ob + 2): o0 = oi * Q_OUT_CHUNK hidden_chunk = hidden_states[:, o0 : o0 + Q_OUT_CHUNK] @@ -422,7 +422,7 @@ def qwen3_decode( # Stage 7 & 8: Down projection + final residual writeback. for db in pl.parallel(0, HIDDEN // DOWN_N_CHUNK, 2): - with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.auto_chunk, pl.split(pl.SplitMode.UP_DOWN)], name_hint="down_proj_residual"): + with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)], name_hint="down_proj_residual"): for di in pl.range(db, db + 2): d0 = di * DOWN_N_CHUNK resid1_tile_chunk = resid1_tile[:, d0 : d0 + DOWN_N_CHUNK] diff --git a/models/qwen3/32b/qwen3_32b_prefill_draft.py b/models/qwen3/32b/qwen3_32b_prefill_draft.py index 430c15e1..02186a0d 100644 --- a/models/qwen3/32b/qwen3_32b_prefill_draft.py +++ b/models/qwen3/32b/qwen3_32b_prefill_draft.py @@ -136,45 +136,47 @@ def prefill_scope123( # Stage 1.2: Q projection (matmul + matmul_acc, FP32 output). q_proj_tile = pl.create_tensor([TOK_TILE, hidden], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for ob in pl.parallel(q_out_blocks, chunk=4): - q0 = ob * Q_OUT_CHUNK - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_w = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) - q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_w_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) - q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i) - q_proj_tile = pl.assemble(q_proj_tile, q_acc, [0, q0]) + for ob_chunk in pl.parallel(0, q_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + q0 = ob * Q_OUT_CHUNK + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_w = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [0, q0]) + q_acc = pl.matmul(tile_a, tile_w, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_w_i = pl.slice(wq, [K_CHUNK, Q_OUT_CHUNK], [k0, q0]) + q_acc = pl.matmul_acc(q_acc, tile_a_i, tile_w_i) + q_proj_tile = pl.assemble(q_proj_tile, q_acc, [0, q0]) # Stage 1.3: K/V projection (matmul + matmul_acc in single incore). k_proj_tile = pl.create_tensor([TOK_TILE, kv_hidden], dtype=pl.FP32) v_proj_tile = pl.create_tensor([TOK_TILE, kv_hidden], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for ob in pl.parallel(kv_out_blocks, chunk=4): - kv0 = ob * KV_OUT_CHUNK - - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) - k_proj_tile = pl.assemble(k_proj_tile, k_acc, [0, kv0]) - - tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) - tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) - v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) - for kb in pl.range(1, hidden_blocks): - k0 = kb * K_CHUNK - tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) - tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) - v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) - v_proj_tile = pl.assemble(v_proj_tile, v_acc, [0, kv0]) + for ob_chunk in pl.parallel(0, kv_out_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for ob in pl.range(ob_chunk, ob_chunk + 4): + kv0 = ob * KV_OUT_CHUNK + + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_wk = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + k_acc = pl.matmul(tile_a, tile_wk, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_wk_i = pl.slice(wk, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + k_acc = pl.matmul_acc(k_acc, tile_a_i, tile_wk_i) + k_proj_tile = pl.assemble(k_proj_tile, k_acc, [0, kv0]) + + tile_a = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, 0]) + tile_wv = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [0, kv0]) + v_acc = pl.matmul(tile_a, tile_wv, out_dtype=pl.FP32) + for kb in pl.range(1, hidden_blocks): + k0 = kb * K_CHUNK + tile_a_i = pl.slice(normed_tile, [TOK_TILE, K_CHUNK], [0, k0]) + tile_wv_i = pl.slice(wv, [K_CHUNK, KV_OUT_CHUNK], [k0, kv0]) + v_acc = pl.matmul_acc(v_acc, tile_a_i, tile_wv_i) + v_proj_tile = pl.assemble(v_proj_tile, v_acc, [0, kv0]) # ── Scope 2: RoPE + KV cache update + causal attention ── attn_tile = pl.create_tensor([TOK_TILE, hidden], dtype=pl.BF16) @@ -198,62 +200,63 @@ def prefill_scope123( pl.cast(pl.full([Q_HEAD_PAD, head_dim], dtype=pl.FP32, value=0.0), target_type=pl.BF16), [gi * Q_HEAD_PAD, 0], ) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for ki in pl.parallel(0, num_kv_heads, chunk=8): - # K RoPE + cache update. - kv_col = ki * head_dim - k_lo = pl.reshape(pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col]), [1, half_dim]) - k_hi = pl.reshape(pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col + half_dim]), [1, half_dim]) - rot_lo = pl.sub( - pl.col_expand_mul(k_lo, cos_lo), - pl.col_expand_mul(k_hi, sin_lo), - ) - rot_hi = pl.add( - pl.col_expand_mul(k_hi, cos_hi), - pl.col_expand_mul(k_lo, sin_hi), - ) - cache_row = b * num_kv_heads * max_seq + ki * max_seq + pos - k_cache = pl.assemble( - k_cache, - pl.cast(rot_lo, target_type=pl.BF16), - [cache_row, 0], - ) - k_cache = pl.assemble( - k_cache, - pl.cast(rot_hi, target_type=pl.BF16), - [cache_row, half_dim], - ) - # V cache update. - v_cache = pl.assemble( - v_cache, - pl.cast( - pl.reshape(pl.slice(v_proj_tile, [1, head_dim], [ti, ki * head_dim]), [1, head_dim]), - target_type=pl.BF16, - ), - [cache_row, 0], - ) - # Q RoPE + pad. - q_base = ki * q_per_kv - for qi in pl.range(Q_HEAD_BATCH): - q_col = (q_base + qi) * head_dim - q_lo = pl.reshape(pl.slice(q_proj_tile, [1, half_dim], [ti, q_col]), [1, half_dim]) - q_hi = pl.reshape(pl.slice(q_proj_tile, [1, half_dim], [ti, q_col + half_dim]), [1, half_dim]) - rot_lo_bf16 = pl.cast( - pl.sub( - pl.col_expand_mul(q_lo, cos_lo), - pl.col_expand_mul(q_hi, sin_lo), - ), - target_type=pl.BF16, + for ki_chunk in pl.parallel(0, num_kv_heads, 8): + with pl.at(level=pl.Level.CORE_GROUP): + for ki in pl.range(ki_chunk, ki_chunk + 8): + # K RoPE + cache update. + kv_col = ki * head_dim + k_lo = pl.reshape(pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col]), [1, half_dim]) + k_hi = pl.reshape(pl.slice(k_proj_tile, [1, half_dim], [ti, kv_col + half_dim]), [1, half_dim]) + rot_lo = pl.sub( + pl.col_expand_mul(k_lo, cos_lo), + pl.col_expand_mul(k_hi, sin_lo), ) - rot_hi_bf16 = pl.cast( - pl.add( - pl.col_expand_mul(q_hi, cos_hi), - pl.col_expand_mul(q_lo, sin_hi), + rot_hi = pl.add( + pl.col_expand_mul(k_hi, cos_hi), + pl.col_expand_mul(k_lo, sin_hi), + ) + cache_row = b * num_kv_heads * max_seq + ki * max_seq + pos + k_cache = pl.assemble( + k_cache, + pl.cast(rot_lo, target_type=pl.BF16), + [cache_row, 0], + ) + k_cache = pl.assemble( + k_cache, + pl.cast(rot_hi, target_type=pl.BF16), + [cache_row, half_dim], + ) + # V cache update. + v_cache = pl.assemble( + v_cache, + pl.cast( + pl.reshape(pl.slice(v_proj_tile, [1, head_dim], [ti, ki * head_dim]), [1, head_dim]), + target_type=pl.BF16, ), - target_type=pl.BF16, + [cache_row, 0], ) - all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0]) - all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [ki * Q_HEAD_PAD + qi, half_dim]) + # Q RoPE + pad. + q_base = ki * q_per_kv + for qi in pl.range(Q_HEAD_BATCH): + q_col = (q_base + qi) * head_dim + q_lo = pl.reshape(pl.slice(q_proj_tile, [1, half_dim], [ti, q_col]), [1, half_dim]) + q_hi = pl.reshape(pl.slice(q_proj_tile, [1, half_dim], [ti, q_col + half_dim]), [1, half_dim]) + rot_lo_bf16 = pl.cast( + pl.sub( + pl.col_expand_mul(q_lo, cos_lo), + pl.col_expand_mul(q_hi, sin_lo), + ), + target_type=pl.BF16, + ) + rot_hi_bf16 = pl.cast( + pl.add( + pl.col_expand_mul(q_hi, cos_hi), + pl.col_expand_mul(q_lo, sin_hi), + ), + target_type=pl.BF16, + ) + all_q_padded = pl.assemble(all_q_padded, rot_lo_bf16, [ki * Q_HEAD_PAD + qi, 0]) + all_q_padded = pl.assemble(all_q_padded, rot_hi_bf16, [ki * Q_HEAD_PAD + qi, half_dim]) attn_row = pl.create_tensor([1, hidden], dtype=pl.BF16) @@ -271,43 +274,46 @@ def prefill_scope123( all_cur_li = pl.create_tensor([max_ctx_blocks * Q_HEAD_PAD, 1], dtype=pl.FP32) # Stage 2.2: QK matmul for all active sb blocks. - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - k_tile = pl.slice(k_cache, [SEQ_TILE, head_dim], [cache_row0, 0]) - raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) - all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + k_tile = pl.slice(k_cache, [SEQ_TILE, head_dim], [cache_row0, 0]) + raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32) + all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) # Stage 2.3: softmax for all active sb blocks. - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * SEQ_TILE - valid_len = pl.min(SEQ_TILE, ctx_len - s0) - scores_valid = pl.slice( - all_raw_scores, [Q_HEAD_PAD, SEQ_TILE], - [sb * Q_HEAD_PAD, 0], - valid_shape=[Q_HEAD_BATCH, valid_len], - ) - scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) - scores = pl.mul(scores_padded, attn_scale) - cur_mi = pl.row_max(scores) - exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) - exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) - cur_li = pl.row_sum(pl.cast(exp_scores_bf16, target_type=pl.FP32)) - all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) - all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_PAD, 0]) - all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_PAD, 0]) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + s0 = sb * SEQ_TILE + valid_len = pl.min(SEQ_TILE, ctx_len - s0) + scores_valid = pl.slice( + all_raw_scores, [Q_HEAD_PAD, SEQ_TILE], + [sb * Q_HEAD_PAD, 0], + valid_shape=[Q_HEAD_BATCH, valid_len], + ) + scores_padded = pl.fillpad(scores_valid, pad_value=pl.PadValue.min) + scores = pl.mul(scores_padded, attn_scale) + cur_mi = pl.row_max(scores) + exp_scores = pl.exp(pl.row_expand_sub(scores, cur_mi)) + exp_scores_bf16 = pl.cast(exp_scores, target_type=pl.BF16) + cur_li = pl.row_sum(pl.cast(exp_scores_bf16, target_type=pl.FP32)) + all_exp_padded = pl.assemble(all_exp_padded, exp_scores_bf16, [sb * Q_HEAD_PAD, 0]) + all_cur_mi = pl.assemble(all_cur_mi, cur_mi, [sb * Q_HEAD_PAD, 0]) + all_cur_li = pl.assemble(all_cur_li, cur_li, [sb * Q_HEAD_PAD, 0]) # Stage 2.4: SV matmul for all active sb blocks. - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for sb in pl.parallel(ctx_blocks, chunk=SB_BATCH): - s0 = sb * SEQ_TILE - cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 - exp_tile = pl.slice(all_exp_padded, [Q_HEAD_PAD, SEQ_TILE], [sb * Q_HEAD_PAD, 0]) - v_tile = pl.slice(v_cache, [SEQ_TILE, head_dim], [cache_row0, 0]) - oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) - all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) + for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): + with pl.at(level=pl.Level.CORE_GROUP): + for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): + s0 = sb * SEQ_TILE + cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0 + exp_tile = pl.slice(all_exp_padded, [Q_HEAD_PAD, SEQ_TILE], [sb * Q_HEAD_PAD, 0]) + v_tile = pl.slice(v_cache, [SEQ_TILE, head_dim], [cache_row0, 0]) + oi_tmp = pl.matmul(exp_tile, v_tile, out_dtype=pl.FP32) + all_oi_tmp = pl.assemble(all_oi_tmp, oi_tmp, [sb * Q_HEAD_PAD, 0]) # Stage 2.5: online softmax accumulation. with pl.at(level=pl.Level.CORE_GROUP): @@ -429,18 +435,19 @@ def prefill_scope123( # Stage 3.3b: Down projection (matmul_acc chain over intermediate dim). down_fp32_tile = pl.create_tensor([TOK_TILE, hidden], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer): - for dob in pl.parallel(hidden_blocks, chunk=4): - d0 = dob * K_CHUNK - dn_a = pl.slice(mlp_silu_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, 0]) - dn_w = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) - down_acc = pl.matmul(dn_a, dn_w, out_dtype=pl.FP32) - for ob in pl.range(1, mlp_out_blocks): - o0 = ob * MLP_OUT_CHUNK - dn_a_i = pl.slice(mlp_silu_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, o0]) - dn_w_i = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) - down_acc = pl.matmul_acc(down_acc, dn_a_i, dn_w_i) - down_fp32_tile = pl.assemble(down_fp32_tile, down_acc, [0, d0]) + for dob_chunk in pl.parallel(0, hidden_blocks, 4): + with pl.at(level=pl.Level.CORE_GROUP): + for dob in pl.range(dob_chunk, dob_chunk + 4): + d0 = dob * K_CHUNK + dn_a = pl.slice(mlp_silu_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, 0]) + dn_w = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [0, d0]) + down_acc = pl.matmul(dn_a, dn_w, out_dtype=pl.FP32) + for ob in pl.range(1, mlp_out_blocks): + o0 = ob * MLP_OUT_CHUNK + dn_a_i = pl.slice(mlp_silu_tile, [TOK_TILE, MLP_OUT_CHUNK], [0, o0]) + dn_w_i = pl.slice(w_down, [MLP_OUT_CHUNK, K_CHUNK], [o0, d0]) + down_acc = pl.matmul_acc(down_acc, dn_a_i, dn_w_i) + down_fp32_tile = pl.assemble(down_fp32_tile, down_acc, [0, d0]) # Stage 3.4: Final residual add -> BF16 output. for ob in pl.range(hidden_blocks):