Skip to content

Refactor: fuse and simplify DeepSeek V4 qkv_proj_rope scopes#332

Merged
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
zhangqi-chen:refactor/dsv4-qkv-proj-rope-fuse
May 20, 2026
Merged

Refactor: fuse and simplify DeepSeek V4 qkv_proj_rope scopes#332
zhangqi-chen merged 1 commit into
hw-native-sys:mainfrom
zhangqi-chen:refactor/dsv4-qkv-proj-rope-fuse

Conversation

@zhangqi-chen
Copy link
Copy Markdown
Collaborator

Summary

  • Fuse attn_norm (rms + apply + bf16 cast) into one scope and drop the intermediate token_x_fp32 GM tensor
  • Fuse qr rms/apply/bf16-cast and kv rms/nope-apply/rope-apply into single scopes each, dropping dead fp32 write-backs
  • Rewrite qr_proj_matmul in the qwen3 q_proj style (create_tensor accumulator + pl.pipeline K-loop with if-first matmul/matmul_acc)
  • Drop all chunked_loop_optimizer hints
  • Fold the q/kv RoPE reassemble+write loops inside their pl.at scopes via an intermediate fp32 tensor; use pl.pipeline(stage=2)
  • Convert all pl.slice/pl.assemble to subscript [:] sugar
  • Add in-core T-dim range(2) tiling to the fused kv_rms and q_head_rms_rope scopes so they stay within the Vec buffer at B=64

Validated on a2a3 (B=64, T=128): q / kv / qr / qr_scale all PASS.

- Fuse attn_norm rms/apply/bf16-cast into one scope; drop the
  intermediate token_x_fp32 GM tensor
- Fuse qr rms/apply/bf16-cast and kv rms/nope-apply/rope-apply into
  single scopes each, dropping the dead fp32 write-backs
- Rewrite qr_proj_matmul in the qwen3 q_proj style (create_tensor acc +
  pl.pipeline K-loop with if-first matmul/matmul_acc)
- Drop all chunked_loop_optimizer hints
- Fold the q/kv RoPE reassemble+write loops inside their pl.at scopes
  via an intermediate fp32 tensor; use pl.pipeline(stage=2)
- Convert all pl.slice/pl.assemble to subscript [:] sugar
- Add in-core T-dim range(2) tiling to the fused kv_rms and
  q_head_rms_rope scopes so they stay within the Vec buffer at B=64
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors the DeepSeek-V4 QKV projection and rotary embedding pipeline to eliminate assemble-based storage and consolidate BF16 computation using tiled RMSNorm and direct indexing across all projection paths.

Changes

QKV Projection and RoPE Optimization

Layer / File(s) Summary
Tiling configuration and fused token normalization
models/deepseek/v4/qkv_proj_rope.py
Introduces RMS_T_TILES and T_TILE tiling parameters and consolidates the two-stage FP32→BF16 normalization into a single Stage 0 block that directly computes and writes token_x_bf16.
QR projection, normalization, and quantization
models/deepseek/v4/qkv_proj_rope.py
QR matmul consumes token_x_bf16 and writes qr_fp32 directly; fused RMSNorm applies gamma and writes qr_bf16 via block-wise stores; per-token int8 quantization fills qr and qr_scale through direct indexing.
Q projection and per-head RoPE
models/deepseek/v4/qkv_proj_rope.py
Q matmul writes q_proj_fp32 directly; per-head RMSNorm and RoPE computation use RMS_T_TILES/T_TILE tiling with normalized NOPE/ROPE computed in BF16, reassembled in FP32, and written back to q_flat via direct indexing.
KV projection, normalization, NOPE/RoPE, and reassembly
models/deepseek/v4/qkv_proj_rope.py
KV matmul writes kv_fp32 directly with tiled D-block accumulation; fused RMSNorm + NOPE/RoPE uses tiled loops and direct BF16 writes into kv; ROPE-even/odd intermediates are generated and reassembled via direct indexing into kv_rope_fp32, then finalized as BF16 back to kv.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#219: Earlier DeepSeek-V4 QKV rope refactoring that introduced BF16-typed precast intermediates and reorganized matmul/dataflow stages.

Poem

