Skip to content

Fix scalar-indexed slices in V3.2 / Qwen3-32B + sparse_attn rope refactor#288

Merged
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
zhangqi-chen:fix/pl-range-slice-folding-and-sparse-rope
May 15, 2026
Merged

Fix scalar-indexed slices in V3.2 / Qwen3-32B + sparse_attn rope refactor#288
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
zhangqi-chen:fix/pl-range-slice-folding-and-sparse-rope

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

Summary

  • Rewrite 28 scalar-leading-axis subscripts x[scalar, expr] to range-slice form x[scalar : scalar + 1, expr] inside two @pl.function decode kernels (deepseek_v3_2_decode_front.py, qwen3_32b_decode.py). Extends the moe_expert fix (Fix: DeepSeek V4 moe_expert scalar-indexed 3-D slice compile failures #281) from rank-3 weights to rank-2 rope tables and per-batch projections.
  • Where the leading offset was originally ctx_len - 1, introduce pos = ctx_len - 1 first — the form rope_cos[ctx_len - 1 : ctx_len, ...] does not compile because pypto IR doesn't fold (ctx_len) - (ctx_len - 1) = 1 into a static row dim (filed as [Bug] IR does not fold i - (i-1) = 1 in range-slice bounds, breaks col_expand_mul static row-dim check pypto#1377).
  • Refactor sparse_attn rope-interleave assemble: split cfa_proj_rope_assemble into _matmul + _combine scopes with FP32 intermediate buffers; BF16 add + cast now runs once per H * ROPE_DIM tile instead of per matmul chunk.

Verified on a2a3:

  • qwen3_32b_decode: compile + runtime + golden PASS (17.4s)
  • deepseek_v3_2_decode_front: compile + runtime + golden PASS (4/4 outputs)

Related Issues

Related: hw-native-sys/pypto#1377

Rewrite scalar-leading-axis subscripts `x[scalar, expr]` to range-slice
`x[scalar : scalar + 1, expr]` across 28 sites in two decode kernels.
Extends the moe_expert fix (hw-native-sys#281) from rank-3 weights to rank-2 rope and
per-batch projection tensors.

For offsets that were originally `ctx_len - 1`, introduce `pos = ctx_len
- 1` first — the form `rope_cos[ctx_len - 1 : ctx_len, ...]` does not
compile because pypto IR does not fold `(ctx_len) - (ctx_len - 1) = 1`
into a static row dim, and downstream `pl.col_expand_mul` rejects it
(filed as pypto#1377).

Verified on a2a3:
- qwen3_32b_decode: compile + runtime + golden PASS (17.4s)
- deepseek_v3_2_decode_front: compile + runtime + golden PASS (4/4)
Split the single `cfa_proj_rope_assemble` scope into `_matmul` and
`_combine` scopes, with FP32 intermediate buffers for the even/odd
interleave streams. The BF16 add + cast that was inlined inside the
matmul loop now runs once over the full ROPE_DIM tile after the
even/odd matmul outputs are assembled.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 15, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 7cc5503c-41c8-4fcb-8ded-8b84af873013

📥 Commits

Reviewing files that changed from the base of the PR and between c47b5fb and a39c216.

📒 Files selected for processing (3)
  • models/deepseek/v3_2/deepseek_v3_2_decode_front.py
  • models/deepseek/v4/sparse_attn.py
  • models/qwen3/32b/qwen3_32b_decode.py

📝 Walkthrough

Walkthrough

Three model kernels (Qwen3 32B, DeepSeek V3.2, DeepSeek V4) refactor RoPE tensor indexing from scalar/1D patterns to batched singleton slicing (pos:pos+1, b:b+1) across decode attention paths, normalizing how RoPE embeddings and related tensors are accessed and cast. DeepSeek V4 additionally refactors sparse-attention RoPE assembly to use intermediate staging buffers.

Changes

RoPE tensor shape refactoring across decode kernels

Layer / File(s) Summary
Qwen3 RoPE cosine/sine singleton-dimension slicing
models/qwen3/32b/qwen3_32b_decode.py
RoPE cosine and sine extraction for the current position switches from 1D indexing to 2D singleton-dimension slicing (pos:pos+1), affecting the shape of cos_lo, cos_hi, sin_lo, and sin_hi used in subsequent K/Q rotation.
DeepSeek V3.2 decode front multi-stage RoPE refactoring
models/deepseek/v3_2/deepseek_v3_2_decode_front.py
Across five decode stages (1.7, 1.9, 2.4, 3.3, 4.1), RoPE-related tensor handling uniformly transitions from scalar indexing to batched singleton slicing: RoPE cos/sin use pos:pos+1, Q/K/KV tensors use b:b+1 slices, score rows use scores[b : b + 1, :], and all casts/processes maintain the batched singleton shape.
DeepSeek V4 sparse attention RoPE assembly with staging buffers
models/deepseek/v4/sparse_attn.py
Inverse-RoPE assembly refactors to use intermediate staging buffers (rope_even_interleave_buf, rope_odd_interleave_buf). Per-r0 even/odd RoPE matmul results are written to these buffers, then combined in a separate core-group step into the final BF16 o_rope_interleave tile. Core-group name hint updated to cfa_proj_rope_assemble_matmul.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#253: Directly related inverse-RoPE/output RoPE assembly refactor in DeepSeek V4 sparse attention with interleaved even/odd staging pattern.
  • hw-native-sys/pypto-lib#157: Related DeepSeek V3.2 decode-front RoPE application and tensor-shape/slicing logic for batched singleton position-based indexing.
  • hw-native-sys/pypto-lib#61: Related Qwen3 decode RoPE handling refactoring with pos:pos+1 slice-shape adjustments in attention and KV-cache write paths.

Suggested labels

bug

Poem

🐰 A rabbit hops through tensor slices
From scalar to singleton, the RoPE device
Three models in sync, one shape to convey
Position embeddings dance the modern way!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main changes: fixing scalar-indexed slices in two kernels and refactoring sparse_attn rope handling, which matches the changeset content.
Description check ✅ Passed The description provides relevant details about the changes, including the 28 subscript rewrites, the ctx_len - 1 issue workaround, the sparse_attn refactor, and verification results that correspond to the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zhangqi-chen zhangqi-chen merged commit 4683640 into hw-native-sys:main May 15, 2026
6 checks passed
@zhangqi-chen zhangqi-chen deleted the fix/pl-range-slice-folding-and-sparse-rope branch May 15, 2026 01:48
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates indexing patterns across DeepSeek and Qwen3 decoding scripts, replacing scalar subscripts with range slices to ensure compatibility with the IR compiler's dimension folding. Additionally, the DeepSeek v4 sparse attention implementation was refactored to use intermediate FP32 buffers for RoPE interleave assembly, consolidating addition and casting operations to improve hardware utilization. I have no feedback to provide as the existing review comments were primarily explanatory and did not identify issues or required actions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant