diff --git a/models/deepseek/v4/hc_pre.py b/models/deepseek/v4/hc_pre.py index 6b5af6a..ade041a 100644 --- a/models/deepseek/v4/hc_pre.py +++ b/models/deepseek/v4/hc_pre.py @@ -35,9 +35,14 @@ # tiling T_TILE = 16 -K_CHUNK = 128 if T >= 128 else (256 if T >= 64 else 512) +RMS_T_TILE = 16 +LINEAR_T_TILE = 16 +COMB_T_TILE = 16 +RMS_K_CHUNK = 128 +LINEAR_K_CHUNK = 512 D_CHUNK = 512 -HC_DIM_BLOCKS = HC_DIM // K_CHUNK +RMS_K_BLOCKS = HC_DIM // RMS_K_CHUNK +LINEAR_K_BLOCKS = HC_DIM // LINEAR_K_CHUNK D_BLOCKS = D // D_CHUNK RMS_PIPE_STAGE = 1 if T >= 64 else 4 @@ -55,60 +60,59 @@ def hc_pre( x_flat = pl.reshape(x, [T, HC_DIM]) post_flat = pl.reshape(post, [T * HC_MULT]) comb_flat = pl.reshape(comb, [T * HC_MULT * HC_MULT]) - x_flat_fp32 = pl.create_tensor([T, HC_DIM], dtype=pl.FP32) inv_rms = pl.create_tensor([1, T], dtype=pl.FP32) mixes = pl.create_tensor([T, MIX_PAD], dtype=pl.FP32) mix_raw = pl.create_tensor([T, MIX_PAD], dtype=pl.FP32) - - for kb in pl.parallel(HC_DIM_BLOCKS): - k0 = kb * K_CHUNK - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cast_x"): - x_chunk_fp32 = pl.cast( - pl.slice(x_flat, [T, K_CHUNK], [0, k0]), + pre_val_store = pl.create_tensor([T, HC_PAD], dtype=pl.FP32) + pre_val_t = pl.create_tensor([HC_PAD, T], dtype=pl.FP32) + + for t0 in pl.parallel(0, T, RMS_T_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="rms"): + sq_sum = pl.full([1, RMS_T_TILE], dtype=pl.FP32, value=0.0) + for kb in pl.pipeline(RMS_K_BLOCKS, stage=RMS_PIPE_STAGE): + k0 = kb * RMS_K_CHUNK + x_chunk = pl.cast( + pl.slice(x_flat, [RMS_T_TILE, RMS_K_CHUNK], [t0, k0]), + target_type=pl.FP32, + ) + sq_sum = pl.add( + sq_sum, + pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, RMS_T_TILE]), + ) + inv_rms_val = pl.rsqrt(pl.add(pl.mul(sq_sum, HC_DIM_INV), NORM_EPS), high_precision=True) + inv_rms = pl.assemble(inv_rms, inv_rms_val, [0, t0]) + + for t0 in pl.parallel(0, T, LINEAR_T_TILE): + with pl.at( + level=pl.Level.CORE_GROUP, + optimizations=[pl.split(pl.SplitMode.UP_DOWN)], + name_hint="linear", + ): + x_lin_0 = pl.cast( + pl.slice(x_flat, [LINEAR_T_TILE, LINEAR_K_CHUNK], [t0, 0]), target_type=pl.FP32, ) - x_flat_fp32 = pl.assemble(x_flat_fp32, x_chunk_fp32, [0, k0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="rms"): - sq_sum = pl.full([1, T], dtype=pl.FP32, value=0.0) - for kb in pl.pipeline(HC_DIM_BLOCKS, stage=RMS_PIPE_STAGE): - k0 = kb * K_CHUNK - x_chunk = pl.slice(x_flat_fp32, [T, K_CHUNK], [0, k0]) - sq_sum = pl.add( - sq_sum, - pl.reshape(pl.row_sum(pl.mul(x_chunk, x_chunk)), [1, T]), - ) - inv_rms_val = pl.rsqrt(pl.add(pl.mul(sq_sum, HC_DIM_INV), NORM_EPS), high_precision=True) - inv_rms = pl.assemble(inv_rms, inv_rms_val, [0, 0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="linear"): - x_lin_0_mat = pl.load(x_flat_fp32, [0, 0], [T, K_CHUNK], target_memory=pl.MemorySpace.Mat) - w_lin_0_mat = pl.load( - hc_fn, - [0, 0], - [MIX_PAD, K_CHUNK], - valid_shapes=[MIX_HC, K_CHUNK], - target_memory=pl.MemorySpace.Mat, - transpose=True, - ) - x_lin_0_left = pl.move(x_lin_0_mat, target_memory=pl.MemorySpace.Left) - w_lin_0_right = pl.move(w_lin_0_mat, target_memory=pl.MemorySpace.Right) - mix_acc = pl.matmul(x_lin_0_left, w_lin_0_right) - for kb in pl.range(1, HC_DIM_BLOCKS): - kl0 = kb * K_CHUNK - x_lin_mat = pl.load(x_flat_fp32, [0, kl0], [T, K_CHUNK], target_memory=pl.MemorySpace.Mat) - w_lin_mat = pl.load( + w_lin_0 = pl.slice( hc_fn, - [0, kl0], - [MIX_PAD, K_CHUNK], - valid_shapes=[MIX_HC, K_CHUNK], - target_memory=pl.MemorySpace.Mat, - transpose=True, + [MIX_PAD, LINEAR_K_CHUNK], + [0, 0], + valid_shape=[MIX_HC, LINEAR_K_CHUNK], ) - x_lin_left = pl.move(x_lin_mat, target_memory=pl.MemorySpace.Left) - w_lin_right = pl.move(w_lin_mat, target_memory=pl.MemorySpace.Right) - mix_acc = pl.matmul_acc(mix_acc, x_lin_left, w_lin_right) - mix_raw = pl.store(mix_acc, [0, 0], mix_raw) + mix_acc = pl.matmul(x_lin_0, w_lin_0, b_trans=True, out_dtype=pl.FP32) + for kb in pl.pipeline(1, LINEAR_K_BLOCKS, stage=2): + kl0 = kb * LINEAR_K_CHUNK + x_lin = pl.cast( + pl.slice(x_flat, [LINEAR_T_TILE, LINEAR_K_CHUNK], [t0, kl0]), + target_type=pl.FP32, + ) + w_lin = pl.slice( + hc_fn, + [MIX_PAD, LINEAR_K_CHUNK], + [0, kl0], + valid_shape=[MIX_HC, LINEAR_K_CHUNK], + ) + mix_acc = pl.matmul_acc(mix_acc, x_lin, w_lin, b_trans=True) + mix_raw = pl.assemble(mix_raw, mix_acc, [t0, 0]) with pl.at(level=pl.Level.CORE_GROUP, name_hint="linear_scale"): mixes = pl.assemble( @@ -130,6 +134,7 @@ def hc_pre( pl.col_expand_mul(ones_hc, pre_base), ) pre_val = pl.add(pl.recip(pl.add(pl.exp(pl.neg(pre_logits)), 1.0)), HC_EPS) + pre_val_store = pl.assemble(pre_val_store, pre_val, [0, 0]) post_base = pl.reshape(pl.slice(hc_base, [HC_PAD], [HC_MULT]), [1, HC_PAD]) post_logits = pl.add( @@ -152,94 +157,107 @@ def hc_pre( post_pad_flat = pl.reshape(post_pad, [T * HC_PAD]) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="comb_sinkhorn"): - row0 = pl.fillpad(pl.load( - comb_logits, - [0, 0 * HC_MULT], - [T, HC_PAD], - valid_shapes=[T, HC_MULT], - target_memory=pl.MemorySpace.Vec, - ), pad_value=pl.PadValue.min) - row1 = pl.fillpad(pl.load( - comb_logits, - [0, 1 * HC_MULT], - [T, HC_PAD], - valid_shapes=[T, HC_MULT], - target_memory=pl.MemorySpace.Vec, - ), pad_value=pl.PadValue.min) - row2 = pl.fillpad(pl.load( - comb_logits, - [0, 2 * HC_MULT], - [T, HC_PAD], - valid_shapes=[T, HC_MULT], - target_memory=pl.MemorySpace.Vec, - ), pad_value=pl.PadValue.min) - row3 = pl.fillpad(pl.load( - comb_logits, - [0, 3 * HC_MULT], - [T, HC_PAD], - valid_shapes=[T, HC_MULT], - target_memory=pl.MemorySpace.Vec, - ), pad_value=pl.PadValue.min) - - row_max_tmp = pl.create_tile([T, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) - row_sum_tmp = pl.create_tile([T, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) - row0_exp = pl.exp(pl.row_expand_sub(row0, pl.row_max(row0, row_max_tmp))) - row1_exp = pl.exp(pl.row_expand_sub(row1, pl.row_max(row1, row_max_tmp))) - row2_exp = pl.exp(pl.row_expand_sub(row2, pl.row_max(row2, row_max_tmp))) - row3_exp = pl.exp(pl.row_expand_sub(row3, pl.row_max(row3, row_max_tmp))) - row0_soft = pl.add(pl.row_expand_div(row0_exp, pl.row_sum(row0_exp, row_sum_tmp)), HC_EPS) - row1_soft = pl.add(pl.row_expand_div(row1_exp, pl.row_sum(row1_exp, row_sum_tmp)), HC_EPS) - row2_soft = pl.add(pl.row_expand_div(row2_exp, pl.row_sum(row2_exp, row_sum_tmp)), HC_EPS) - row3_soft = pl.add(pl.row_expand_div(row3_exp, pl.row_sum(row3_exp, row_sum_tmp)), HC_EPS) - - row0_eff = pl.tile.fillpad(pl.tile.set_validshape(row0_soft, T, HC_MULT), pad_value=pl.PadValue.zero) - row1_eff = pl.tile.fillpad(pl.tile.set_validshape(row1_soft, T, HC_MULT), pad_value=pl.PadValue.zero) - row2_eff = pl.tile.fillpad(pl.tile.set_validshape(row2_soft, T, HC_MULT), pad_value=pl.PadValue.zero) - row3_eff = pl.tile.fillpad(pl.tile.set_validshape(row3_soft, T, HC_MULT), pad_value=pl.PadValue.zero) - - row_sum_tmp_iter = pl.create_tile([T, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) - col_sum = pl.add(pl.add(row0_eff, row1_eff), pl.add(row2_eff, row3_eff)) - col_sum = pl.add(col_sum, HC_EPS) - row0_cur = pl.div(row0_eff, col_sum) - row1_cur = pl.div(row1_eff, col_sum) - row2_cur = pl.div(row2_eff, col_sum) - row3_cur = pl.div(row3_eff, col_sum) - - for _ in pl.unroll(HC_SINKHORN_ITER - 1): - row0_norm = pl.row_expand_div(row0_cur, pl.add(pl.row_sum(row0_cur, row_sum_tmp_iter), HC_EPS)) - row1_norm = pl.row_expand_div(row1_cur, pl.add(pl.row_sum(row1_cur, row_sum_tmp_iter), HC_EPS)) - row2_norm = pl.row_expand_div(row2_cur, pl.add(pl.row_sum(row2_cur, row_sum_tmp_iter), HC_EPS)) - row3_norm = pl.row_expand_div(row3_cur, pl.add(pl.row_sum(row3_cur, row_sum_tmp_iter), HC_EPS)) - col_sum = pl.add(pl.add(row0_norm, row1_norm), pl.add(row2_norm, row3_norm)) + with pl.at(level=pl.Level.CORE_GROUP, name_hint="transpose_pre"): + for t0 in pl.range(0, T, T_TILE): + pre_tile = pl.load( + pre_val_store, + [t0, 0], + [T_TILE, HC_PAD], + target_memory=pl.MemorySpace.Vec, + ) + pre_tile_t = pl.transpose(pre_tile, axis1=0, axis2=1) + pre_val_t = pl.store(pre_tile_t, [0, t0], pre_val_t) + + for t0 in pl.parallel(0, T, COMB_T_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="comb_sinkhorn"): + row0 = pl.fillpad(pl.load( + comb_logits, + [t0, 0 * HC_MULT], + [COMB_T_TILE, HC_PAD], + valid_shapes=[COMB_T_TILE, HC_MULT], + target_memory=pl.MemorySpace.Vec, + ), pad_value=pl.PadValue.min) + row1 = pl.fillpad(pl.load( + comb_logits, + [t0, 1 * HC_MULT], + [COMB_T_TILE, HC_PAD], + valid_shapes=[COMB_T_TILE, HC_MULT], + target_memory=pl.MemorySpace.Vec, + ), pad_value=pl.PadValue.min) + row2 = pl.fillpad(pl.load( + comb_logits, + [t0, 2 * HC_MULT], + [COMB_T_TILE, HC_PAD], + valid_shapes=[COMB_T_TILE, HC_MULT], + target_memory=pl.MemorySpace.Vec, + ), pad_value=pl.PadValue.min) + row3 = pl.fillpad(pl.load( + comb_logits, + [t0, 3 * HC_MULT], + [COMB_T_TILE, HC_PAD], + valid_shapes=[COMB_T_TILE, HC_MULT], + target_memory=pl.MemorySpace.Vec, + ), pad_value=pl.PadValue.min) + + row_max_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + row_sum_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + row0_exp = pl.exp(pl.row_expand_sub(row0, pl.row_max(row0, row_max_tmp))) + row1_exp = pl.exp(pl.row_expand_sub(row1, pl.row_max(row1, row_max_tmp))) + row2_exp = pl.exp(pl.row_expand_sub(row2, pl.row_max(row2, row_max_tmp))) + row3_exp = pl.exp(pl.row_expand_sub(row3, pl.row_max(row3, row_max_tmp))) + row0_soft = pl.add(pl.row_expand_div(row0_exp, pl.row_sum(row0_exp, row_sum_tmp)), HC_EPS) + row1_soft = pl.add(pl.row_expand_div(row1_exp, pl.row_sum(row1_exp, row_sum_tmp)), HC_EPS) + row2_soft = pl.add(pl.row_expand_div(row2_exp, pl.row_sum(row2_exp, row_sum_tmp)), HC_EPS) + row3_soft = pl.add(pl.row_expand_div(row3_exp, pl.row_sum(row3_exp, row_sum_tmp)), HC_EPS) + + row0_eff = pl.tile.fillpad(pl.tile.set_validshape(row0_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) + row1_eff = pl.tile.fillpad(pl.tile.set_validshape(row1_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) + row2_eff = pl.tile.fillpad(pl.tile.set_validshape(row2_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) + row3_eff = pl.tile.fillpad(pl.tile.set_validshape(row3_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) + + row_sum_tmp_iter = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) + col_sum = pl.add(pl.add(row0_eff, row1_eff), pl.add(row2_eff, row3_eff)) col_sum = pl.add(col_sum, HC_EPS) - row0_cur = pl.div(row0_norm, col_sum) - row1_cur = pl.div(row1_norm, col_sum) - row2_cur = pl.div(row2_norm, col_sum) - row3_cur = pl.div(row3_norm, col_sum) - - for comb_t_idx in pl.unroll(T): - for c in pl.unroll(HC_MULT): - pl.write( - comb_flat, - [comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c], - pl.read(row0_cur, [comb_t_idx, c]), - ) - pl.write( - comb_flat, - [comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c], - pl.read(row1_cur, [comb_t_idx, c]), - ) - pl.write( - comb_flat, - [comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c], - pl.read(row2_cur, [comb_t_idx, c]), - ) - pl.write( - comb_flat, - [comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c], - pl.read(row3_cur, [comb_t_idx, c]), - ) + row0_cur = pl.div(row0_eff, col_sum) + row1_cur = pl.div(row1_eff, col_sum) + row2_cur = pl.div(row2_eff, col_sum) + row3_cur = pl.div(row3_eff, col_sum) + + for _ in pl.unroll(HC_SINKHORN_ITER - 1): + row0_norm = pl.row_expand_div(row0_cur, pl.add(pl.row_sum(row0_cur, row_sum_tmp_iter), HC_EPS)) + row1_norm = pl.row_expand_div(row1_cur, pl.add(pl.row_sum(row1_cur, row_sum_tmp_iter), HC_EPS)) + row2_norm = pl.row_expand_div(row2_cur, pl.add(pl.row_sum(row2_cur, row_sum_tmp_iter), HC_EPS)) + row3_norm = pl.row_expand_div(row3_cur, pl.add(pl.row_sum(row3_cur, row_sum_tmp_iter), HC_EPS)) + col_sum = pl.add(pl.add(row0_norm, row1_norm), pl.add(row2_norm, row3_norm)) + col_sum = pl.add(col_sum, HC_EPS) + row0_cur = pl.div(row0_norm, col_sum) + row1_cur = pl.div(row1_norm, col_sum) + row2_cur = pl.div(row2_norm, col_sum) + row3_cur = pl.div(row3_norm, col_sum) + + for ti in pl.unroll(COMB_T_TILE): + for c in pl.unroll(HC_MULT): + comb_t_idx = t0 + ti + pl.write( + comb_flat, + [comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c], + pl.read(row0_cur, [ti, c]), + ) + pl.write( + comb_flat, + [comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c], + pl.read(row1_cur, [ti, c]), + ) + pl.write( + comb_flat, + [comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c], + pl.read(row2_cur, [ti, c]), + ) + pl.write( + comb_flat, + [comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c], + pl.read(row3_cur, [ti, c]), + ) with pl.at(level=pl.Level.CORE_GROUP, name_hint="write_post"): for token_idx in pl.range(0, T, 1): @@ -250,25 +268,90 @@ def hc_pre( pl.read(post_pad_flat, [token_idx * HC_PAD + h]), ) - pre_val_flat = pl.reshape(pre_val, [T * HC_PAD]) x_mixed_view = pl.reshape(x_mixed, [T, D]) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="mix_x"): - for token_idx in pl.range(0, T, 1): + for t0 in pl.parallel(0, T, T_TILE): + with pl.at(level=pl.Level.CORE_GROUP, name_hint="mix_x"): + pre0 = pl.reshape( + pl.load( + pre_val_t, + [0, t0], + [1, T_TILE], + target_memory=pl.MemorySpace.Vec, + ), + [T_TILE, 1], + ) + pre1 = pl.reshape( + pl.load( + pre_val_t, + [1, t0], + [1, T_TILE], + target_memory=pl.MemorySpace.Vec, + ), + [T_TILE, 1], + ) + pre2 = pl.reshape( + pl.load( + pre_val_t, + [2, t0], + [1, T_TILE], + target_memory=pl.MemorySpace.Vec, + ), + [T_TILE, 1], + ) + pre3 = pl.reshape( + pl.load( + pre_val_t, + [3, t0], + [1, T_TILE], + target_memory=pl.MemorySpace.Vec, + ), + [T_TILE, 1], + ) for db in pl.range(D_BLOCKS): d0 = db * D_CHUNK - y_row = pl.tile.full([1, D_CHUNK], dtype=pl.FP32, value=0.0) - for h in pl.range(HC_MULT): - pre_th = pl.read(pre_val_flat, [token_idx * HC_PAD + h]) - x_row = pl.load( - x_flat_fp32, - [token_idx, h * D + d0], - [1, D_CHUNK], + x0 = pl.cast( + pl.load( + x_flat, + [t0, 0 * D + d0], + [T_TILE, D_CHUNK], target_memory=pl.MemorySpace.Vec, - ) - y_row = pl.add(y_row, pl.mul(x_row, pre_th)) + ), + target_type=pl.FP32, + ) + x1 = pl.cast( + pl.load( + x_flat, + [t0, 1 * D + d0], + [T_TILE, D_CHUNK], + target_memory=pl.MemorySpace.Vec, + ), + target_type=pl.FP32, + ) + x2 = pl.cast( + pl.load( + x_flat, + [t0, 2 * D + d0], + [T_TILE, D_CHUNK], + target_memory=pl.MemorySpace.Vec, + ), + target_type=pl.FP32, + ) + x3 = pl.cast( + pl.load( + x_flat, + [t0, 3 * D + d0], + [T_TILE, D_CHUNK], + target_memory=pl.MemorySpace.Vec, + ), + target_type=pl.FP32, + ) + y_tile = pl.add( + pl.add(pl.row_expand_mul(x0, pre0), pl.row_expand_mul(x1, pre1)), + pl.add(pl.row_expand_mul(x2, pre2), pl.row_expand_mul(x3, pre3)), + ) x_mixed_view = pl.store( - pl.cast(y_row, target_type=pl.BF16, mode="rint"), - [token_idx, d0], + pl.cast(y_tile, target_type=pl.BF16, mode="rint"), + [t0, d0], x_mixed_view, ) x_mixed = pl.reshape(x_mixed_view, [B, S, D]) @@ -301,17 +384,17 @@ def golden_hc_pre(tensors): x_flat_2d = x_flat.reshape(T, HC_DIM) sq_sum = torch.zeros(T, 1, dtype=torch.float32) - for k0 in range(0, HC_DIM, K_CHUNK): - x_chunk = x_flat_2d[:, k0:k0 + K_CHUNK] + for k0 in range(0, HC_DIM, RMS_K_CHUNK): + x_chunk = x_flat_2d[:, k0:k0 + RMS_K_CHUNK] sq_sum += (x_chunk * x_chunk).sum(dim=1, keepdim=True) rsqrt = torch.rsqrt(sq_sum * HC_DIM_INV + NORM_EPS) mix_cols = [] for m in range(MIX_HC): mix_col = torch.zeros(T, 1, dtype=torch.float32) - for k0 in range(0, HC_DIM, K_CHUNK): - x_chunk = x_flat_2d[:, k0:k0 + K_CHUNK] - w_chunk = hc_fn[m:m + 1, k0:k0 + K_CHUNK] + for k0 in range(0, HC_DIM, LINEAR_K_CHUNK): + x_chunk = x_flat_2d[:, k0:k0 + LINEAR_K_CHUNK] + w_chunk = hc_fn[m:m + 1, k0:k0 + LINEAR_K_CHUNK] mix_col += (x_chunk * w_chunk).sum(dim=1, keepdim=True) mix_cols.append(mix_col * rsqrt) mixes = torch.cat(mix_cols, dim=1).reshape(B, S, MIX_HC) # [B, S, mix_hc]