🐰 The rope unwinds with tiling grace,
Direct writes replace assemble's pace,
From token norm to KV rope so tight,
Each path optimized, batched just right. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Refactor: fuse and simplify DeepSeek V4 qkv_proj_rope scopes' directly matches the main objective of the changeset, which refactors and fuses multiple computation scopes in the qkv_proj_rope implementation.
Description check ✅ Passed The description provides detailed, specific information about all major changes including fusions, rewrites, and tiling optimizations, directly corresponding to the changeset modifications.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/qkv_proj_rope.py`:
- Around line 52-53: The tiling assumes T divides evenly by RMS_T_TILES, but
T_TILE = T // RMS_T_TILES can be zero or skip tail rows; add a guard: either
assert T >= RMS_T_TILES and T % RMS_T_TILES == 0 immediately after
RMS_T_TILES/T_TILE definition, or (preferable) compute tiles so tails are
handled by replacing T_TILE = T // RMS_T_TILES with a safe scheme (e.g., compute
tile_count = RMS_T_TILES; compute start = i * T_TILE and end = min(T, (i+1) *
T_TILE) for each tile, or use ceil division for T_TILE and clamp end indices)
and update the tiled loops that use RMS_T_TILES and T_TILE (the loops referenced
in the comment) to use start/end bounds so the final partial tile is processed.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 07e4c9c1-3d64-4f45-8600-c1709225f6ef

📥 Commits

Reviewing files that changed from the base of the PR and between a83c8ef and cea2532.

📒 Files selected for processing (1)
  • models/deepseek/v4/qkv_proj_rope.py

Comment on lines +52 to +53
RMS_T_TILES = 2
T_TILE = T // RMS_T_TILES
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard the 2-way T tiling contract.

T_TILE = T // RMS_T_TILES makes the tiled loops at Lines 178-179 and 251-252 cover exactly RMS_T_TILES * T_TILE rows. If T is odd, or ever smaller than RMS_T_TILES, the tail rows are skipped or the tile size becomes zero, so the last tokens never get RMS/RoPE applied. Please either assert that contract here or handle the final partial tile explicitly.

Suggested minimal guard
 RMS_T_TILES = 2
+assert T >= RMS_T_TILES and T % RMS_T_TILES == 0, \
+    "RMS_T_TILES must evenly divide T"
 T_TILE = T // RMS_T_TILES
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RMS_T_TILES = 2
T_TILE = T // RMS_T_TILES
RMS_T_TILES = 2
assert T >= RMS_T_TILES and T % RMS_T_TILES == 0, \
"RMS_T_TILES must evenly divide T"
T_TILE = T // RMS_T_TILES
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/qkv_proj_rope.py` around lines 52 - 53, The tiling assumes
T divides evenly by RMS_T_TILES, but T_TILE = T // RMS_T_TILES can be zero or
skip tail rows; add a guard: either assert T >= RMS_T_TILES and T % RMS_T_TILES
== 0 immediately after RMS_T_TILES/T_TILE definition, or (preferable) compute
tiles so tails are handled by replacing T_TILE = T // RMS_T_TILES with a safe
scheme (e.g., compute tile_count = RMS_T_TILES; compute start = i * T_TILE and
end = min(T, (i+1) * T_TILE) for each tile, or use ceil division for T_TILE and
clamp end indices) and update the tiled loops that use RMS_T_TILES and T_TILE
(the loops referenced in the comment) to use start/end bounds so the final
partial tile is processed.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the qkv_proj_rope function in models/deepseek/v4/qkv_proj_rope.py to optimize performance and improve code structure, notably by updating parallel loop patterns, replacing slice operations with direct indexing, and optimizing tensor assembly. The review identified a critical race condition where tensors allocated in global memory lack a head dimension, causing data corruption during parallel execution. Additionally, suggestions were provided to fuse specific scopes to reduce global memory bandwidth usage and improve overall efficiency.

