Skip to content

Refactor dsv4 decode_sparse_attn; restore a2a3 ring sizes; drop sim from decode_{swa,csa,hca}#393

Merged
zhangqi-chen merged 11 commits into
hw-native-sys:mainfrom
zhangqi-chen:decode-sparse-attn
May 26, 2026
Merged

Refactor dsv4 decode_sparse_attn; restore a2a3 ring sizes; drop sim from decode_{swa,csa,hca}#393
zhangqi-chen merged 11 commits into
hw-native-sys:mainfrom
zhangqi-chen:decode-sparse-attn

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

Summary

  • Refactor dsv4 decode_sparse_attn: switch pl.parallel+pl.at to pl.spmd, fuse rope/proj_a/proj_b and double-quant stages, fuse qk_softmax+pv then split qk_pv and merge_norm into separate pl.spmd loops, fuse merge+norm, shrink scratch tensors, drop ATTN_TOKEN_TILE, and tighten valid-shape / local constants / naming.
  • CI: revert PR chore(ci): Reduce PTO2 ring buffer sizes in CI workflows #380 ring-buffer reduction for the a2a3 jobs only (ci.yml + daily_ci.yml). PTO2_RING_DEP_POOL 524288→1048576, PTO2_RING_TASK_WINDOW 524288→1048576, PTO2_RING_HEAP 1 GiB→4 GiB. Sim jobs keep the smaller values.
  • Restrict -p choices in models/deepseek/v4/decode_{swa,csa,hca}.py to {a2a3, a5} (top-level decode layers don't run on sim).

- Inline derived _BLOCKS expressions (SPARSE_ATTN_BLOCKS,
  ROPE_PACK_SPMD_BLOCKS, A/B K/N/AMAX blocks) at use sites; keep
  ORI_MAX_BLOCKS / CMP_MAX_BLOCKS pool configs.
- Drop one-shot GATHER_TOKEN_TILE constant; use literal 4.
- Rename _CHUNK constants to _TILE; collapse the T-conditional
  B_N_TILE / QUANT_TILE to single fixed values for the current
  decode shape.
- Shorten pl.at / pl.spmd name_hint strings to the bare role
  (gather_kv, qk_softmax, pv, merge, norm, rope_*, proj_*, quant).
- Add outer parens around the ceil-div tile-count expression where it
  was previously hidden behind a symbol, fixing latent operator
  precedence (T * H * ((TOPK + TILE - 1) // TILE)).
- Move intermediate pl.create_tensor allocations next to the first
  pl.parallel that consumes them, grouped per stage.
- Drop column-aligned signatures and aligned trailing comments.
- Trim the module docstring to a single-line summary.
- Switch DEFAULT_COMPRESS_RATIO to 0 for the standalone harness.

Verified on a2a3 NPU: attn_out PASS (ratio_allclose).
The outer attention loop now uses `pl.parallel(T)` so each iteration
processes one token directly, removing the redundant inner
`pl.range(ATTN_TOKEN_TILE)` wrapper in the qk_softmax / pv / merge /
norm scopes. Validated on a2a3 with the standalone harness.
Move the per-(token, head-tile) softmax scratch (sparse_exp,
sparse_blk_mi/li/oi, sparse_mi/li/oi) inside the inner
`pl.parallel(0, H, H_TILE)` task and drop the now-redundant T and H
factors from their shapes. Each declaration sits right before its
first-use `pl.at` scope. Also rename `MATMUL_ROW_PAD` to `H_TILE` to
reflect its role as the head-tile size now that it is the row count
of the local scratch tensors.

`attn_rope_stage` and `o_packed` stay at outer scope because the
later rope_slice and proj_a_accum loops consume them across tasks.
…se_attn

- Collapse the per-token `merge` and `norm` `pl.at` scopes into a single
  `merge_norm` scope; the per-token mi/li/oi now stay in-register from the
  online-softmax recurrence into the sink-norm cast, dropping the
  `sparse_mi`, `sparse_li`, and `sparse_oi` per-task scratch tensors.
- Replace the `pl.slice(..., valid_shape=...)` boilerplate around the QK
  score scaling with `pl.set_validshape`, which is the canonical way to
  attach a runtime-valid shape to an existing tile.

Validated on a2a3 with the standalone harness.
- Replace every `pl.slice` / `pl.assemble` with the `dst[r:r+H, c:c+W]`
  subscript sugar; the single `pl.slice(..., valid_shape=...)` becomes
  `pl.set_validshape` on the subscript-sliced source.
- Shorten the verbose per-scope prefixes (`gather_` -> `g_`, `merge_` ->
  `m_`, `norm_` -> `n_`, `rope_slice_` -> `rs_`, `rope_apply_` -> `ra_`,
  `rope_pack_/rope_combine_` -> `rp_`) and the common verbose suffixes
  (`_window_valid` -> `_win_v`, `_cmp_topk_valid` -> `_tk_v`,
  `_seq_final/used` -> `_seq_end/len`, `_sparse_kv_base` -> `_kv_base`,
  `_block_row` -> `_row`, `_tile_start` -> `_s0`, `_tile_valid` ->
  `_s_v`). Prefixes stay distinct per scope to keep `@pl.jit.inline`
  SSA-safe.
- Inline `o_proj_even`/`o_proj_odd` style multi-line ops onto one line
  where they fit, and break only by `pl` operator when they don't, per
  the project's coding style. The orphan multi-line `pl.cast(...)` in
  `proj_a_store` and `quant` is also collapsed.

Validated on a2a3 with the standalone harness.
- Merge rope_slice + rope_apply into a single `rope` scope; drop the
  `o_proj_even`/`o_proj_odd` GM scratch round-trip. Relies on the local
  pypto#1532 fix (one InCore param can now be loaded with both
  `b_trans=True` and `b_trans=False` in the same scope).
- Fuse proj_a_accum + proj_a_store into one `proj_a` scope. The vec
  post-process (BF16 store + per-row partial amax) is T-tiled
  (`A_T_TILE=16`) inside the scope as a pypto#1472 workaround — without
  it the AIV UB live set exceeds the 192KB limit. K loop stays in
  peel-first-iter form because the `pl.create_tensor` + `if k0 == 0`
  carry style hits pypto#1540 on the 3D wo_a slice.
- Fuse proj_b_accum + proj_b_store similarly with `B_T_TILE=16`; the
  2D INT8 carry uses the `pl.create_tensor` + `if k0 == 0` form now
  that pypto#1501 is fixed.
- Quant kernel: K is tiled as a second parallel axis (`QUANT_K_TILE`)
  so the quant task count doubles (T/TOKEN_TILE * K/QUANT_K_TILE).
- Drop the redundant `g_dt` 4-token batching in gather_kv — use
  `pl.parallel(T)` directly. 2D `pl.read` lets `ori_block_table_flat`,
  `cmp_sparse_indices_flat`, and `cmp_block_table_flat` reshapes go.
- Hoist `q_flat`/`ori_kv_flat`/`cmp_kv_flat` to right before their first
  use; trim verbose stage banners and outdated comments.

Validated on a2a3 with the standalone harness.
…e_attn

Convert every `for X in pl.parallel(...): with pl.at(name_hint="Y")`
pattern to the single `for X in pl.spmd(N, name_hint="Y")` form (and
inline the dependent index/stride math into the body). Doubly-nested
parallel pairs are linearized into one spmd over the product of the
two block counts.

- gather_kv: pl.spmd(T)
- rope: pl.spmd(T // ROPE_TOKEN_TILE)
- proj_a: pl.spmd(O_GROUPS * (O_LORA // A_N_TILE))
- quant: pl.spmd((T // QUANT_TOKEN_TILE) * ((O_GROUPS * O_LORA) // QUANT_K_TILE))
- proj_b: pl.spmd(D // B_N_TILE)

The Stage-2 attention block (qk_softmax / pv / merge_norm) keeps its
nested `pl.parallel(T) > pl.parallel(0, H, H_TILE)` outer structure
because each inner `pl.at` is its own scope; converting to spmd there
would require splitting one task per scope and is unrelated to this
sweep.

Validated on a2a3 with the standalone harness.
Collapse the per-token `qk_softmax` and `pv` `pl.at` scopes into a
single `qk_pv` scope. The BF16 exp scores now feed the V matmul
directly in the same scope, eliminating the `sparse_exp` GM scratch
round-trip between the two matmuls.

Two separate slice variables (`qk_kv_k` and `qk_kv_v`) on the same
`sparse_kv` parameter give the cube one transpose=True load (QK) and
one transpose=False load (PV) — requires the pypto#1532 fix landed
locally.

SPARSE_ATTN_TILE is halved from 64 to 32 so the K and V cube
right-buffer copies (each `SPARSE_ATTN_TILE * HEAD_DIM * 2` B) fit
together in the 64KB L1B.

Validated on a2a3 with the standalone harness.
- Replace the outer `pl.parallel(T) > pl.parallel(0, H, H_TILE)` task
  nest with two `pl.spmd(T)` loops, one per scope. Each spmd body
  walks `H_BLOCKS` sequentially via inner `pl.range(H // H_TILE)`, so
  the InCore region is provided by spmd itself (no explicit pl.at).
- Hoist `sparse_blk_mi`/`sparse_blk_li`/`sparse_blk_oi` to function
  level and size them for all (token, h-block, K-tile) chunks. qk_pv
  writes the full sweep; merge_norm reads it back. Per-spmd-task
  offsets pick the owning token's slice.
- Rename `SPARSE_ATTN_TILE` to `ATTN_K_TILE` to match the A_K_TILE /
  B_K_TILE convention used by the projection scopes.

Validated on a2a3 with the standalone harness.
Revert PR hw-native-sys#380 for the a2a3 jobs in ci.yml and daily_ci.yml only;
sim jobs keep the smaller values from hw-native-sys#380.

- PTO2_RING_DEP_POOL: 524288 -> 1048576
- PTO2_RING_TASK_WINDOW: 524288 -> 1048576
- PTO2_RING_HEAP: 1 GiB -> 4 GiB
Top-level dsv4 decode layers aren't supported on a2a3sim / a5sim;
restrict the argparse --platform choices to {a2a3, a5}.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

Caution

Review failed

Pull request was closed or merged during review

📝 Walkthrough

Walkthrough

This PR updates CI environment sizing for A2A3 tests, restricts platform support to production variants in three decode scripts, and refactors the sparse attention decode kernel into a unified entry point with updated tiling constants and inverse-RoPE selector logic.

Changes

Runtime Environment Sizing

Layer / File(s) Summary
A2A3 ring environment tuning
.github/workflows/ci.yml, .github/workflows/daily_ci.yml
PTO2_RING_DEP_POOL, PTO2_RING_TASK_WINDOW, and PTO2_RING_HEAP are increased from smaller values (524288/524288/1073741824) to larger values (1048576/1048576/4294967296) in both ci.yml and daily_ci.yml for the A2A3 test jobs.

Platform Support Constraints

Layer / File(s) Summary
Platform choice restriction
models/deepseek/v4/decode_csa.py, models/deepseek/v4/decode_hca.py, models/deepseek/v4/decode_swa.py
The --platform argparse option in three decode scripts is restricted to only ["a2a3", "a5"], removing previously supported simulation variants (a2a3sim, a5sim).

Sparse Attention Decode Kernel Refactor

Layer / File(s) Summary
Tiling constants and configuration
models/deepseek/v4/decode_sparse_attn.py
Module docstring is updated; DEFAULT_COMPRESS_RATIO changes from 128 to 0; new tile parameters (H_TILE, ATTN_K_TILE, ROPE_TILE, ROPE_INTERLEAVE_TILE) are introduced, and QUANT_K_TILE is updated to depend on O_GROUPS and O_LORA, replacing prior tiling constants.
Unified sparse_attn kernel implementation
models/deepseek/v4/decode_sparse_attn.py
A new sparse_attn(...) JIT kernel centralizes the full decode path: packed sparse KV gather (window + compressed hits with zero-padding), sparse attention QK/softmax PV over tiles with online merge and sink-norm, inverse RoPE (deinterleave/rotate/reinterleave), and grouped projection (BF16 stage-1 + symmetric INT8 quant/stage-2 GEMM/dequant). The sparse_attn_test(...) function is refactored to wrap this kernel.
Rope selector initialization and tensor specifications
models/deepseek/v4/decode_sparse_attn.py
The torch golden reference iterates sparse KV tiles using ATTN_K_TILE instead of SPARSE_ATTN_TILE. The init_odd_select_local() function builds selectors with new dimensions based on ROPE_INTERLEAVE_TILE and ROPE_TILE. Tensor specs for even_select_local and odd_select_local are updated to reflect the new selector shapes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#380: Updates the same CI workflow environment variables (PTO2_RING_DEP_POOL, PTO2_RING_TASK_WINDOW, PTO2_RING_HEAP) for A2A3/sim test jobs.
  • hw-native-sys/pypto-lib#361: Modifies models/deepseek/v4/decode_sparse_attn.py sparse KV tiling and golden reference softmax/merge logic to align with kernel computation.
  • hw-native-sys/pypto-lib#225: Introduces a fused sparse attention + grouped output projection decode kernel matching this PR's unified sparse_attn kernel refactor.

Poem

🐰 With rings resized and platforms pruned clean,
A sparse attention kernel takes the stage between—
KV gathered, soft-maxed, ropes inverse-spun,
Grouped projections complete what's begun!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately summarizes the three main changes: refactoring decode_sparse_attn, restoring a2a3 ring sizes, and dropping sim from decode modules.
Description check ✅ Passed The pull request description provides detailed explanations of all three key change areas and directly relates to the actual modifications in the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the platform choices across several DeepSeek-V4 decode scripts and heavily refactors decode_sparse_attn.py to use SPMD-based tiling for attention, inverse RoPE, and projection stages. The review feedback highlights two optimization opportunities in decode_sparse_attn.py: guarding redundant global memory writes of or_scale_dq in the quantization loop to prevent bus contention, and refactoring the proj_b pipeline to use a "peel-first-iter" pattern to avoid wasted tensor allocation and conditional branching.

or_amax = pl.maximum(or_amax, or_a_part)
or_sq_row = pl.div(pl.full([1, QUANT_TOKEN_TILE], dtype=pl.FP32, value=INT8_SCALE_MAX), or_amax)
or_scale_dq = pl.reshape(pl.recip(or_sq_row), [QUANT_TOKEN_TILE, 1])
o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the quant SPMD loop, multiple blocks with the same qt_idx but different qk_idx concurrently write the identical or_scale_dq value to the same global memory location o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1]. This redundant write causes memory bus contention and cache coherency overhead on the accelerator.

We should guard this write with if qk_idx == 0: to ensure it is only written once per token tile.

Suggested change
o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq
if qk_idx == 0:
o_r_scale_dq[quant_t0:quant_t0 + QUANT_TOKEN_TILE, 0:1] = or_scale_dq

Comment on lines +335 to 343
acc_b = pl.create_tensor([T, B_N_TILE], dtype=pl.INT32)
for kb in pl.pipeline(0, (O_GROUPS * O_LORA) // B_K_TILE, stage=2):
k0 = kb * B_K_TILE
xb_k_chunk = o_r_i8[:, k0:k0 + B_K_TILE]
wb_k_chunk = wo_b[n0:n0 + B_N_TILE, k0:k0 + B_K_TILE]
if k0 == 0:
acc_b = pl.matmul(xb_k_chunk, wb_k_chunk, b_trans=True, out_dtype=pl.INT32)
else:
acc_b = pl.matmul_acc(acc_b, xb_k_chunk, wb_k_chunk, b_trans=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tensor acc_b is allocated via pl.create_tensor at line 335, but is immediately overwritten by pl.matmul during the first iteration (k0 == 0) of the pipeline loop. This results in a wasted memory allocation in the limited accelerator memory (UB/L1/L2). Additionally, checking if k0 == 0 inside the pipelined loop introduces an unnecessary conditional branch.

We can refactor this to use the same "peel-first-iter" pattern as proj_a (lines 287-294), which avoids both the wasted allocation and the conditional branch inside the loop.

        xb0_chunk = o_r_i8[:, 0:B_K_TILE]
        wb0_chunk = wo_b[n0:n0 + B_N_TILE, 0:B_K_TILE]
        acc_b = pl.matmul(xb0_chunk, wb0_chunk, b_trans=True, out_dtype=pl.INT32)
        for kb in pl.pipeline(1, (O_GROUPS * O_LORA) // B_K_TILE, stage=2):
            k0 = kb * B_K_TILE
            xb_k_chunk = o_r_i8[:, k0:k0 + B_K_TILE]
            wb_k_chunk = wo_b[n0:n0 + B_N_TILE, k0:k0 + B_K_TILE]
            acc_b = pl.matmul_acc(acc_b, xb_k_chunk, wb_k_chunk, b_trans=True)

@zhangqi-chen zhangqi-chen merged commit e84fb6f into hw-native-sys:main May 26, 2026
4 of 7 checks passed
@zhangqi-chen zhangqi-chen deleted the decode-sparse-attn branch May 26, 2026 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant