diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index 3939afd2..dc2c0b83 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -49,6 +49,8 @@ HEAD_GROUP_BLOCKS = (H * HEAD_DIM) // (HEAD_CHUNK * HEAD_GROUP) D_BLOCKS = D // D_CHUNK KV_BLOCKS = HEAD_DIM // KV_CHUNK +RMS_T_TILES = 2 +T_TILE = T // RMS_T_TILES @pl.jit.inline @@ -72,241 +74,236 @@ def qkv_proj_rope( ): x_flat = pl.reshape(x, [T, D]) - # Stage 0.1: fused attn_norm -> token_x_fp32 - token_x_fp32 = pl.create_tensor([T, D], dtype=pl.FP32) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm_rms"): + # Stage 0: fused attn_norm (rms + apply + bf16 cast) -> token_x_bf16 for the split AIV->AIC flow. + token_x_bf16 = pl.create_tensor([T, D], dtype=pl.BF16) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="attn_norm"): x_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) for db in pl.range(D_BLOCKS): d0 = db * D_CHUNK - x_chunk = pl.cast(pl.slice(x_flat, [T, D_CHUNK], [0, d0]), target_type=pl.FP32) + x_chunk = pl.cast(x_flat[:, d0 : d0 + D_CHUNK], target_type=pl.FP32) x_sq_sum = pl.add(x_sq_sum, pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, T])) x_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(x_sq_sum, 1.0 / D), EPS))) - - x_inv_rms_t = pl.reshape(x_inv_rms, [T, 1]) - for db in pl.parallel(0, D_BLOCKS, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="attn_norm_apply"): - d0 = db * D_CHUNK - x_chunk = pl.cast(pl.slice(x_flat, [T, D_CHUNK], [0, d0]), target_type=pl.FP32) - norm_w_chunk = pl.reshape(pl.slice(norm_w, [D_CHUNK], [d0]), [1, D_CHUNK]) - x_normed = pl.col_expand_mul(pl.row_expand_mul(x_chunk, x_inv_rms_t), norm_w_chunk) - token_x_fp32 = pl.assemble(token_x_fp32, x_normed, [0, d0]) - - # Stage 0.2: pre-cast token_x for split AIV->AIC flow. - token_x_bf16 = pl.create_tensor([T, D], dtype=pl.BF16) - for db in pl.parallel(0, D_BLOCKS, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="token_x_cast_bf16"): - d0 = db * D_CHUNK - x_chunk_fp32 = pl.slice(token_x_fp32, [T, D_CHUNK], [0, d0]) - token_x_bf16 = pl.assemble(token_x_bf16, pl.cast(x_chunk_fp32, target_type=pl.BF16, mode="rint"), [0, d0]) + x_inv_rms_t = pl.reshape(x_inv_rms, [T, 1]) + for db_apply in pl.range(D_BLOCKS): + d0_apply = db_apply * D_CHUNK + x_apply_chunk = pl.cast(x_flat[:, d0_apply : d0_apply + D_CHUNK], target_type=pl.FP32) + norm_w_chunk = pl.reshape(norm_w[d0_apply : d0_apply + D_CHUNK], [1, D_CHUNK]) + x_normed = pl.col_expand_mul(pl.row_expand_mul(x_apply_chunk, x_inv_rms_t), norm_w_chunk) + token_x_bf16[:, d0_apply : d0_apply + D_CHUNK] = pl.cast(x_normed, target_type=pl.BF16, mode="rint") # Stage 1/2.1: qr = rms_norm(token_x @ wq_a, gamma_cq) qr_fp32 = pl.create_tensor([T, Q_LORA], dtype=pl.FP32) - for qb in pl.parallel(0, Q_BLOCKS, 1): + for qb in pl.parallel(Q_BLOCKS): with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_proj_matmul"): q_a_col0 = qb * Q_LORA_CHUNK - d0_0 = 0 - x_chunk_bf16_0 = pl.slice(token_x_bf16, [T, D_CHUNK], [0, d0_0]) - w_chunk_0 = pl.slice(wq_a, [D_CHUNK, Q_LORA_CHUNK], [d0_0, q_a_col0]) - q_acc = pl.matmul(x_chunk_bf16_0, w_chunk_0, out_dtype=pl.FP32) - for db in pl.range(1, D_BLOCKS): + q_acc = pl.create_tensor([T, Q_LORA_CHUNK], dtype=pl.FP32) + for db in pl.pipeline(0, D_BLOCKS, stage=2): d0 = db * D_CHUNK - q_x_chunk_bf16 = pl.slice(token_x_bf16, [T, D_CHUNK], [0, d0]) - w_chunk = pl.slice(wq_a, [D_CHUNK, Q_LORA_CHUNK], [d0, q_a_col0]) - q_acc = pl.matmul_acc(q_acc, q_x_chunk_bf16, w_chunk) - qr_fp32 = pl.assemble(qr_fp32, q_acc, [0, q_a_col0]) - + tile_a = token_x_bf16[:, d0 : d0 + D_CHUNK] + tile_b = wq_a[d0 : d0 + D_CHUNK, q_a_col0 : q_a_col0 + Q_LORA_CHUNK] + if db == 0: + q_acc = pl.matmul(tile_a, tile_b, out_dtype=pl.FP32) + else: + q_acc = pl.matmul_acc(q_acc, tile_a, tile_b) + qr_fp32[:, q_a_col0 : q_a_col0 + Q_LORA_CHUNK] = q_acc + + # Stage 2.1: fused qr rms_norm (rms + apply + bf16 cast) for the W8A8 dynamic activation path. + qr_bf16 = pl.create_tensor([T, Q_LORA], dtype=pl.BF16) with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_rms"): qr_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) for qb in pl.range(Q_BLOCKS): qr_sq_col0 = qb * Q_LORA_CHUNK - qr_chunk = pl.slice(qr_fp32, [T, Q_LORA_CHUNK], [0, qr_sq_col0]) + qr_chunk = qr_fp32[:, qr_sq_col0 : qr_sq_col0 + Q_LORA_CHUNK] qr_sq_sum = pl.add(qr_sq_sum, pl.reshape(pl.row_sum(pl.mul(qr_chunk, qr_chunk)), [1, T])) qr_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(qr_sq_sum, 1.0 / Q_LORA), EPS))) - - qr_inv_rms_t = pl.reshape(qr_inv_rms, [T, 1]) - for qb in pl.parallel(0, Q_BLOCKS, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_norm_apply"): - qr_norm_col0 = qb * Q_LORA_CHUNK - qr_chunk = pl.slice(qr_fp32, [T, Q_LORA_CHUNK], [0, qr_norm_col0]) + qr_inv_rms_t = pl.reshape(qr_inv_rms, [T, 1]) + for qb_apply in pl.range(Q_BLOCKS): + qr_norm_col0 = qb_apply * Q_LORA_CHUNK + qr_apply_chunk = qr_fp32[:, qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK] gamma_chunk = pl.reshape( - pl.cast(pl.slice(gamma_cq, [Q_LORA_CHUNK], [qr_norm_col0]), target_type=pl.FP32), + pl.cast(gamma_cq[qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK], target_type=pl.FP32), [1, Q_LORA_CHUNK], ) - qr_normed = pl.col_expand_mul(pl.row_expand_mul(qr_chunk, qr_inv_rms_t), gamma_chunk) - qr_fp32 = pl.assemble(qr_fp32, qr_normed, [0, qr_norm_col0]) - - # Stage 2.2: pre-cast normalized qr for the W8A8 dynamic activation path. - qr_bf16 = pl.create_tensor([T, Q_LORA], dtype=pl.BF16) - for qb in pl.parallel(0, Q_BLOCKS, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_cast_bf16"): - qr_store_col0 = qb * Q_LORA_CHUNK - qr_chunk_fp32 = pl.slice(qr_fp32, [T, Q_LORA_CHUNK], [0, qr_store_col0]) - qr_bf16 = pl.assemble(qr_bf16, pl.cast(qr_chunk_fp32, target_type=pl.BF16, mode="rint"), [0, qr_store_col0]) + qr_normed = pl.col_expand_mul(pl.row_expand_mul(qr_apply_chunk, qr_inv_rms_t), gamma_chunk) + qr_bf16[:, qr_norm_col0 : qr_norm_col0 + Q_LORA_CHUNK] = pl.cast(qr_normed, target_type=pl.BF16, mode="rint") # Stage 2.3: W8A8C16 activation path: quantize normalized qr per token. qr_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) with pl.at(level=pl.Level.CORE_GROUP, name_hint="qr_quant"): qr_amax = pl.full([1, T], dtype=pl.FP32, value=INT8_AMAX_EPS) for q0 in pl.range(0, Q_LORA, QUANT_CHUNK): - qr_a_f32 = pl.cast(pl.slice(qr_bf16, [T, QUANT_CHUNK], [0, q0]), target_type=pl.FP32) + qr_a_f32 = pl.cast(qr_bf16[:, q0 : q0 + QUANT_CHUNK], target_type=pl.FP32) qr_a_abs = pl.maximum(qr_a_f32, pl.neg(qr_a_f32)) qr_a_max = pl.reshape(pl.row_max(qr_a_abs), [1, T]) qr_amax = pl.maximum(qr_amax, qr_a_max) qr_scale_quant_row = pl.div(pl.full([1, T], dtype=pl.FP32, value=INT8_SCALE_MAX), qr_amax) qr_scale_dq = pl.reshape(pl.recip(qr_scale_quant_row), [T, 1]) - qr_scale = pl.assemble(qr_scale, qr_scale_dq, [0, 0]) + qr_scale[:, :] = qr_scale_dq qr_scale_quant = pl.reshape(qr_scale_quant_row, [T, 1]) for q1 in pl.range(0, Q_LORA, QUANT_CHUNK): - qr_q_f32 = pl.cast(pl.slice(qr_bf16, [T, QUANT_CHUNK], [0, q1]), target_type=pl.FP32) + qr_q_f32 = pl.cast(qr_bf16[:, q1 : q1 + QUANT_CHUNK], target_type=pl.FP32) qr_q_scaled = pl.row_expand_mul(qr_q_f32, qr_scale_quant) qr_q_i32 = pl.cast(qr_q_scaled, target_type=pl.INT32, mode="rint") qr_q_half = pl.cast(qr_q_i32, target_type=pl.FP16, mode="round") - qr = pl.assemble(qr, pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc"), [0, q1]) + qr[:, q1 : q1 + QUANT_CHUNK] = pl.cast(qr_q_half, target_type=pl.INT8, mode="trunc") # Stage 3: W8A8C16 q_proj = qr_i8 @ wq_b, then dequantize to FP32. q_proj_fp32 = pl.create_tensor([T, H * HEAD_DIM], dtype=pl.FP32) - for hb in pl.parallel(0, Q_PROJ_HEAD_BLOCKS, 1): + for hb in pl.parallel(Q_PROJ_HEAD_BLOCKS): h0 = hb * Q_PROJ_OUT_CHUNK - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_matmul", optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_matmul"): q0_0 = 0 - qr_i8_0 = pl.slice(qr, [T, Q_PROJ_CHUNK], [0, q0_0]) - wq_0 = pl.slice(wq_b, [Q_PROJ_CHUNK, Q_PROJ_OUT_CHUNK], [q0_0, h0]) + qr_i8_0 = qr[:, q0_0 : q0_0 + Q_PROJ_CHUNK] + wq_0 = wq_b[q0_0 : q0_0 + Q_PROJ_CHUNK, h0 : h0 + Q_PROJ_OUT_CHUNK] col_acc = pl.matmul(qr_i8_0, wq_0, out_dtype=pl.INT32) for qb in pl.range(1, Q_PROJ_BLOCKS): qr_proj_col0 = qb * Q_PROJ_CHUNK - qr_i8_chunk = pl.slice(qr, [T, Q_PROJ_CHUNK], [0, qr_proj_col0]) - wq_chunk = pl.slice(wq_b, [Q_PROJ_CHUNK, Q_PROJ_OUT_CHUNK], [qr_proj_col0, h0]) + qr_i8_chunk = qr[:, qr_proj_col0 : qr_proj_col0 + Q_PROJ_CHUNK] + wq_chunk = wq_b[qr_proj_col0 : qr_proj_col0 + Q_PROJ_CHUNK, h0 : h0 + Q_PROJ_OUT_CHUNK] col_acc = pl.matmul_acc(col_acc, qr_i8_chunk, wq_chunk) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_dequant", optimization=pl.chunked_loop_optimizer): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="qproj_dequant"): col_fp32 = pl.cast(col_acc, target_type=pl.FP32, mode="none") - w_scale = pl.slice(wq_b_scale, [1, Q_PROJ_OUT_CHUNK], [hb, 0]) + w_scale = wq_b_scale[hb : hb + 1, :] col_dequant = pl.col_expand_mul(pl.row_expand_mul(col_fp32, qr_scale_dq), w_scale) - q_proj_fp32 = pl.assemble(q_proj_fp32, col_dequant, [0, h0]) + q_proj_fp32[:, h0 : h0 + Q_PROJ_OUT_CHUNK] = col_dequant # Stage 4: per-head RMSNorm + RoPE on q q_flat = pl.reshape(q, [T, H * HEAD_DIM]) - for h in pl.parallel(0, H, 1): - with pl.at(level=pl.Level.CORE_GROUP, optimization=pl.chunked_loop_optimizer, name_hint="q_head_rms_rope"): + for h in pl.parallel(H): + q_rot_even_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) + q_rot_odd_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_head_rms_rope"): h0 = h * HEAD_DIM - q_head_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for db in pl.range(HEAD_DIM // HEAD_CHUNK): - d0 = h0 + db * HEAD_CHUNK - q_head_chunk = pl.slice(q_proj_fp32, [T, HEAD_CHUNK], [0, d0]) - q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) - q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) - q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) - - for nb in pl.range(NOPE_DIM // HEAD_CHUNK): - n0 = nb * HEAD_CHUNK - q_nope_chunk = pl.slice(q_proj_fp32, [T, HEAD_CHUNK], [0, h0 + n0]) - q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) - q_flat = pl.assemble(q_flat, pl.cast(q_normed, target_type=pl.BF16, mode="rint"), [0, h0 + n0]) - - q_rope = pl.slice(q_proj_fp32, [T, ROPE_DIM], [0, h0 + NOPE_DIM]) - q_rope_norm = pl.row_expand_mul(q_rope, q_head_inv_rms_t) - q_even = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) - q_odd = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) - cos = pl.cast(pl.slice(rope_cos, [T, ROPE_HALF], [0, 0]), target_type=pl.FP32) - sin = pl.cast(pl.slice(rope_sin, [T, ROPE_HALF], [0, 0]), target_type=pl.FP32) - q_rot_even = pl.sub(pl.mul(q_even, cos), pl.mul(q_odd, sin)) - q_rot_odd = pl.add(pl.mul(q_even, sin), pl.mul(q_odd, cos)) - q_rot_even_bf16 = pl.cast(q_rot_even, target_type=pl.BF16, mode="rint") - q_rot_odd_bf16 = pl.cast(q_rot_odd, target_type=pl.BF16, mode="rint") - - for rope_col in pl.range(0, ROPE_DIM, ROPE_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): + for tt in pl.range(RMS_T_TILES): + t0 = tt * T_TILE + q_head_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + for db in pl.range(HEAD_DIM // HEAD_CHUNK): + d0 = h0 + db * HEAD_CHUNK + q_head_chunk = q_proj_fp32[t0 : t0 + T_TILE, d0 : d0 + HEAD_CHUNK] + q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T_TILE])) + q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) + q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T_TILE, 1]) + + for nb in pl.range(NOPE_DIM // HEAD_CHUNK): + n0 = nb * HEAD_CHUNK + q_nope_chunk = q_proj_fp32[t0 : t0 + T_TILE, h0 + n0 : h0 + n0 + HEAD_CHUNK] + q_normed = pl.row_expand_mul(q_nope_chunk, q_head_inv_rms_t) + q_flat[t0 : t0 + T_TILE, h0 + n0 : h0 + n0 + HEAD_CHUNK] = pl.cast(q_normed, target_type=pl.BF16, mode="rint") + + q_rope = q_proj_fp32[t0 : t0 + T_TILE, h0 + NOPE_DIM : h0 + NOPE_DIM + ROPE_DIM] + q_rope_norm = pl.row_expand_mul(q_rope, q_head_inv_rms_t) + q_even = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) + q_odd = pl.tensor.gather(q_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) + cos = pl.cast(rope_cos[t0 : t0 + T_TILE, :ROPE_HALF], target_type=pl.FP32) + sin = pl.cast(rope_sin[t0 : t0 + T_TILE, :ROPE_HALF], target_type=pl.FP32) + q_rot_even = pl.sub(pl.mul(q_even, cos), pl.mul(q_odd, sin)) + q_rot_odd = pl.add(pl.mul(q_even, sin), pl.mul(q_odd, cos)) + q_rot_even_bf16[t0 : t0 + T_TILE, :] = pl.cast(q_rot_even, target_type=pl.BF16, mode="rint") + q_rot_odd_bf16[t0 : t0 + T_TILE, :] = pl.cast(q_rot_odd, target_type=pl.BF16, mode="rint") + + q_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"): + for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): pair_col = rope_col // 2 - q_rot_even_chunk = pl.slice(q_rot_even_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) - q_rot_odd_chunk = pl.slice(q_rot_odd_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col]) + q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] + q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK] q_rot_chunk = pl.matmul( q_rot_even_chunk, - pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), + even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], out_dtype=pl.FP32, ) q_rot_chunk = pl.matmul_acc( q_rot_chunk, q_rot_odd_chunk, - pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), + odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], ) + q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = q_rot_chunk - with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): - h0 = h * HEAD_DIM - q_flat = pl.assemble(q_flat, pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint"), [0, h0 + NOPE_DIM + rope_col]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"): + h0 = h * HEAD_DIM + for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): + q_rope_chunk = q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] + q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rope_chunk, target_type=pl.BF16, mode="rint") q = pl.reshape(q_flat, [T, H, HEAD_DIM]) # Stage 5/6: kv = rms_norm(token_x @ wkv, gamma_ckv) + RoPE kv_fp32 = pl.create_tensor([T, HEAD_DIM], dtype=pl.FP32) - for kb in pl.parallel(0, KV_BLOCKS, 1): + for kb in pl.parallel(KV_BLOCKS): with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_proj_matmul"): kv_col0 = kb * KV_CHUNK d0_0 = 0 - x_chunk_bf16_0 = pl.slice(token_x_bf16, [T, D_CHUNK], [0, d0_0]) - wkv_chunk_0 = pl.slice(wkv, [D_CHUNK, KV_CHUNK], [d0_0, kv_col0]) + x_chunk_bf16_0 = token_x_bf16[:, d0_0 : d0_0 + D_CHUNK] + wkv_chunk_0 = wkv[d0_0 : d0_0 + D_CHUNK, kv_col0 : kv_col0 + KV_CHUNK] kv_acc = pl.matmul(x_chunk_bf16_0, wkv_chunk_0, out_dtype=pl.FP32) for db in pl.range(1, D_BLOCKS): d0 = db * D_CHUNK - kv_x_chunk_bf16 = pl.slice(token_x_bf16, [T, D_CHUNK], [0, d0]) - wkv_chunk = pl.slice(wkv, [D_CHUNK, KV_CHUNK], [d0, kv_col0]) + kv_x_chunk_bf16 = token_x_bf16[:, d0 : d0 + D_CHUNK] + wkv_chunk = wkv[d0 : d0 + D_CHUNK, kv_col0 : kv_col0 + KV_CHUNK] kv_acc = pl.matmul_acc(kv_acc, kv_x_chunk_bf16, wkv_chunk) - kv_fp32 = pl.assemble(kv_fp32, kv_acc, [0, kv_col0]) + kv_fp32[:, kv_col0 : kv_col0 + KV_CHUNK] = kv_acc - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rms"): - kv_sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for kb in pl.range(KV_BLOCKS): - kv_sq_col0 = kb * KV_CHUNK - kv_chunk = pl.slice(kv_fp32, [T, KV_CHUNK], [0, kv_sq_col0]) - kv_sq_sum = pl.add(kv_sq_sum, pl.reshape(pl.row_sum(pl.mul(kv_chunk, kv_chunk)), [1, T])) - kv_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(kv_sq_sum, 1.0 / HEAD_DIM), EPS))) - - kv_inv_rms_t = pl.reshape(kv_inv_rms, [T, 1]) - for nb in pl.parallel(0, NOPE_DIM // KV_CHUNK, 1): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_norm_nope"): - n0 = nb * KV_CHUNK - kv_chunk = pl.slice(kv_fp32, [T, KV_CHUNK], [0, n0]) - gamma_kv_chunk = pl.reshape( - pl.cast(pl.slice(gamma_ckv, [KV_CHUNK], [n0]), target_type=pl.FP32), - [1, KV_CHUNK], - ) - kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_chunk, kv_inv_rms_t), gamma_kv_chunk) - kv = pl.assemble(kv, pl.cast(kv_normed, target_type=pl.BF16, mode="rint"), [0, n0]) + # Stage 5.1: fused kv rms_norm (rms + nope apply + rope apply). kv_rot_even_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) kv_rot_odd_tmp = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_apply"): - kv_rope = pl.slice(kv_fp32, [T, ROPE_DIM], [0, NOPE_DIM]) - gamma_rope = pl.reshape( - pl.cast(pl.slice(gamma_ckv, [ROPE_DIM], [NOPE_DIM]), target_type=pl.FP32), - [1, ROPE_DIM], - ) - kv_rope_norm = pl.col_expand_mul(pl.row_expand_mul(kv_rope, kv_inv_rms_t), gamma_rope) - kv_even = pl.tensor.gather(kv_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) - kv_odd = pl.tensor.gather(kv_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) - cos = pl.cast(pl.slice(rope_cos, [T, ROPE_HALF], [0, 0]), target_type=pl.FP32) - sin = pl.cast(pl.slice(rope_sin, [T, ROPE_HALF], [0, 0]), target_type=pl.FP32) - kv_rot_even = pl.sub(pl.mul(kv_even, cos), pl.mul(kv_odd, sin)) - kv_rot_odd = pl.add(pl.mul(kv_even, sin), pl.mul(kv_odd, cos)) - kv_rot_even_tmp = pl.assemble(kv_rot_even_tmp, pl.cast(kv_rot_even, target_type=pl.BF16, mode="rint"), [0, 0]) - kv_rot_odd_tmp = pl.assemble(kv_rot_odd_tmp, pl.cast(kv_rot_odd, target_type=pl.BF16, mode="rint"), [0, 0]) - - for rope_col in pl.range(0, ROPE_DIM, ROPE_CHUNK): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rms"): + for tt in pl.range(RMS_T_TILES): + t0 = tt * T_TILE + kv_sq_sum = pl.full([1, T_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.range(KV_BLOCKS): + kv_sq_col0 = kb * KV_CHUNK + kv_chunk = kv_fp32[t0 : t0 + T_TILE, kv_sq_col0 : kv_sq_col0 + KV_CHUNK] + kv_sq_sum = pl.add(kv_sq_sum, pl.reshape(pl.row_sum(pl.mul(kv_chunk, kv_chunk)), [1, T_TILE])) + kv_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(kv_sq_sum, 1.0 / HEAD_DIM), EPS))) + kv_inv_rms_t = pl.reshape(kv_inv_rms, [T_TILE, 1]) + for nb in pl.range(NOPE_DIM // KV_CHUNK): + n0 = nb * KV_CHUNK + kv_nope_chunk = kv_fp32[t0 : t0 + T_TILE, n0 : n0 + KV_CHUNK] + gamma_kv_chunk = pl.reshape( + pl.cast(gamma_ckv[n0 : n0 + KV_CHUNK], target_type=pl.FP32), + [1, KV_CHUNK], + ) + kv_normed = pl.col_expand_mul(pl.row_expand_mul(kv_nope_chunk, kv_inv_rms_t), gamma_kv_chunk) + kv[t0 : t0 + T_TILE, n0 : n0 + KV_CHUNK] = pl.cast(kv_normed, target_type=pl.BF16, mode="rint") + + kv_rope = kv_fp32[t0 : t0 + T_TILE, NOPE_DIM : NOPE_DIM + ROPE_DIM] + gamma_rope = pl.reshape( + pl.cast(gamma_ckv[NOPE_DIM : NOPE_DIM + ROPE_DIM], target_type=pl.FP32), + [1, ROPE_DIM], + ) + kv_rope_norm = pl.col_expand_mul(pl.row_expand_mul(kv_rope, kv_inv_rms_t), gamma_rope) + kv_even = pl.tensor.gather(kv_rope_norm, mask_pattern=pl.tile.MaskPattern.P0101) + kv_odd = pl.tensor.gather(kv_rope_norm, mask_pattern=pl.tile.MaskPattern.P1010) + cos = pl.cast(rope_cos[t0 : t0 + T_TILE, :ROPE_HALF], target_type=pl.FP32) + sin = pl.cast(rope_sin[t0 : t0 + T_TILE, :ROPE_HALF], target_type=pl.FP32) + kv_rot_even = pl.sub(pl.mul(kv_even, cos), pl.mul(kv_odd, sin)) + kv_rot_odd = pl.add(pl.mul(kv_even, sin), pl.mul(kv_odd, cos)) + kv_rot_even_tmp[t0 : t0 + T_TILE, :] = pl.cast(kv_rot_even, target_type=pl.BF16, mode="rint") + kv_rot_odd_tmp[t0 : t0 + T_TILE, :] = pl.cast(kv_rot_odd, target_type=pl.BF16, mode="rint") + + kv_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"): + for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): pair_col = rope_col // 2 - kv_rot_even_chunk = pl.slice(kv_rot_even_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) - kv_rot_odd_chunk = pl.slice(kv_rot_odd_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col]) + kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] + kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK] kv_rot_chunk = pl.matmul( kv_rot_even_chunk, - pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), + even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], out_dtype=pl.FP32, ) kv_rot_chunk = pl.matmul_acc( kv_rot_chunk, kv_rot_odd_chunk, - pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]), + odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK], ) + kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk - with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): - kv = pl.assemble(kv, pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint"), [0, NOPE_DIM + rope_col]) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"): + for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2): + kv_rope_chunk = kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] + kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rope_chunk, target_type=pl.BF16, mode="rint") return q