Comment on lines +174 to +175
q_rot_even_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16)
q_rot_odd_bf16 = pl.create_tensor([T, ROPE_HALF], dtype=pl.BF16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tensors q_rot_even_bf16 and q_rot_odd_bf16 are created inside the pl.parallel(H) loop but lack a head dimension (e.g., [H, T, ROPE_HALF]). Since these tensors cross pl.at scopes, they are allocated in Global Memory (GM). Without indexing by the head index h, all parallel iterations will overwrite the same memory locations, leading to a critical race condition.

To fix this, either:

  1. Move the declarations outside the pl.parallel loop and add an H dimension, then index them with [h, ...] throughout the loop.
  2. Fuse the q_head_rms_rope and q_rope_reassemble scopes so these can be local variables (though this is difficult due to the tiling inversion between T and ROPE_DIM).

Comment on lines +205 to +227
q_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
q_rot_even_chunk = pl.slice(q_rot_even_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col])
q_rot_odd_chunk = pl.slice(q_rot_odd_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col])
q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_chunk = pl.matmul(
q_rot_even_chunk,
pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
q_rot_chunk = pl.matmul_acc(
q_rot_chunk,
q_rot_odd_chunk,
pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = q_rot_chunk

with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"):
h0 = h * HEAD_DIM
q_flat = pl.assemble(q_flat, pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint"), [0, h0 + NOPE_DIM + rope_col])
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"):
h0 = h * HEAD_DIM
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
q_rope_chunk = q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK]
q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rope_chunk, target_type=pl.BF16, mode="rint")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The q_rope_reassemble and q_rope_write scopes can be fused into a single pl.at scope. This would eliminate the need for the intermediate GM tensor q_rope_fp32 (which also suffers from the race condition mentioned above), saving GM bandwidth and improving performance. You can write the results of the matmuls directly to q_flat after casting.

Suggested change
q_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
q_rot_even_chunk = pl.slice(q_rot_even_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col])
q_rot_odd_chunk = pl.slice(q_rot_odd_bf16, [T, ROPE_PAIR_CHUNK], [0, pair_col])
q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_chunk = pl.matmul(
q_rot_even_chunk,
pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
q_rot_chunk = pl.matmul_acc(
q_rot_chunk,
q_rot_odd_chunk,
pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = q_rot_chunk
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"):
h0 = h * HEAD_DIM
q_flat = pl.assemble(q_flat, pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint"), [0, h0 + NOPE_DIM + rope_col])
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_write"):
h0 = h * HEAD_DIM
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
q_rope_chunk = q_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK]
q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rope_chunk, target_type=pl.BF16, mode="rint")
with pl.at(level=pl.Level.CORE_GROUP, name_hint="q_rope_reassemble_write"):
h0 = h * HEAD_DIM
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
# Note: q_rot_even_bf16 should be indexed by h if fixed as suggested above
q_rot_even_chunk = q_rot_even_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_odd_chunk = q_rot_odd_bf16[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
q_rot_chunk = pl.matmul(
q_rot_even_chunk,
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
q_rot_chunk = pl.matmul_acc(
q_rot_chunk,
q_rot_odd_chunk,
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
q_flat[:, h0 + NOPE_DIM + rope_col : h0 + NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(q_rot_chunk, target_type=pl.BF16, mode="rint")

Comment on lines +285 to +306
kv_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
kv_rot_even_chunk = pl.slice(kv_rot_even_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col])
kv_rot_odd_chunk = pl.slice(kv_rot_odd_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col])
kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_chunk = pl.matmul(
kv_rot_even_chunk,
pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
kv_rot_chunk = pl.matmul_acc(
kv_rot_chunk,
kv_rot_odd_chunk,
pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk

with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"):
kv = pl.assemble(kv, pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint"), [0, NOPE_DIM + rope_col])
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
kv_rope_chunk = kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK]
kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rope_chunk, target_type=pl.BF16, mode="rint")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the Q path, the kv_rope_reassemble and kv_rope_write scopes should be fused to eliminate the intermediate GM tensor kv_rope_fp32 and save GM round-trips.

Suggested change
kv_rope_fp32 = pl.create_tensor([T, ROPE_DIM], dtype=pl.FP32)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
kv_rot_even_chunk = pl.slice(kv_rot_even_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col])
kv_rot_odd_chunk = pl.slice(kv_rot_odd_tmp, [T, ROPE_PAIR_CHUNK], [0, pair_col])
kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_chunk = pl.matmul(
kv_rot_even_chunk,
pl.slice(even_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
kv_rot_chunk = pl.matmul_acc(
kv_rot_chunk,
kv_rot_odd_chunk,
pl.slice(odd_select_t, [ROPE_PAIR_CHUNK, ROPE_CHUNK], [pair_col, rope_col]),
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK] = kv_rot_chunk
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"):
kv = pl.assemble(kv, pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint"), [0, NOPE_DIM + rope_col])
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_write"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
kv_rope_chunk = kv_rope_fp32[:, rope_col : rope_col + ROPE_CHUNK]
kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rope_chunk, target_type=pl.BF16, mode="rint")
with pl.at(level=pl.Level.CORE_GROUP, name_hint="kv_rope_reassemble_write"):
for rope_col in pl.pipeline(0, ROPE_DIM, ROPE_CHUNK, stage=2):
pair_col = rope_col // 2
kv_rot_even_chunk = kv_rot_even_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_odd_chunk = kv_rot_odd_tmp[:, pair_col : pair_col + ROPE_PAIR_CHUNK]
kv_rot_chunk = pl.matmul(
kv_rot_even_chunk,
even_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
out_dtype=pl.FP32,
)
kv_rot_chunk = pl.matmul_acc(
kv_rot_chunk,
kv_rot_odd_chunk,
odd_select_t[pair_col : pair_col + ROPE_PAIR_CHUNK, rope_col : rope_col + ROPE_CHUNK],
)
kv[:, NOPE_DIM + rope_col : NOPE_DIM + rope_col + ROPE_CHUNK] = pl.cast(kv_rot_chunk, target_type=pl.BF16, mode="rint")

@zhangqi-chen zhangqi-chen merged commit 2b8a22e into hw-native-sys:main May 20, 2026
10 of 11 checks passed
@zhangqi-chen zhangqi-chen deleted the refactor/dsv4-qkv-proj-rope-fuse branch May 20, 2026 07:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant