Refactor dsv4 decode_sparse_attn; restore a2a3 ring sizes; drop sim from decode_{swa,csa,hca}#393
Conversation
- Inline derived _BLOCKS expressions (SPARSE_ATTN_BLOCKS, ROPE_PACK_SPMD_BLOCKS, A/B K/N/AMAX blocks) at use sites; keep ORI_MAX_BLOCKS / CMP_MAX_BLOCKS pool configs. - Drop one-shot GATHER_TOKEN_TILE constant; use literal 4. - Rename _CHUNK constants to _TILE; collapse the T-conditional B_N_TILE / QUANT_TILE to single fixed values for the current decode shape. - Shorten pl.at / pl.spmd name_hint strings to the bare role (gather_kv, qk_softmax, pv, merge, norm, rope_*, proj_*, quant). - Add outer parens around the ceil-div tile-count expression where it was previously hidden behind a symbol, fixing latent operator precedence (T * H * ((TOPK + TILE - 1) // TILE)). - Move intermediate pl.create_tensor allocations next to the first pl.parallel that consumes them, grouped per stage. - Drop column-aligned signatures and aligned trailing comments. - Trim the module docstring to a single-line summary. - Switch DEFAULT_COMPRESS_RATIO to 0 for the standalone harness. Verified on a2a3 NPU: attn_out PASS (ratio_allclose).
The outer attention loop now uses `pl.parallel(T)` so each iteration processes one token directly, removing the redundant inner `pl.range(ATTN_TOKEN_TILE)` wrapper in the qk_softmax / pv / merge / norm scopes. Validated on a2a3 with the standalone harness.
Move the per-(token, head-tile) softmax scratch (sparse_exp, sparse_blk_mi/li/oi, sparse_mi/li/oi) inside the inner `pl.parallel(0, H, H_TILE)` task and drop the now-redundant T and H factors from their shapes. Each declaration sits right before its first-use `pl.at` scope. Also rename `MATMUL_ROW_PAD` to `H_TILE` to reflect its role as the head-tile size now that it is the row count of the local scratch tensors. `attn_rope_stage` and `o_packed` stay at outer scope because the later rope_slice and proj_a_accum loops consume them across tasks.
…se_attn - Collapse the per-token `merge` and `norm` `pl.at` scopes into a single `merge_norm` scope; the per-token mi/li/oi now stay in-register from the online-softmax recurrence into the sink-norm cast, dropping the `sparse_mi`, `sparse_li`, and `sparse_oi` per-task scratch tensors. - Replace the `pl.slice(..., valid_shape=...)` boilerplate around the QK score scaling with `pl.set_validshape`, which is the canonical way to attach a runtime-valid shape to an existing tile. Validated on a2a3 with the standalone harness.
- Replace every `pl.slice` / `pl.assemble` with the `dst[r:r+H, c:c+W]` subscript sugar; the single `pl.slice(..., valid_shape=...)` becomes `pl.set_validshape` on the subscript-sliced source. - Shorten the verbose per-scope prefixes (`gather_` -> `g_`, `merge_` -> `m_`, `norm_` -> `n_`, `rope_slice_` -> `rs_`, `rope_apply_` -> `ra_`, `rope_pack_/rope_combine_` -> `rp_`) and the common verbose suffixes (`_window_valid` -> `_win_v`, `_cmp_topk_valid` -> `_tk_v`, `_seq_final/used` -> `_seq_end/len`, `_sparse_kv_base` -> `_kv_base`, `_block_row` -> `_row`, `_tile_start` -> `_s0`, `_tile_valid` -> `_s_v`). Prefixes stay distinct per scope to keep `@pl.jit.inline` SSA-safe. - Inline `o_proj_even`/`o_proj_odd` style multi-line ops onto one line where they fit, and break only by `pl` operator when they don't, per the project's coding style. The orphan multi-line `pl.cast(...)` in `proj_a_store` and `quant` is also collapsed. Validated on a2a3 with the standalone harness.
- Merge rope_slice + rope_apply into a single `rope` scope; drop the `o_proj_even`/`o_proj_odd` GM scratch round-trip. Relies on the local pypto#1532 fix (one InCore param can now be loaded with both `b_trans=True` and `b_trans=False` in the same scope). - Fuse proj_a_accum + proj_a_store into one `proj_a` scope. The vec post-process (BF16 store + per-row partial amax) is T-tiled (`A_T_TILE=16`) inside the scope as a pypto#1472 workaround — without it the AIV UB live set exceeds the 192KB limit. K loop stays in peel-first-iter form because the `pl.create_tensor` + `if k0 == 0` carry style hits pypto#1540 on the 3D wo_a slice. - Fuse proj_b_accum + proj_b_store similarly with `B_T_TILE=16`; the 2D INT8 carry uses the `pl.create_tensor` + `if k0 == 0` form now that pypto#1501 is fixed. - Quant kernel: K is tiled as a second parallel axis (`QUANT_K_TILE`) so the quant task count doubles (T/TOKEN_TILE * K/QUANT_K_TILE). - Drop the redundant `g_dt` 4-token batching in gather_kv — use `pl.parallel(T)` directly. 2D `pl.read` lets `ori_block_table_flat`, `cmp_sparse_indices_flat`, and `cmp_block_table_flat` reshapes go. - Hoist `q_flat`/`ori_kv_flat`/`cmp_kv_flat` to right before their first use; trim verbose stage banners and outdated comments. Validated on a2a3 with the standalone harness.
…e_attn Convert every `for X in pl.parallel(...): with pl.at(name_hint="Y")` pattern to the single `for X in pl.spmd(N, name_hint="Y")` form (and inline the dependent index/stride math into the body). Doubly-nested parallel pairs are linearized into one spmd over the product of the two block counts. - gather_kv: pl.spmd(T) - rope: pl.spmd(T // ROPE_TOKEN_TILE) - proj_a: pl.spmd(O_GROUPS * (O_LORA // A_N_TILE)) - quant: pl.spmd((T // QUANT_TOKEN_TILE) * ((O_GROUPS * O_LORA) // QUANT_K_TILE)) - proj_b: pl.spmd(D // B_N_TILE) The Stage-2 attention block (qk_softmax / pv / merge_norm) keeps its nested `pl.parallel(T) > pl.parallel(0, H, H_TILE)` outer structure because each inner `pl.at` is its own scope; converting to spmd there would require splitting one task per scope and is unrelated to this sweep. Validated on a2a3 with the standalone harness.
Collapse the per-token `qk_softmax` and `pv` `pl.at` scopes into a single `qk_pv` scope. The BF16 exp scores now feed the V matmul directly in the same scope, eliminating the `sparse_exp` GM scratch round-trip between the two matmuls. Two separate slice variables (`qk_kv_k` and `qk_kv_v`) on the same `sparse_kv` parameter give the cube one transpose=True load (QK) and one transpose=False load (PV) — requires the pypto#1532 fix landed locally. SPARSE_ATTN_TILE is halved from 64 to 32 so the K and V cube right-buffer copies (each `SPARSE_ATTN_TILE * HEAD_DIM * 2` B) fit together in the 64KB L1B. Validated on a2a3 with the standalone harness.
- Replace the outer `pl.parallel(T) > pl.parallel(0, H, H_TILE)` task nest with two `pl.spmd(T)` loops, one per scope. Each spmd body walks `H_BLOCKS` sequentially via inner `pl.range(H // H_TILE)`, so the InCore region is provided by spmd itself (no explicit pl.at). - Hoist `sparse_blk_mi`/`sparse_blk_li`/`sparse_blk_oi` to function level and size them for all (token, h-block, K-tile) chunks. qk_pv writes the full sweep; merge_norm reads it back. Per-spmd-task offsets pick the owning token's slice. - Rename `SPARSE_ATTN_TILE` to `ATTN_K_TILE` to match the A_K_TILE / B_K_TILE convention used by the projection scopes. Validated on a2a3 with the standalone harness.
Revert PR hw-native-sys#380 for the a2a3 jobs in ci.yml and daily_ci.yml only; sim jobs keep the smaller values from hw-native-sys#380. - PTO2_RING_DEP_POOL: 524288 -> 1048576 - PTO2_RING_TASK_WINDOW: 524288 -> 1048576 - PTO2_RING_HEAP: 1 GiB -> 4 GiB
Top-level dsv4 decode layers aren't supported on a2a3sim / a5sim;
restrict the argparse --platform choices to {a2a3, a5}.
|
Caution Review failedPull request was closed or merged during review 📝 WalkthroughWalkthroughThis PR updates CI environment sizing for A2A3 tests, restricts platform support to production variants in three decode scripts, and refactors the sparse attention decode kernel into a unified entry point with updated tiling constants and inverse-RoPE selector logic. ChangesRuntime Environment Sizing
Platform Support Constraints
Sparse Attention Decode Kernel Refactor
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request simplifies the platform choices across several DeepSeek-V4 decode scripts and heavily refactors decode_sparse_attn.py to use SPMD-based tiling for attention, inverse RoPE, and projection stages. The review feedback highlights two optimization opportunities in decode_sparse_attn.py: guarding redundant global memory writes of or_scale_dq in the quantization loop to prevent bus contention, and refactoring the proj_b pipeline to use a "peel-first-iter" pattern to avoid wasted tensor allocation and conditional branching.
| or_amax = pl.maximum(or_amax, or_a_part) | ||
| or_sq_row = pl.div(pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), or_amax) | ||
| or_scale_dq = pl.reshape(pl.recip(or_sq_row), [QUANT_TOKEN_TILE, 1]) | ||
| o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq |
There was a problem hiding this comment.
In the quant SPMD loop, multiple blocks with the same qt_idx but different qk_idx concurrently write the identical or_scale_dq value to the same global memory location o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1]. This redundant write causes memory bus contention and cache coherency overhead on the accelerator.
We should guard this write with if qk_idx == 0: to ensure it is only written once per token tile.
| o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq | |
| if qk_idx == 0: | |
| o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq |
| acc_b = pl.create_tensor([T, B_N_TILE], dtype=pl.INT32) | ||
| for kb in pl.pipeline(0, (O_GROUPS * O_LORA) // B_K_TILE, stage=2): | ||
| k0 = kb * B_K_TILE | ||
| xb_k_chunk = o_r_i8[:, k0:k0 + B_K_TILE] | ||
| wb_k_chunk = wo_b[n0:n0 + B_N_TILE, k0:k0 + B_K_TILE] | ||
| if k0 == 0: | ||
| acc_b = pl.matmul(xb_k_chunk, wb_k_chunk, b_trans=True, out_dtype=pl.INT32) | ||
| else: | ||
| acc_b = pl.matmul_acc(acc_b, xb_k_chunk, wb_k_chunk, b_trans=True) |
There was a problem hiding this comment.
The tensor acc_b is allocated via pl.create_tensor at line 335, but is immediately overwritten by pl.matmul during the first iteration (k0 == 0) of the pipeline loop. This results in a wasted memory allocation in the limited accelerator memory (UB/L1/L2). Additionally, checking if k0 == 0 inside the pipelined loop introduces an unnecessary conditional branch.
We can refactor this to use the same "peel-first-iter" pattern as proj_a (lines 287-294), which avoids both the wasted allocation and the conditional branch inside the loop.
xb0_chunk = o_r_i8[:, 0:B_K_TILE]
wb0_chunk = wo_b[n0:n0 + B_N_TILE, 0:B_K_TILE]
acc_b = pl.matmul(xb0_chunk, wb0_chunk, b_trans=True, out_dtype=pl.INT32)
for kb in pl.pipeline(1, (O_GROUPS * O_LORA) // B_K_TILE, stage=2):
k0 = kb * B_K_TILE
xb_k_chunk = o_r_i8[:, k0:k0 + B_K_TILE]
wb_k_chunk = wo_b[n0:n0 + B_N_TILE, k0:k0 + B_K_TILE]
acc_b = pl.matmul_acc(acc_b, xb_k_chunk, wb_k_chunk, b_trans=True)
Summary
decode_sparse_attn: switchpl.parallel+pl.attopl.spmd, fuserope/proj_a/proj_band double-quant stages, fuseqk_softmax+pvthen splitqk_pvandmerge_norminto separatepl.spmdloops, fusemerge+norm, shrink scratch tensors, dropATTN_TOKEN_TILE, and tighten valid-shape / local constants / naming.PTO2_RING_DEP_POOL524288→1048576,PTO2_RING_TASK_WINDOW524288→1048576,PTO2_RING_HEAP1 GiB→4 GiB. Sim jobs keep the smaller values.-pchoices inmodels/deepseek/v4/decode_{swa,csa,hca}.pyto{a2a3, a5}(top-level decode layers don't run on sim).