diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df7c8ad8..8b17dbf4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -330,9 +330,9 @@ jobs: shell: bash env: PYTHONPATH: ${{ github.workspace }} - PTO2_RING_DEP_POOL: 524288 - PTO2_RING_TASK_WINDOW: 524288 - PTO2_RING_HEAP: 1073741824 + PTO2_RING_DEP_POOL: 1048576 + PTO2_RING_TASK_WINDOW: 1048576 + PTO2_RING_HEAP: 4294967296 run: | set +e if [ "${{ github.event_name }}" = "pull_request" ]; then diff --git a/.github/workflows/daily_ci.yml b/.github/workflows/daily_ci.yml index 4563a2ae..0b8ebbe8 100644 --- a/.github/workflows/daily_ci.yml +++ b/.github/workflows/daily_ci.yml @@ -242,9 +242,9 @@ jobs: shell: bash env: PYTHONPATH: ${{ github.workspace }} - PTO2_RING_DEP_POOL: 524288 - PTO2_RING_TASK_WINDOW: 524288 - PTO2_RING_HEAP: 1073741824 + PTO2_RING_DEP_POOL: 1048576 + PTO2_RING_TASK_WINDOW: 1048576 + PTO2_RING_HEAP: 4294967296 PYPTO_LOG_LEVEL: error PYPTO_WARNING_LEVEL: none run: | diff --git a/models/deepseek/v4/decode_csa.py b/models/deepseek/v4/decode_csa.py index e7184770..861bf2a2 100644 --- a/models/deepseek/v4/decode_csa.py +++ b/models/deepseek/v4/decode_csa.py @@ -399,7 +399,7 @@ def build_tensor_specs(layer_id: int = 0): parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + choices=["a2a3", "a5"]) parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--layer-id", type=int, default=0) parser.add_argument("--seed", type=int, default=0) diff --git a/models/deepseek/v4/decode_hca.py b/models/deepseek/v4/decode_hca.py index accd357b..089a9dd1 100644 --- a/models/deepseek/v4/decode_hca.py +++ b/models/deepseek/v4/decode_hca.py @@ -347,7 +347,7 @@ def build_tensor_specs(layer_id: int = 0): parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + choices=["a2a3", "a5"]) parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--layer-id", type=int, default=0) parser.add_argument("--seed", type=int, default=0) diff --git a/models/deepseek/v4/decode_sparse_attn.py b/models/deepseek/v4/decode_sparse_attn.py index 7a69805b..85e0f334 100644 --- a/models/deepseek/v4/decode_sparse_attn.py +++ b/models/deepseek/v4/decode_sparse_attn.py @@ -6,38 +6,7 @@ # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. # See LICENSE in the root of the software repository for the full text of the License. # ----------------------------------------------------------------------------------------------------------- -"""DeepSeek-V4 sparse attention with grouped output projection (decode). - -Corresponds to model.py Attention.forward decode branch lines 533-542: - o = sparse_attn(q, kv_cache, attn_sink, topk_idxs, softmax_scale) - apply_rotary_emb(o[..., -rope_dim:], freqs_cis, inverse=True) - o = o.view(bsz, seqlen, self.n_local_groups, -1) - wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) - o = torch.einsum("bsgd,grd->bsgr", o, wo_a) - x = self.wo_b(o.flatten(2)) - -Inputs: -- q : query tensor from MLA prolog (RoPE already applied) -- ori_kv / cmp_kv : paged sliding-window and paged compressed KV pools -- cmp_sparse_indices: per-token absolute indices (window + compressed concat) - computed by orchestrator (window topk + indexer/HCA topk) -- attn_sink : per-head sink term added inside softmax -- freqs_cos/sin : split-half inverse-RoPE tables for the sparse-attn output -- wo_a : grouped first-stage output-projection weights from model.py:537-541 -- wo_b / wo_b_scale : grouped second-stage output-projection W8 per-channel weights - -Standalone contract: -- `cmp_sparse_indices[t, :]` may contain `-1` pads. -- entries in `[0, WIN)` address the logical sliding-window ring slots. -- entries in `[WIN, WIN + cmp_valid)` address compressed cache slots, where - `seq_used = seqused_kv[b] - S + 1 + s` - and `cmp_valid = max(seq_used - min(WIN, seq_used), 0)`. -- the grouped projection layout matches: - `o.view(bsz, seqlen, self.n_local_groups, -1)` with - `self.n_local_groups == O_GROUPS` and `-1 == O_GROUP_IN`. - -The standalone harness exposes `--compress-ratio {0,4,128}` for testing. -""" +"""DeepSeek-V4 sparse attention with grouped output projection (decode).""" import pypto.language as pl @@ -67,30 +36,29 @@ # kernel-local SUPPORTED_COMPRESS_RATIOS = (0, 4, 128) -DEFAULT_COMPRESS_RATIO = 128 -ORI_MAX_BLOCKS = 1 # paged-KV pool: ori (sliding-window) blocks per batch +DEFAULT_COMPRESS_RATIO = 0 +ORI_MAX_BLOCKS = 1 # paged-KV pool: ori (sliding-window) blocks per batch ORI_BLOCK_NUM = B * ORI_MAX_BLOCKS -CMP_MAX_BLOCKS = 64 # paged-KV pool: compressed blocks per batch +CMP_MAX_BLOCKS = 64 # paged-KV pool: compressed blocks per batch CMP_BLOCK_NUM = B * CMP_MAX_BLOCKS # tiling -GATHER_TOKEN_TILE = 4 -ATTN_TOKEN_TILE = 8 ROPE_TOKEN_TILE = 4 ROPE_PACK_TOKEN_TILE = 32 ROPE_PACK_GROUP_TILE = 1 -ROPE_PACK_SPMD_BLOCKS = (T // ROPE_PACK_TOKEN_TILE) * O_GROUPS -MATMUL_ROW_PAD = 16 -SPARSE_ATTN_TILE = 64 -SPARSE_ATTN_BLOCKS = (TOPK + SPARSE_ATTN_TILE - 1) // SPARSE_ATTN_TILE -ROPE_CHUNK = 16 -ROPE_INTERLEAVE_CHUNK = 2 * ROPE_CHUNK -A_K_CHUNK = 128 -A_N_CHUNK = 128 -B_K_CHUNK = 128 -B_N_CHUNK = 128 if T >= 128 else 256 -QUANT_CHUNK = 32 if T >= 128 else (128 if T >= 64 else 256) +H_TILE = 16 +ATTN_K_TILE = 32 +ROPE_TILE = 16 +ROPE_INTERLEAVE_TILE = 2 * ROPE_TILE +A_T_TILE = 16 +A_K_TILE = 128 +A_N_TILE = 128 +B_T_TILE = 16 +B_K_TILE = 128 +B_N_TILE = 128 +QUANT_TILE = 32 QUANT_TOKEN_TILE = 8 +QUANT_K_TILE = O_GROUPS * O_LORA // 2 def get_standalone_cmp_valid(compress_ratio: int) -> int: @@ -106,461 +74,303 @@ def get_standalone_cmp_valid(compress_ratio: int) -> int: @pl.jit.inline def sparse_attn( - q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], - ori_kv: pl.Tensor[[ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - ori_block_table: pl.Tensor[[B, ORI_MAX_BLOCKS], pl.INT32], - cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], - cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], - attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B], pl.INT32], - freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], - freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], - even_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - odd_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], - wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], - wo_b_scale: pl.Tensor[[D], pl.FP32], - attn_out: pl.Tensor[[T, D], pl.BF16], + q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], + ori_kv: pl.Tensor[[ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], + ori_block_table: pl.Tensor[[B, ORI_MAX_BLOCKS], pl.INT32], + cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], + cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], + cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], + attn_sink: pl.Tensor[[H], pl.FP32], + seqused_kv: pl.Tensor[[B], pl.INT32], + freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], + freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], + even_select_local: pl.Tensor[[ROPE_INTERLEAVE_TILE, ROPE_TILE], pl.BF16], + odd_select_local: pl.Tensor[[ROPE_INTERLEAVE_TILE, ROPE_TILE], pl.BF16], + wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], + wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], + wo_b_scale: pl.Tensor[[D], pl.FP32], + attn_out: pl.Tensor[[T, D], pl.BF16], ): """Run sparse decode attention, inverse RoPE, and grouped output projection.""" - A_K_BLOCKS = O_GROUP_IN // A_K_CHUNK - A_N_BLOCKS = O_LORA // A_N_CHUNK - A_AMAX_BLOCKS = O_GROUPS * A_N_BLOCKS - B_K_BLOCKS = (O_GROUPS * O_LORA) // B_K_CHUNK - B_N_BLOCKS = D // B_N_CHUNK - - q_flat = pl.reshape(q, [T * H, HEAD_DIM]) + # Gather the sliding-window + compressed-cache rows into a per-token packed KV list. ori_kv_flat = pl.reshape(ori_kv, [ORI_BLOCK_NUM * BLOCK_SIZE, HEAD_DIM]) - ori_block_table_flat = pl.reshape(ori_block_table, [B * ORI_MAX_BLOCKS]) cmp_kv_flat = pl.reshape(cmp_kv, [CMP_BLOCK_NUM * BLOCK_SIZE, HEAD_DIM]) - cmp_block_table_flat = pl.reshape(cmp_block_table, [B * CMP_MAX_BLOCKS]) - cmp_sparse_indices_flat = pl.reshape(cmp_sparse_indices, [T * TOPK]) sparse_kv = pl.create_tensor([T * TOPK, HEAD_DIM], dtype=pl.BF16) + for g_t in pl.spmd(T, name_hint="gather_kv"): + g_b = g_t // S + g_s = g_t - g_b * S + g_seq_end = pl.read(seqused_kv, [g_b]) + g_seq_len = g_seq_end - S + 1 + g_s + g_win_v = pl.min(WIN, g_seq_len) + g_cmp_v = g_seq_len - g_win_v + g_tk_v = pl.min(IDX_TOPK, g_cmp_v) + g_sparse_k = g_win_v + g_tk_v + g_kv_base = g_t * TOPK + + # Window prefix: contiguous, copy as one row block. + g_ori_blk = pl.cast(pl.read(ori_block_table, [g_b, 0]), pl.INDEX) + g_ori_row = g_ori_blk * BLOCK_SIZE + window_rows = pl.set_validshape(ori_kv_flat[g_ori_row : g_ori_row + WIN, 0 : HEAD_DIM], g_win_v, HEAD_DIM) + sparse_kv[g_kv_base : g_kv_base + WIN, 0 : HEAD_DIM] = window_rows + + # Compressed-cache hits after the window prefix (sparse row-gather). + for g_kk in pl.range(g_tk_v): + g_raw = pl.read(cmp_sparse_indices, [g_t, g_win_v + g_kk]) + g_slot = g_raw - WIN + g_blk = pl.cast(pl.read(cmp_block_table, [g_b, g_slot // BLOCK_SIZE]), pl.INDEX) + g_src_row = g_blk * BLOCK_SIZE + g_slot % BLOCK_SIZE + g_dst_row = g_kv_base + g_win_v + g_kk + sparse_kv[g_dst_row : g_dst_row + 1, 0 : HEAD_DIM] = cmp_kv_flat[g_src_row : g_src_row + 1, 0 : HEAD_DIM] + + # Zero-pad the tail so ratio-0/128 sanity modes stay deterministic. + zero_kv_row = pl.full([1, HEAD_DIM], dtype=pl.BF16, value=0.0) + for g_pad_kk in pl.range(g_sparse_k, TOPK): + g_pad_row = g_kv_base + g_pad_kk + sparse_kv[g_pad_row : g_pad_row + 1, 0 : HEAD_DIM] = zero_kv_row + + # Sparse-K attention: qk_pv writes per-tile (mi, li, oi) into GM scratch, + # merge_norm reads them back. ATTN_K_TILE keeps K and V right-buffer + # copies together under the 64KB L1B limit. + q_flat = pl.reshape(q, [T * H, HEAD_DIM]) attn_rope_stage = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.BF16) - sparse_exp = pl.create_tensor([T * H * SPARSE_ATTN_BLOCKS, SPARSE_ATTN_TILE], dtype=pl.BF16) - sparse_blk_mi = pl.create_tensor([T * H * SPARSE_ATTN_BLOCKS, 1], dtype=pl.FP32) - sparse_blk_li = pl.create_tensor([T * H * SPARSE_ATTN_BLOCKS, 1], dtype=pl.FP32) - sparse_blk_oi = pl.create_tensor([T * H * SPARSE_ATTN_BLOCKS, HEAD_DIM], dtype=pl.FP32) - sparse_mi = pl.create_tensor([T * H, 1], dtype=pl.FP32) - sparse_li = pl.create_tensor([T * H, 1], dtype=pl.FP32) - sparse_oi = pl.create_tensor([T * H, HEAD_DIM], dtype=pl.FP32) - o_proj_even = pl.create_tensor([T * H, HALF_ROPE], dtype=pl.FP32) - o_proj_odd = pl.create_tensor([T * H, HALF_ROPE], dtype=pl.FP32) - rope_even_interleave_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) - rope_odd_interleave_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) o_packed = pl.create_tensor([O_GROUPS * T, O_GROUP_IN], dtype=pl.BF16) - - # Stage 1: gather the sparse KV rows selected by the sliding-window path and - # the compressed-cache path into one per-token packed KV list. - for gather_t0 in pl.parallel(0, T, GATHER_TOKEN_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_gather_kv_topk_tile"): - for gather_dt in pl.range(GATHER_TOKEN_TILE): - gather_t = gather_t0 + gather_dt - gather_b = gather_t // S - gather_s = gather_t - gather_b * S - gather_seq_final = pl.read(seqused_kv, [gather_b]) - gather_seq_used = gather_seq_final - S + 1 + gather_s - gather_window_valid = pl.min(WIN, gather_seq_used) - gather_cmp_valid = gather_seq_used - gather_window_valid - gather_cmp_topk_valid = pl.min(IDX_TOPK, gather_cmp_valid) - gather_sparse_k = gather_window_valid + gather_cmp_topk_valid - gather_ori_block_base = gather_b * ORI_MAX_BLOCKS - gather_cmp_block_base = gather_b * CMP_MAX_BLOCKS - gather_sparse_idx_base = gather_t * TOPK - gather_sparse_kv_base = gather_t * TOPK - - # The standalone decode contract uses a contiguous full-window - # prefix, so copy that prefix as one row block and leave the - # truly sparse compressed tail on the dynamic row-gather path. - gather_ori_blk = pl.cast(pl.read(ori_block_table_flat, [gather_ori_block_base]), pl.INDEX) - gather_ori_row = gather_ori_blk * BLOCK_SIZE - window_rows = pl.slice( - ori_kv_flat, - [WIN, HEAD_DIM], - [gather_ori_row, 0], - valid_shape=[gather_window_valid, HEAD_DIM], - ) - sparse_kv = pl.assemble(sparse_kv, window_rows, [gather_sparse_kv_base, 0]) - - # Append compressed-cache hits after the window prefix. - for gather_cmp_kk in pl.range(gather_cmp_topk_valid): - gather_cmp_idx_pos = gather_sparse_idx_base + gather_window_valid + gather_cmp_kk - gather_cmp_raw_idx = pl.read(cmp_sparse_indices_flat, [gather_cmp_idx_pos]) - gather_cmp_slot = gather_cmp_raw_idx - WIN - gather_cmp_block_slot = gather_cmp_slot // BLOCK_SIZE - gather_cmp_block_pos = gather_cmp_block_base + gather_cmp_block_slot - gather_cmp_blk = pl.cast(pl.read(cmp_block_table_flat, [gather_cmp_block_pos]), pl.INDEX) - gather_cmp_intra = gather_cmp_slot % BLOCK_SIZE - gather_cmp_row = gather_cmp_blk * BLOCK_SIZE + gather_cmp_intra - sparse_kv = pl.assemble( - sparse_kv, - cmp_kv_flat[gather_cmp_row : gather_cmp_row + 1, 0 : HEAD_DIM], - [gather_sparse_kv_base + gather_window_valid + gather_cmp_kk, 0], - ) - - # Keep padded rows deterministic for ratio-0/128 sanity modes. - zero_kv_row = pl.full([1, HEAD_DIM], dtype=pl.BF16, value=0.0) - for gather_pad_kk in pl.range(gather_sparse_k, TOPK): - sparse_kv = pl.assemble(sparse_kv, zero_kv_row, [gather_sparse_kv_base + gather_pad_kk, 0]) - - for attn_t0 in pl.parallel(0, T, ATTN_TOKEN_TILE): - for h0 in pl.parallel(0, H, MATMUL_ROW_PAD): - # Stage 2a: QK + tile-local softmax for every sparse-K tile. - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_qk_softmax_tile"): - for qk_dt in pl.range(ATTN_TOKEN_TILE): - qk_t = attn_t0 + qk_dt - qk_b = qk_t // S - qk_s = qk_t - qk_b * S - qk_seq_final = pl.read(seqused_kv, [qk_b]) - qk_seq_used = qk_seq_final - S + 1 + qk_s - qk_window_valid = pl.min(WIN, qk_seq_used) - qk_cmp_valid = qk_seq_used - qk_window_valid - qk_cmp_topk_valid = pl.min(IDX_TOPK, qk_cmp_valid) - qk_sparse_k = qk_window_valid + qk_cmp_topk_valid - qk_sparse_kv_base = qk_t * TOPK - qk_head_row = qk_t * H + h0 - qk_q_batch = q_flat[qk_head_row : qk_head_row + MATMUL_ROW_PAD, 0 : HEAD_DIM] - - for qk_sb in pl.range(SPARSE_ATTN_BLOCKS): - qk_tile_start = qk_sb * SPARSE_ATTN_TILE - if qk_tile_start < qk_sparse_k: - qk_tile_valid = pl.min(SPARSE_ATTN_TILE, qk_sparse_k - qk_tile_start) - qk_kv_tile = sparse_kv[ - qk_sparse_kv_base + qk_tile_start : qk_sparse_kv_base + qk_tile_start + SPARSE_ATTN_TILE, - 0 : HEAD_DIM, - ] - qk_raw_scores = pl.matmul(qk_q_batch, qk_kv_tile, b_trans=True, out_dtype=pl.FP32) - qk_scores_valid = pl.slice( - pl.mul(qk_raw_scores, SOFTMAX_SCALE), - [MATMUL_ROW_PAD, SPARSE_ATTN_TILE], - [0, 0], - valid_shape=[MATMUL_ROW_PAD, qk_tile_valid], - ) - qk_scores = pl.fillpad(qk_scores_valid, pad_value=pl.PadValue.min) - qk_mi = pl.row_max(qk_scores) - qk_exp_scores = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) - qk_exp_scores_bf16 = pl.cast(qk_exp_scores, target_type=pl.BF16) - qk_li = pl.row_sum(pl.cast(qk_exp_scores_bf16, target_type=pl.FP32)) - qk_block_row = qk_t * H * SPARSE_ATTN_BLOCKS + qk_sb * H + h0 - sparse_exp = pl.assemble(sparse_exp, qk_exp_scores_bf16, [qk_block_row, 0]) - sparse_blk_mi = pl.assemble(sparse_blk_mi, qk_mi, [qk_block_row, 0]) - sparse_blk_li = pl.assemble(sparse_blk_li, qk_li, [qk_block_row, 0]) - - # Stage 2b: PV for each sparse-K tile. Keep the online merge in a - # separate scope under FLASH so the AIV live set stays bounded. - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_pv_tile"): - for pv_dt in pl.range(ATTN_TOKEN_TILE): - pv_t = attn_t0 + pv_dt - pv_b = pv_t // S - pv_s = pv_t - pv_b * S - pv_seq_final = pl.read(seqused_kv, [pv_b]) - pv_seq_used = pv_seq_final - S + 1 + pv_s - pv_window_valid = pl.min(WIN, pv_seq_used) - pv_cmp_valid = pv_seq_used - pv_window_valid - pv_cmp_topk_valid = pl.min(IDX_TOPK, pv_cmp_valid) - pv_sparse_k = pv_window_valid + pv_cmp_topk_valid - pv_sparse_kv_base = pv_t * TOPK - for pv_sb in pl.range(SPARSE_ATTN_BLOCKS): - pv_tile_start = pv_sb * SPARSE_ATTN_TILE - if pv_tile_start < pv_sparse_k: - pv_block_row = pv_t * H * SPARSE_ATTN_BLOCKS + pv_sb * H + h0 - pv_exp = sparse_exp[pv_block_row : pv_block_row + MATMUL_ROW_PAD, 0 : SPARSE_ATTN_TILE] - pv_kv_tile = sparse_kv[ - pv_sparse_kv_base + pv_tile_start : pv_sparse_kv_base + pv_tile_start + SPARSE_ATTN_TILE, - 0 : HEAD_DIM, - ] - pv_oi_tmp = pl.matmul(pv_exp, pv_kv_tile, out_dtype=pl.FP32) - sparse_blk_oi = pl.assemble(sparse_blk_oi, pv_oi_tmp, [pv_block_row, 0]) - - # Stage 2c: online-softmax merge across sparse-K tiles. - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_merge_tile"): - for merge_dt in pl.range(ATTN_TOKEN_TILE): - merge_t = attn_t0 + merge_dt - merge_b = merge_t // S - merge_s = merge_t - merge_b * S - merge_seq_final = pl.read(seqused_kv, [merge_b]) - merge_seq_used = merge_seq_final - S + 1 + merge_s - merge_window_valid = pl.min(WIN, merge_seq_used) - merge_cmp_valid = merge_seq_used - merge_window_valid - merge_cmp_topk_valid = pl.min(IDX_TOPK, merge_cmp_valid) - merge_sparse_k = merge_window_valid + merge_cmp_topk_valid - merge_head_row = merge_t * H + h0 - merge_block_row0 = merge_t * H * SPARSE_ATTN_BLOCKS + h0 - merge_mi = sparse_blk_mi[merge_block_row0 : merge_block_row0 + MATMUL_ROW_PAD, 0 : 1] - merge_li = sparse_blk_li[merge_block_row0 : merge_block_row0 + MATMUL_ROW_PAD, 0 : 1] - merge_oi = sparse_blk_oi[merge_block_row0 : merge_block_row0 + MATMUL_ROW_PAD, 0 : HEAD_DIM] - - for merge_sb in pl.range(1, SPARSE_ATTN_BLOCKS): - merge_tile_start = merge_sb * SPARSE_ATTN_TILE - if merge_tile_start < merge_sparse_k: - merge_block_row = merge_t * H * SPARSE_ATTN_BLOCKS + merge_sb * H + h0 - merge_cur_mi = sparse_blk_mi[merge_block_row : merge_block_row + MATMUL_ROW_PAD, 0 : 1] - merge_cur_li = sparse_blk_li[merge_block_row : merge_block_row + MATMUL_ROW_PAD, 0 : 1] - merge_cur_oi = sparse_blk_oi[merge_block_row : merge_block_row + MATMUL_ROW_PAD, 0 : HEAD_DIM] - merge_mi_new = pl.maximum(merge_mi, merge_cur_mi) - merge_alpha = pl.exp(pl.sub(merge_mi, merge_mi_new)) - merge_beta = pl.exp(pl.sub(merge_cur_mi, merge_mi_new)) - merge_li = pl.add(pl.mul(merge_alpha, merge_li), pl.mul(merge_beta, merge_cur_li)) - merge_oi = pl.add( - pl.row_expand_mul(merge_oi, merge_alpha), - pl.row_expand_mul(merge_cur_oi, merge_beta), - ) - merge_mi = merge_mi_new - - sparse_mi = pl.assemble(sparse_mi, merge_mi, [merge_head_row, 0]) - sparse_li = pl.assemble(sparse_li, merge_li, [merge_head_row, 0]) - sparse_oi = pl.assemble(sparse_oi, merge_oi, [merge_head_row, 0]) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_sparse_attn_norm_tile"): - for norm_dt in pl.range(ATTN_TOKEN_TILE): - norm_t = attn_t0 + norm_dt - norm_attn_head_row = norm_t * H + h0 - norm_oi = sparse_oi[norm_attn_head_row : norm_attn_head_row + MATMUL_ROW_PAD, 0 : HEAD_DIM] - norm_mi = sparse_mi[norm_attn_head_row : norm_attn_head_row + MATMUL_ROW_PAD, 0 : 1] - norm_li = sparse_li[norm_attn_head_row : norm_attn_head_row + MATMUL_ROW_PAD, 0 : 1] - norm_sink_bias = pl.reshape(attn_sink[h0 : h0 + MATMUL_ROW_PAD], [MATMUL_ROW_PAD, 1]) - norm_sink_tile = pl.add(pl.sub(norm_mi, norm_mi), norm_sink_bias) - norm_denom = pl.add(norm_li, pl.exp(pl.sub(norm_sink_tile, norm_mi))) - oi_out = pl.row_expand_div(norm_oi, norm_denom) - attn_stage_row = pl.cast( - oi_out[0 : MATMUL_ROW_PAD, 0 : HEAD_DIM], - target_type=pl.BF16, - ) - attn_rope_stage = pl.assemble( - attn_rope_stage, - attn_stage_row[0 : MATMUL_ROW_PAD, NOPE_DIM:HEAD_DIM], - [norm_attn_head_row, 0], - ) - - for norm_head_i in pl.range(MATMUL_ROW_PAD): - norm_global_head = h0 + norm_head_i - norm_g = norm_global_head // HEADS_PER_GROUP - norm_hh = norm_global_head - norm_g * HEADS_PER_GROUP - norm_pack_row = norm_g * T + norm_t - norm_head_col = norm_hh * HEAD_DIM - o_packed = pl.assemble( - o_packed, - attn_stage_row[norm_head_i : norm_head_i + 1, 0:NOPE_DIM], - [norm_pack_row, norm_head_col], - ) - - # Stage 3: inverse RoPE on the rope slice of the attention output by - # deinterleaving even/odd lanes, rotating them, then reinterleaving. - for rope_t0 in pl.parallel(0, T, ROPE_TOKEN_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_rope_slice_tile"): - for rope_dt in pl.range(ROPE_TOKEN_TILE): - rope_slice_t = rope_t0 + rope_dt - rope_slice_head_row = rope_slice_t * H - + sparse_blk_mi = pl.create_tensor([T * (H // H_TILE) * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE, 1], dtype=pl.FP32) + sparse_blk_li = pl.create_tensor([T * (H // H_TILE) * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE, 1], dtype=pl.FP32) + sparse_blk_oi = pl.create_tensor([T * (H // H_TILE) * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE, HEAD_DIM], dtype=pl.FP32) + + for qk_t in pl.spmd(T, name_hint="qk_pv"): + qk_b = qk_t // S + qk_s = qk_t - qk_b * S + qk_seq_end = pl.read(seqused_kv, [qk_b]) + qk_seq_len = qk_seq_end - S + 1 + qk_s + qk_win_v = pl.min(WIN, qk_seq_len) + qk_tk_v = pl.min(IDX_TOPK, qk_seq_len - qk_win_v) + qk_sparse_k = qk_win_v + qk_tk_v + qk_kv_base = qk_t * TOPK + qk_token_base = qk_t * (H // H_TILE) * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE + for qk_h_idx in pl.range((H // H_TILE)): + qk_h0 = qk_h_idx * H_TILE + qk_head_row = qk_t * H + qk_h0 + qk_q_tile = q_flat[qk_head_row : qk_head_row + H_TILE, 0 : HEAD_DIM] + qk_blk_base = qk_token_base + qk_h_idx * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE + + for qk_sb in pl.range(((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE)): + qk_s0 = qk_sb * ATTN_K_TILE + if qk_s0 < qk_sparse_k: + qk_s_v = pl.min(ATTN_K_TILE, qk_sparse_k - qk_s0) + qk_kv_k = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] + qk_raw = pl.matmul(qk_q_tile, qk_kv_k, b_trans=True, out_dtype=pl.FP32) + qk_scores_v = pl.set_validshape(pl.mul(qk_raw, SOFTMAX_SCALE), H_TILE, qk_s_v) + qk_scores = pl.fillpad(qk_scores_v, pad_value=pl.PadValue.min) + qk_mi = pl.row_max(qk_scores) + qk_exp = pl.exp(pl.row_expand_sub(qk_scores, qk_mi)) + qk_exp_bf16 = pl.cast(qk_exp, target_type=pl.BF16) + qk_li = pl.row_sum(pl.cast(qk_exp_bf16, target_type=pl.FP32)) + qk_kv_v = sparse_kv[qk_kv_base + qk_s0 : qk_kv_base + qk_s0 + ATTN_K_TILE, 0 : HEAD_DIM] + qk_oi = pl.matmul(qk_exp_bf16, qk_kv_v, out_dtype=pl.FP32) + qk_row = qk_blk_base + qk_sb * H_TILE + sparse_blk_mi[qk_row : qk_row + H_TILE, 0 : 1] = qk_mi + sparse_blk_li[qk_row : qk_row + H_TILE, 0 : 1] = qk_li + sparse_blk_oi[qk_row : qk_row + H_TILE, 0 : HEAD_DIM] = qk_oi + + # Online-softmax merge across sparse-K tiles, then sink-norm. + for m_t in pl.spmd(T, name_hint="merge_norm"): + m_b = m_t // S + m_s = m_t - m_b * S + m_seq_end = pl.read(seqused_kv, [m_b]) + m_seq_len = m_seq_end - S + 1 + m_s + m_win_v = pl.min(WIN, m_seq_len) + m_tk_v = pl.min(IDX_TOPK, m_seq_len - m_win_v) + m_sparse_k = m_win_v + m_tk_v + m_token_base = m_t * (H // H_TILE) * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE + + for m_h_idx in pl.range((H // H_TILE)): + m_h0 = m_h_idx * H_TILE + m_blk_base = m_token_base + m_h_idx * ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE) * H_TILE + m_mi = sparse_blk_mi[m_blk_base : m_blk_base + H_TILE, 0 : 1] + m_li = sparse_blk_li[m_blk_base : m_blk_base + H_TILE, 0 : 1] + m_oi = sparse_blk_oi[m_blk_base : m_blk_base + H_TILE, 0 : HEAD_DIM] + + for m_sb in pl.range(1, ((TOPK + ATTN_K_TILE - 1) // ATTN_K_TILE)): + m_s0 = m_sb * ATTN_K_TILE + if m_s0 < m_sparse_k: + m_row = m_blk_base + m_sb * H_TILE + m_cur_mi = sparse_blk_mi[m_row : m_row + H_TILE, 0 : 1] + m_cur_li = sparse_blk_li[m_row : m_row + H_TILE, 0 : 1] + m_cur_oi = sparse_blk_oi[m_row : m_row + H_TILE, 0 : HEAD_DIM] + m_mi_new = pl.maximum(m_mi, m_cur_mi) + m_alpha = pl.exp(pl.sub(m_mi, m_mi_new)) + m_beta = pl.exp(pl.sub(m_cur_mi, m_mi_new)) + m_li = pl.add(pl.mul(m_alpha, m_li), pl.mul(m_beta, m_cur_li)) + m_oi = pl.add(pl.row_expand_mul(m_oi, m_alpha), pl.row_expand_mul(m_cur_oi, m_beta)) + m_mi = m_mi_new + + n_sink_bias = pl.reshape(attn_sink[m_h0 : m_h0 + H_TILE], [H_TILE, 1]) + n_sink_tile = pl.add(pl.sub(m_mi, m_mi), n_sink_bias) + n_denom = pl.add(m_li, pl.exp(pl.sub(n_sink_tile, m_mi))) + n_out = pl.cast(pl.row_expand_div(m_oi, n_denom)[0 : H_TILE, 0 : HEAD_DIM], target_type=pl.BF16) + n_rope_row = m_t * H + m_h0 + attn_rope_stage[n_rope_row : n_rope_row + H_TILE, 0 : ROPE_DIM] = n_out[0 : H_TILE, NOPE_DIM : HEAD_DIM] + + for n_hi in pl.range(H_TILE): + n_gh = m_h0 + n_hi + n_g = n_gh // HEADS_PER_GROUP + n_hh = n_gh - n_g * HEADS_PER_GROUP + n_pack_row = n_g * T + m_t + n_col = n_hh * HEAD_DIM + o_packed[n_pack_row : n_pack_row + 1, n_col : n_col + NOPE_DIM] = n_out[n_hi : n_hi + 1, 0 : NOPE_DIM] + + # Inverse RoPE: deinterleave even/odd lanes, rotate with cos/sin, reinterleave. + rope_even_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) + rope_odd_buf = pl.create_tensor([T * H, ROPE_DIM], dtype=pl.FP32) + for r_idx in pl.spmd(T // ROPE_TOKEN_TILE, name_hint="rope"): + r_t0 = r_idx * ROPE_TOKEN_TILE + for r_dt in pl.range(ROPE_TOKEN_TILE): + r_t = r_t0 + r_dt + r_row = r_t * H + + for r_r0 in pl.range(0, HALF_ROPE, ROPE_TILE): # Split interleaved rope lanes into even and odd halves. - for rope_slice_r0 in pl.range(0, HALF_ROPE, ROPE_CHUNK): - rope_tile = attn_rope_stage[ - rope_slice_head_row : rope_slice_head_row + H, - 2 * rope_slice_r0 : 2 * rope_slice_r0 + ROPE_INTERLEAVE_CHUNK, - ] - rope_slice_even_chunk = pl.matmul(rope_tile, even_select_local, out_dtype=pl.FP32) - rope_slice_odd_chunk = pl.matmul(rope_tile, odd_select_local, out_dtype=pl.FP32) - o_proj_even = pl.assemble(o_proj_even, rope_slice_even_chunk, [rope_slice_head_row, rope_slice_r0]) - o_proj_odd = pl.assemble(o_proj_odd, rope_slice_odd_chunk, [rope_slice_head_row, rope_slice_r0]) - - # Stage 3: rotate even/odd halves and immediately reinterleave them. Keeping - # the BF16 rint cast local avoids the new orchestration SSA collision on - # written-then-read RoPE GM scratch tensors. - for rope_apply_t0 in pl.parallel(0, T, ROPE_TOKEN_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_rope_apply_assemble_tile"): - for rope_apply_dt in pl.range(ROPE_TOKEN_TILE): - rope_apply_t = rope_apply_t0 + rope_apply_dt - rope_apply_head_row = rope_apply_t * H - - # Reinterleave the rotated even/odd halves back to rope lane order. - for rope_asm_r0 in pl.range(0, HALF_ROPE, ROPE_CHUNK): - cos_chunk = pl.cast( - freqs_cos[rope_apply_t : rope_apply_t + 1, rope_asm_r0 : rope_asm_r0 + ROPE_CHUNK], - target_type=pl.FP32, - ) - sin_chunk = pl.cast( - freqs_sin[rope_apply_t : rope_apply_t + 1, rope_asm_r0 : rope_asm_r0 + ROPE_CHUNK], - target_type=pl.FP32, - ) - rope_apply_even_chunk = o_proj_even[ - rope_apply_head_row : rope_apply_head_row + H, - rope_asm_r0 : rope_asm_r0 + ROPE_CHUNK, - ] - rope_apply_odd_chunk = o_proj_odd[ - rope_apply_head_row : rope_apply_head_row + H, - rope_asm_r0 : rope_asm_r0 + ROPE_CHUNK, - ] - rope_even_acc = pl.add( - pl.col_expand_mul(rope_apply_even_chunk, cos_chunk), - pl.col_expand_mul(rope_apply_odd_chunk, sin_chunk), - ) - rope_odd_acc = pl.sub( - pl.col_expand_mul(rope_apply_odd_chunk, cos_chunk), - pl.col_expand_mul(rope_apply_even_chunk, sin_chunk), - ) - rope_rot_even_chunk = pl.cast(rope_even_acc, target_type=pl.BF16, mode="rint") - rope_rot_odd_chunk = pl.cast(rope_odd_acc, target_type=pl.BF16, mode="rint") - rope_even_interleave = pl.matmul( - rope_rot_even_chunk, - even_select_local, - b_trans=True, - out_dtype=pl.FP32, - ) - rope_odd_interleave = pl.matmul( - rope_rot_odd_chunk, - odd_select_local, - b_trans=True, - out_dtype=pl.FP32, - ) - rope_even_interleave_buf = pl.assemble( - rope_even_interleave_buf, - rope_even_interleave, - [rope_apply_head_row, 2 * rope_asm_r0], - ) - rope_odd_interleave_buf = pl.assemble( - rope_odd_interleave_buf, - rope_odd_interleave, - [rope_apply_head_row, 2 * rope_asm_r0], - ) - - for rope_pack_block in pl.spmd(ROPE_PACK_SPMD_BLOCKS, name_hint="cfa_proj_rope_pack_group_spmd"): - rope_pack_token_block = rope_pack_block // O_GROUPS - rope_pack_g = rope_pack_block - rope_pack_token_block * O_GROUPS - rope_combine_t0 = rope_pack_token_block * ROPE_PACK_TOKEN_TILE - - for rope_combine_dt in pl.range(ROPE_PACK_TOKEN_TILE): - rope_combine_t = rope_combine_t0 + rope_combine_dt - rope_pack_head_row = rope_combine_t * H + rope_pack_g * HEADS_PER_GROUP - - # Merge and write only this group's inverse-RoPE tail. - rope_even_tile = rope_even_interleave_buf[ - rope_pack_head_row : rope_pack_head_row + HEADS_PER_GROUP, - 0 : ROPE_DIM, - ] - rope_odd_tile = rope_odd_interleave_buf[ - rope_pack_head_row : rope_pack_head_row + HEADS_PER_GROUP, - 0 : ROPE_DIM, - ] - rope_full = pl.cast( - pl.add(rope_even_tile, rope_odd_tile), - target_type=pl.BF16, - ) - rope_pack_row = rope_pack_g * T + rope_combine_t - for rope_pack_hh in pl.range(HEADS_PER_GROUP): - rope_pack_head_col = rope_pack_hh * HEAD_DIM + NOPE_DIM - o_packed = pl.assemble( - o_packed, - rope_full[rope_pack_hh : rope_pack_hh + 1, 0:ROPE_DIM], - [rope_pack_row, rope_pack_head_col], - ) - + r_tile = attn_rope_stage[r_row : r_row + H, 2 * r_r0 : 2 * r_r0 + ROPE_INTERLEAVE_TILE] + r_even = pl.matmul(r_tile, even_select_local, out_dtype=pl.FP32) + r_odd = pl.matmul(r_tile, odd_select_local, out_dtype=pl.FP32) + + # Rotate the even/odd halves with cos/sin. + r_cos = pl.cast(freqs_cos[r_t : r_t + 1, r_r0 : r_r0 + ROPE_TILE], target_type=pl.FP32) + r_sin = pl.cast(freqs_sin[r_t : r_t + 1, r_r0 : r_r0 + ROPE_TILE], target_type=pl.FP32) + r_even_rot = pl.add(pl.col_expand_mul(r_even, r_cos), pl.col_expand_mul(r_odd, r_sin)) + r_odd_rot = pl.sub(pl.col_expand_mul(r_odd, r_cos), pl.col_expand_mul(r_even, r_sin)) + r_even_bf16 = pl.cast(r_even_rot, target_type=pl.BF16, mode="rint") + r_odd_bf16 = pl.cast(r_odd_rot, target_type=pl.BF16, mode="rint") + + # Reinterleave the rotated halves back to rope-lane order. + r_even_il = pl.matmul(r_even_bf16, even_select_local, b_trans=True, out_dtype=pl.FP32) + r_odd_il = pl.matmul(r_odd_bf16, odd_select_local, b_trans=True, out_dtype=pl.FP32) + rope_even_buf[r_row : r_row + H, 2 * r_r0 : 2 * r_r0 + ROPE_INTERLEAVE_TILE] = r_even_il + rope_odd_buf[r_row : r_row + H, 2 * r_r0 : 2 * r_r0 + ROPE_INTERLEAVE_TILE] = r_odd_il + + for rp_block in pl.spmd((T // ROPE_PACK_TOKEN_TILE) * O_GROUPS, name_hint="rope_pack"): + rp_tb = rp_block // O_GROUPS + rp_g = rp_block - rp_tb * O_GROUPS + rp_t0 = rp_tb * ROPE_PACK_TOKEN_TILE + + for rp_dt in pl.range(ROPE_PACK_TOKEN_TILE): + rp_t = rp_t0 + rp_dt + rp_row = rp_t * H + rp_g * HEADS_PER_GROUP + + # Merge and write only this group's inverse-RoPE tail of o_packed. + rp_even = rope_even_buf[rp_row : rp_row + HEADS_PER_GROUP, 0 : ROPE_DIM] + rp_odd = rope_odd_buf[rp_row : rp_row + HEADS_PER_GROUP, 0 : ROPE_DIM] + rp_full = pl.cast(pl.add(rp_even, rp_odd), target_type=pl.BF16) + rp_pack_row = rp_g * T + rp_t + for rp_hh in pl.range(HEADS_PER_GROUP): + rp_col = rp_hh * HEAD_DIM + NOPE_DIM + o_packed[rp_pack_row : rp_pack_row + 1, rp_col : rp_col + ROPE_DIM] = rp_full[rp_hh : rp_hh + 1, 0 : ROPE_DIM] + + # Grouped BF16 projection `o_packed @ wo_a^T` -> `o_r`. Vec post-process + # (BF16 store + per-row partial amax) is T-tiled inside the scope as a + # pypto#1472 workaround — without it the fused proj_a AIV side oversizes + # UB and AllocateMemoryAddr rejects the kernel. o_r = pl.create_tensor([T, O_GROUPS * O_LORA], dtype=pl.BF16) - o_r_i8 = pl.create_tensor([T, O_GROUPS * O_LORA], dtype=pl.INT8) - o_r_amax_parts = pl.create_tensor([A_AMAX_BLOCKS, T], dtype=pl.FP32) - o_r_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) - - # Stage 5: grouped BF16 projection `o_packed @ wo_a^T`, producing the - # low-rank intermediate activation `o_r`. - for g in pl.parallel(0, O_GROUPS, 1): + o_r_amax_parts = pl.create_tensor([O_GROUPS * (O_LORA // A_N_TILE), T], dtype=pl.FP32) + for proj_a_block in pl.spmd(O_GROUPS * (O_LORA // A_N_TILE), name_hint="proj_a"): + # K-split BF16 matmul for one wo_a output tile. Stays in + # peel-first-iter form: the `pl.create_tensor` + `if k0 == 0` + # carry hits pypto#1540 on the 3D wo_a slice. + g = proj_a_block // (O_LORA // A_N_TILE) + nb = proj_a_block - g * (O_LORA // A_N_TILE) row_base_o = g * T out_col_g = g * O_LORA - - for nb in pl.parallel(0, A_N_BLOCKS, 1): - n0 = nb * A_N_CHUNK - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_stage_a_accum"): - # K-split BF16 matmul for one wo_a output tile. - xa0_chunk = o_packed[row_base_o:row_base_o + T, 0:A_K_CHUNK] - wa0_chunk = wo_a[g:g + 1, n0:n0 + A_N_CHUNK, 0:A_K_CHUNK] - acc_a = pl.matmul(xa0_chunk, wa0_chunk, b_trans=True, out_dtype=pl.FP32) - for kb in pl.pipeline(1, A_K_BLOCKS, stage=2): - k0 = kb * A_K_CHUNK - xa_k_chunk = o_packed[row_base_o:row_base_o + T, k0:k0 + A_K_CHUNK] - wa_k_chunk = wo_a[g:g + 1, n0:n0 + A_N_CHUNK, k0:k0 + A_K_CHUNK] - acc_a = pl.matmul_acc(acc_a, xa_k_chunk, wa_k_chunk, b_trans=True) - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_stage_a_store_amax"): - # Store BF16 activations and expose the tile's row-wise partial amax for quant. - acc_a_2d = pl.reshape(acc_a, [T, A_N_CHUNK]) - acc_a_bf16 = pl.cast( - acc_a_2d, - target_type=pl.BF16, - ) - o_r[:, out_col_g + n0:out_col_g + n0 + A_N_CHUNK] = acc_a_bf16 - acc_a_f32 = pl.cast(acc_a_bf16, target_type=pl.FP32) - acc_a_abs = pl.maximum(acc_a_f32, pl.neg(acc_a_f32)) - acc_a_amax = pl.reshape(pl.row_max(acc_a_abs), [1, T]) - amax_part_row = g * A_N_BLOCKS + nb - o_r_amax_parts[amax_part_row:amax_part_row + 1, 0:T] = acc_a_amax - - # Stage 6: per-row symmetric INT8 quantization of `o_r` for the W8A8C16 - # second projection stage. - for quant_t0 in pl.parallel(0, T, QUANT_TOKEN_TILE): - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_stage_b_quant_tile"): - or_amax = pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) - for ab in pl.range(0, A_AMAX_BLOCKS, 1): - or_a_part = o_r_amax_parts[ab:ab + 1, quant_t0:quant_t0 + QUANT_TOKEN_TILE] - or_amax = pl.maximum(or_amax, or_a_part) - or_sq_row = pl.div(pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), or_amax) - or_scale_dq = pl.reshape(pl.recip(or_sq_row), [QUANT_TOKEN_TILE, 1]) - o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq - or_sq_col = pl.reshape(or_sq_row, [QUANT_TOKEN_TILE, 1]) - for k1 in pl.range(0, O_GROUPS * O_LORA, QUANT_CHUNK): - or_q_f32 = pl.cast(o_r[quant_t0:quant_t0 + QUANT_TOKEN_TILE, k1:k1 + QUANT_CHUNK], target_type=pl.FP32) - or_q_scaled = pl.row_expand_mul(or_q_f32, or_sq_col) - or_q_i32 = pl.cast(or_q_scaled, target_type=pl.INT32, mode="rint") - or_q_half = pl.cast(or_q_i32, target_type=pl.FP16, mode="round") - o_r_i8[quant_t0:quant_t0 + QUANT_TOKEN_TILE, k1:k1 + QUANT_CHUNK] = pl.cast( - or_q_half, - target_type=pl.INT8, - mode="trunc", - ) - - # Stage 7: INT8 projection `o_r_i8 @ wo_b^T`, then dequantize with the - # activation and weight scales into the final BF16 output. - for nb in pl.parallel(0, B_N_BLOCKS, 1): - n0 = nb * B_N_CHUNK - - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_stage_b_accum"): - # K-split INT8 GEMM for one output-channel tile. - xb0_chunk = o_r_i8[:, 0:B_K_CHUNK] - wb0_chunk = wo_b[n0:n0 + B_N_CHUNK, 0:B_K_CHUNK] - acc_b = pl.matmul(xb0_chunk, wb0_chunk, b_trans=True, out_dtype=pl.INT32) - for kb in pl.pipeline(1, B_K_BLOCKS, stage=2): - k0 = kb * B_K_CHUNK - xb_k_chunk = o_r_i8[:, k0:k0 + B_K_CHUNK] - wb_k_chunk = wo_b[n0:n0 + B_N_CHUNK, k0:k0 + B_K_CHUNK] + n0 = nb * A_N_TILE + amax_part_row = g * (O_LORA // A_N_TILE) + nb + + xa0_chunk = o_packed[row_base_o:row_base_o + T, 0:A_K_TILE] + wa0_chunk = wo_a[g:g + 1, n0:n0 + A_N_TILE, 0:A_K_TILE] + acc_a = pl.matmul(xa0_chunk, wa0_chunk, b_trans=True, out_dtype=pl.FP32) + for kb in pl.pipeline(1, O_GROUP_IN // A_K_TILE, stage=2): + k0 = kb * A_K_TILE + xa_k_chunk = o_packed[row_base_o:row_base_o + T, k0:k0 + A_K_TILE] + wa_k_chunk = wo_a[g:g + 1, n0:n0 + A_N_TILE, k0:k0 + A_K_TILE] + acc_a = pl.matmul_acc(acc_a, xa_k_chunk, wa_k_chunk, b_trans=True) + + acc_a_2d = pl.reshape(acc_a, [T, A_N_TILE]) + for tb in pl.range(0, T, A_T_TILE): + acc_t = acc_a_2d[tb:tb + A_T_TILE, 0:A_N_TILE] + acc_t_bf16 = pl.cast(acc_t, target_type=pl.BF16) + o_r[tb:tb + A_T_TILE, out_col_g + n0:out_col_g + n0 + A_N_TILE] = acc_t_bf16 + acc_t_f32 = pl.cast(acc_t_bf16, target_type=pl.FP32) + acc_t_abs = pl.maximum(acc_t_f32, pl.neg(acc_t_f32)) + acc_t_amax = pl.reshape(pl.row_max(acc_t_abs), [1, A_T_TILE]) + o_r_amax_parts[amax_part_row:amax_part_row + 1, tb:tb + A_T_TILE] = acc_t_amax + + # Per-row symmetric INT8 quant of `o_r`, K-tiled as a second parallel axis. + o_r_i8 = pl.create_tensor([T, O_GROUPS * O_LORA], dtype=pl.INT8) + o_r_scale_dq = pl.create_tensor([T, 1], dtype=pl.FP32) + for q_block in pl.spmd((T // QUANT_TOKEN_TILE) * ((O_GROUPS * O_LORA) // QUANT_K_TILE), name_hint="quant"): + qt_idx = q_block // ((O_GROUPS * O_LORA) // QUANT_K_TILE) + qk_idx = q_block - qt_idx * ((O_GROUPS * O_LORA) // QUANT_K_TILE) + quant_t0 = qt_idx * QUANT_TOKEN_TILE + k0 = qk_idx * QUANT_K_TILE + + or_amax = pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_AMAX_EPS) + for ab in pl.range(0, O_GROUPS * (O_LORA // A_N_TILE), 1): + or_a_part = o_r_amax_parts[ab:ab + 1, quant_t0:quant_t0 + QUANT_TOKEN_TILE] + or_amax = pl.maximum(or_amax, or_a_part) + or_sq_row = pl.div(pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), or_amax) + or_scale_dq = pl.reshape(pl.recip(or_sq_row), [QUANT_TOKEN_TILE, 1]) + o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq + or_sq_col = pl.reshape(or_sq_row, [QUANT_TOKEN_TILE, 1]) + for k1 in pl.range(k0, k0 + QUANT_K_TILE, QUANT_TILE): + or_q_f32 = pl.cast(o_r[quant_t0:quant_t0 + QUANT_TOKEN_TILE, k1:k1 + QUANT_TILE], target_type=pl.FP32) + or_q_scaled = pl.row_expand_mul(or_q_f32, or_sq_col) + or_q_i32 = pl.cast(or_q_scaled, target_type=pl.INT32, mode="rint") + or_q_half = pl.cast(or_q_i32, target_type=pl.FP16, mode="round") + o_r_i8[quant_t0:quant_t0 + QUANT_TOKEN_TILE, k1:k1 + QUANT_TILE] = pl.cast(or_q_half, target_type=pl.INT8, mode="trunc") + + # INT8 projection `o_r_i8 @ wo_b^T`, then dequantize -> final BF16 output. + for nb in pl.spmd(D // B_N_TILE, name_hint="proj_b"): + # K-split INT8 GEMM + dequant in one scope. T-tiled vec post-process + # is the pypto#1472 workaround (same as proj_a). + n0 = nb * B_N_TILE + acc_b = pl.create_tensor([T, B_N_TILE], dtype=pl.INT32) + for kb in pl.pipeline(0, (O_GROUPS * O_LORA) // B_K_TILE, stage=2): + k0 = kb * B_K_TILE + xb_k_chunk = o_r_i8[:, k0:k0 + B_K_TILE] + wb_k_chunk = wo_b[n0:n0 + B_N_TILE, k0:k0 + B_K_TILE] + if k0 == 0: + acc_b = pl.matmul(xb_k_chunk, wb_k_chunk, b_trans=True, out_dtype=pl.INT32) + else: acc_b = pl.matmul_acc(acc_b, xb_k_chunk, wb_k_chunk, b_trans=True) - with pl.at(level=pl.Level.CORE_GROUP, name_hint="cfa_proj_stage_b_store"): - # Apply the per-row and per-channel dequant scales before casting to BF16. - wb_scale_chunk = pl.reshape(wo_b_scale[n0:n0 + B_N_CHUNK], [1, B_N_CHUNK]) - attn_chunk = pl.cast(acc_b, target_type=pl.FP32, mode="none") - attn_chunk = pl.col_expand_mul(pl.row_expand_mul(attn_chunk, o_r_scale_dq), wb_scale_chunk) - attn_out[:, n0:n0 + B_N_CHUNK] = pl.cast(attn_chunk, target_type=pl.BF16, mode="rint") + wb_scale_chunk = pl.reshape(wo_b_scale[n0:n0 + B_N_TILE], [1, B_N_TILE]) + for b_tb in pl.range(0, T, B_T_TILE): + acc_b_t = acc_b[b_tb:b_tb + B_T_TILE, 0:B_N_TILE] + b_scale_t = o_r_scale_dq[b_tb:b_tb + B_T_TILE, 0:1] + attn_t = pl.cast(acc_b_t, target_type=pl.FP32, mode="none") + attn_t = pl.col_expand_mul(pl.row_expand_mul(attn_t, b_scale_t), wb_scale_chunk) + attn_out[b_tb:b_tb + B_T_TILE, n0:n0 + B_N_TILE] = pl.cast(attn_t, target_type=pl.BF16, mode="rint") return attn_out @pl.jit def sparse_attn_test( - q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], - ori_kv: pl.Tensor[[ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - ori_block_table: pl.Tensor[[B, ORI_MAX_BLOCKS], pl.INT32], - cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], - cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], - cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], - attn_sink: pl.Tensor[[H], pl.FP32], - seqused_kv: pl.Tensor[[B], pl.INT32], - freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], - freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], - even_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - odd_select_local: pl.Tensor[[ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], pl.BF16], - wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], - wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], - wo_b_scale: pl.Tensor[[D], pl.FP32], - attn_out: pl.Out[pl.Tensor[[T, D], pl.BF16]], + q: pl.Tensor[[T, H, HEAD_DIM], pl.BF16], + ori_kv: pl.Tensor[[ORI_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], + ori_block_table: pl.Tensor[[B, ORI_MAX_BLOCKS], pl.INT32], + cmp_kv: pl.Tensor[[CMP_BLOCK_NUM, BLOCK_SIZE, 1, HEAD_DIM], pl.BF16], + cmp_block_table: pl.Tensor[[B, CMP_MAX_BLOCKS], pl.INT32], + cmp_sparse_indices: pl.Tensor[[T, TOPK], pl.INT32], + attn_sink: pl.Tensor[[H], pl.FP32], + seqused_kv: pl.Tensor[[B], pl.INT32], + freqs_cos: pl.Tensor[[T, ROPE_DIM], pl.BF16], + freqs_sin: pl.Tensor[[T, ROPE_DIM], pl.BF16], + even_select_local: pl.Tensor[[ROPE_INTERLEAVE_TILE, ROPE_TILE], pl.BF16], + odd_select_local: pl.Tensor[[ROPE_INTERLEAVE_TILE, ROPE_TILE], pl.BF16], + wo_a: pl.Tensor[[O_GROUPS, O_LORA, O_GROUP_IN], pl.BF16], + wo_b: pl.Tensor[[D, O_GROUPS * O_LORA], pl.INT8], + wo_b_scale: pl.Tensor[[D], pl.FP32], + attn_out: pl.Out[pl.Tensor[[T, D], pl.BF16]], ): attn_out = sparse_attn( q, @@ -663,8 +473,8 @@ def golden_sparse_attn(tensors): block_mi = [] block_li = [] block_oi = [] - for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE): - kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE] + for tile_start in range(0, kv_b.shape[0], ATTN_K_TILE): + kv_tile = kv_b[tile_start:tile_start + ATTN_K_TILE] scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE mi = scores.max(dim=-1, keepdim=True).values exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float() @@ -793,15 +603,15 @@ def init_sin(): def init_odd_select_local(): """Build the chunk-local selector that extracts odd rope lanes from interleaved inputs.""" - matrix = torch.zeros((ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK)) - for i in range(ROPE_CHUNK): + matrix = torch.zeros((ROPE_INTERLEAVE_TILE, ROPE_TILE)) + for i in range(ROPE_TILE): matrix[2 * i + 1, i] = 1 return matrix def init_even_select_local(): """Build the chunk-local selector that extracts even rope lanes from interleaved inputs.""" - matrix = torch.zeros((ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK)) - for i in range(ROPE_CHUNK): + matrix = torch.zeros((ROPE_INTERLEAVE_TILE, ROPE_TILE)) + for i in range(ROPE_TILE): matrix[2 * i, i] = 1 return matrix @@ -831,8 +641,8 @@ def init_wo_b_scale(): TensorSpec("seqused_kv", [B], torch.int32, init_value=init_seqused_kv), TensorSpec("freqs_cos", [T, ROPE_DIM], torch.bfloat16, init_value=init_cos), TensorSpec("freqs_sin", [T, ROPE_DIM], torch.bfloat16, init_value=init_sin), - TensorSpec("even_select_local", [ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], torch.bfloat16, init_value=init_even_select_local), - TensorSpec("odd_select_local", [ROPE_INTERLEAVE_CHUNK, ROPE_CHUNK], torch.bfloat16, init_value=init_odd_select_local), + TensorSpec("even_select_local", [ROPE_INTERLEAVE_TILE, ROPE_TILE], torch.bfloat16, init_value=init_even_select_local), + TensorSpec("odd_select_local", [ROPE_INTERLEAVE_TILE, ROPE_TILE], torch.bfloat16, init_value=init_odd_select_local), TensorSpec("wo_a", [O_GROUPS, O_LORA, O_GROUP_IN], torch.bfloat16, init_value=init_wo_a), TensorSpec("wo_b", [D, O_GROUPS * O_LORA], torch.int8, init_value=init_wo_b), TensorSpec("wo_b_scale", [D], torch.float32, init_value=init_wo_b_scale), diff --git a/models/deepseek/v4/decode_swa.py b/models/deepseek/v4/decode_swa.py index fac23d8f..dd769dc2 100644 --- a/models/deepseek/v4/decode_swa.py +++ b/models/deepseek/v4/decode_swa.py @@ -304,7 +304,7 @@ def build_tensor_specs(layer_id: int = 0): parser = argparse.ArgumentParser() parser.add_argument("-p", "--platform", type=str, default="a2a3", - choices=["a2a3", "a2a3sim", "a5", "a5sim"]) + choices=["a2a3", "a5"]) parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--layer-id", type=int, default=0) parser.add_argument("--seed", type=int, default=0)