Skip to content

Update: optimize DeepSeek hc_pre tiling#317

Merged
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
high-cloud:optimize/hc-pre
May 20, 2026
Merged

Update: optimize DeepSeek hc_pre tiling#317
zhangqi-chen merged 2 commits into
hw-native-sys:mainfrom
high-cloud:optimize/hc-pre

Conversation

@high-cloud
Copy link
Copy Markdown
Contributor

@high-cloud high-cloud commented May 19, 2026

Summary

  • Tile RMS, linear, comb sinkhorn, and mix_x over the token axis in hc_pre
  • Remove the standalone x cast stage and cast where data is consumed
  • Use larger linear K tiles while keeping RMS tiles within vector limits
  • Replace the manual pre-value transpose loop with incore pl.transpose + pl.store

Related Issues

None

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 19, 2026

Review Change Stack

Warning

Rate limit exceeded

@high-cloud has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 58 minutes and 27 seconds before requesting another review.

You’ve run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: b59e40e7-fd8d-46b2-8f6e-0a0fc8ab4a8b

📥 Commits

Reviewing files that changed from the base of the PR and between bee8dfc and 71a3cc5.

📒 Files selected for processing (1)
  • models/deepseek/v4/hc_pre.py
📝 Walkthrough

Walkthrough

This PR restructures the DeepSeek-V4 hyper-connections pre-mix kernel with independent tiling schemes for RMS, linear mixing, and comb sinkhorn stages. The compute flow now uses separate tile sizes and chunk sizes, stores pre-activation values with a transpose for efficient access, rewrites comb sinkhorn with tiled row-softmax, and reaccumulates x_mixed using the pre-transposed layout and D_CHUNK-blocked x data.

Changes

HC Pre-Mix Kernel Tiling Restructure

Layer / File(s) Summary
Tiling constants and chunking scheme
models/deepseek/v4/hc_pre.py
New tiling constants (RMS_T_TILE, LINEAR_T_TILE, COMB_T_TILE) and separate kernel chunk sizes (RMS_K_CHUNK, LINEAR_K_CHUNK) are introduced to drive the rewritten compute kernels. Derived block counts (RMS_K_BLOCKS, LINEAR_K_BLOCKS) are computed from these constants.
RMS and linear mix computation with tiling
models/deepseek/v4/hc_pre.py
The hc_pre function allocates tiled intermediates and computes inv_rms via a tiled loop over RMS_T_TILE and RMS_K_BLOCKS, then computes mix_raw via a tiled loop over LINEAR_T_TILE and LINEAR_K_BLOCKS. Pre-activation values are assembled into pre_val_store for later transpose.
Comb sinkhorn with tiled softmax and iterations
models/deepseek/v4/hc_pre.py
The pre-value storage is transposed into pre_val_t for efficient vector loading. Comb sinkhorn is rewritten to operate on COMB_T_TILE-sized row blocks, applying row-softmax with max-subtraction and eps stabilization, performing tiled Sinkhorn iterations, and writing results using comb_t_idx and column addressing.
x_mixed accumulation with tiled mixing
models/deepseek/v4/hc_pre.py
The x_mixed accumulation is rewritten to reshape via [T, D], process tokens in parallel over T_TILE, load pre-transposed pre_val_t rows and BF16→FP32 chunks of x in D_CHUNK blocks, build y_tile by row-wise expansion-multiply, store back to x_mixed_view, and reshape to [B, S, D].
Golden Torch reference update for new chunking
models/deepseek/v4/hc_pre.py
The golden_hc_pre reference implementation is updated to iterate using RMS_K_CHUNK for rsqrt accumulation and LINEAR_K_CHUNK for mix accumulation, matching the new kernel's independent chunking strategy.

Sequence Diagram(s)

flowchart TD
    A["Allocate tiled intermediates<br/>inv_rms, mixes, mix_raw, pre_val_store, pre_val_t"] --> B["Compute inv_rms<br/>Loop: RMS_T_TILE × RMS_K_BLOCKS"]
    B --> C["Compute mix_raw<br/>Loop: LINEAR_T_TILE × LINEAR_K_BLOCKS"]
    C --> D["Assemble pre_val_store<br/>from pre_val"]
    D --> E["Transpose pre_val_store<br/>into pre_val_t"]
    E --> F["Comb sinkhorn<br/>Process COMB_T_TILE row blocks"]
    F --> G["Row-softmax per block<br/>with max-subtraction + eps"]
    G --> H["Tiled Sinkhorn iterations<br/>row/column normalization"]
    H --> I["x_mixed accumulation<br/>Reshape to T,D view"]
    I --> J["Load pre_val_t rows<br/>pre0–pre3 per T_TILE"]
    J --> K["Load x chunks<br/>BF16→FP32 in D_CHUNK blocks"]
    K --> L["Build y_tile<br/>Expand-multiply pre × x"]
    L --> M["Store x_mixed_view<br/>reshape to B,S,D"]
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • hw-native-sys/pypto-lib#278: Both PRs refactor models/deepseek/v4/hc_pre.py to rebuild the core HC pipeline—changing how mixes are computed and how downstream comb and x_mixed are written/indexed—so this PR's tiled inv_rms/mix_raw/pre_val_t restructure is directly related to #278's matmul-based refactor.

Poem

🐰 A kernel reborn with tiling refined,
Chunks split apart, each stage redesigned,
Pre-values transposed for swifter access,
Sinkhorn steps tiled through each row's caress,
Now x_mixed flows through structured grace.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: optimizing DeepSeek hc_pre tiling strategy with new tile constants and restructured kernel logic.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The pull request description clearly relates to the changeset, describing tiling optimizations for RMS, linear, comb sinkhorn, and mix_x operations in hc_pre kernel with specific technical details.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
models/deepseek/v4/hc_pre.py (1)

278-357: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same HC_MULT == 4 assumption applies here.

The mix_x block loads exactly four pre values (pre0pre3) and four x chunks (x0x3), again hardcoding the assumption that HC_MULT == 4. This is the same issue as flagged in the comb sinkhorn section and should be addressed together.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/hc_pre.py` around lines 278 - 357, The mix_x block
currently hardcodes four loads (pre0–pre3) and four x chunks (x0–x3), assuming
HC_MULT == 4; change this to iterate over the actual HC_MULT value (or a
computed HCM = HC_MULT) so the number of pre loads and x loads is dynamic:
allocate or build lists/arrays of pre_i by looping t0 loads from pre_val_t, and
x_i by looping db loads from x_flat, then compute y_tile by summing
row_expand_mul(x_i, pre_i) across the loop instead of the fixed pl.add chain;
update the mix_x block where pre0..pre3, x0..x3, and the final y_tile
computation occur to use the looped variables so the code works for any HC_MULT.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/deepseek/v4/hc_pre.py`:
- Around line 177-266: The comb_sinkhorn loop currently hardcodes four rows
(row0..row3) and writes indices 0..3, which breaks if HC_MULT != 4; either add a
module-level static check asserting HC_MULT == 4 (e.g., assert HC_MULT == 4 ...)
to fail fast, or refactor the kernel to iterate over HC_MULT: replace the four
explicit loads/exp/soft/div/write blocks (references: comb_logits, comb_flat,
row0..row3, row*_exp, row*_soft, row*_cur) with loops over an index r in
range(HC_MULT) that allocate per-row tiles dynamically or use arrays/tiles
indexed by r, and make the final write compute base index using r so the code
works for any HC_MULT.

---

Duplicate comments:
In `@models/deepseek/v4/hc_pre.py`:
- Around line 278-357: The mix_x block currently hardcodes four loads
(pre0–pre3) and four x chunks (x0–x3), assuming HC_MULT == 4; change this to
iterate over the actual HC_MULT value (or a computed HCM = HC_MULT) so the
number of pre loads and x loads is dynamic: allocate or build lists/arrays of
pre_i by looping t0 loads from pre_val_t, and x_i by looping db loads from
x_flat, then compute y_tile by summing row_expand_mul(x_i, pre_i) across the
loop instead of the fixed pl.add chain; update the mix_x block where pre0..pre3,
x0..x3, and the final y_tile computation occur to use the looped variables so
the code works for any HC_MULT.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 18afe7bb-cad7-457e-8e2e-30bd59c0a39c

📥 Commits

Reviewing files that changed from the base of the PR and between 05127d2 and bee8dfc.

📒 Files selected for processing (1)
  • models/deepseek/v4/hc_pre.py

Comment on lines +177 to +266
for t0 in pl.parallel(0, T, COMB_T_TILE):
with pl.at(level=pl.Level.CORE_GROUP, name_hint="comb_sinkhorn"):
row0 = pl.fillpad(pl.load(
comb_logits,
[t0, 0 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row1 = pl.fillpad(pl.load(
comb_logits,
[t0, 1 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row2 = pl.fillpad(pl.load(
comb_logits,
[t0, 2 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row3 = pl.fillpad(pl.load(
comb_logits,
[t0, 3 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)

row_max_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec)
row_sum_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec)
row0_exp = pl.exp(pl.row_expand_sub(row0, pl.row_max(row0, row_max_tmp)))
row1_exp = pl.exp(pl.row_expand_sub(row1, pl.row_max(row1, row_max_tmp)))
row2_exp = pl.exp(pl.row_expand_sub(row2, pl.row_max(row2, row_max_tmp)))
row3_exp = pl.exp(pl.row_expand_sub(row3, pl.row_max(row3, row_max_tmp)))
row0_soft = pl.add(pl.row_expand_div(row0_exp, pl.row_sum(row0_exp, row_sum_tmp)), HC_EPS)
row1_soft = pl.add(pl.row_expand_div(row1_exp, pl.row_sum(row1_exp, row_sum_tmp)), HC_EPS)
row2_soft = pl.add(pl.row_expand_div(row2_exp, pl.row_sum(row2_exp, row_sum_tmp)), HC_EPS)
row3_soft = pl.add(pl.row_expand_div(row3_exp, pl.row_sum(row3_exp, row_sum_tmp)), HC_EPS)

row0_eff = pl.tile.fillpad(pl.tile.set_validshape(row0_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero)
row1_eff = pl.tile.fillpad(pl.tile.set_validshape(row1_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero)
row2_eff = pl.tile.fillpad(pl.tile.set_validshape(row2_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero)
row3_eff = pl.tile.fillpad(pl.tile.set_validshape(row3_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero)

row_sum_tmp_iter = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec)
col_sum = pl.add(pl.add(row0_eff, row1_eff), pl.add(row2_eff, row3_eff))
col_sum = pl.add(col_sum, HC_EPS)
row0_cur = pl.div(row0_norm, col_sum)
row1_cur = pl.div(row1_norm, col_sum)
row2_cur = pl.div(row2_norm, col_sum)
row3_cur = pl.div(row3_norm, col_sum)

for comb_t_idx in pl.unroll(T):
for c in pl.unroll(HC_MULT):
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c],
pl.read(row0_cur, [comb_t_idx, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c],
pl.read(row1_cur, [comb_t_idx, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c],
pl.read(row2_cur, [comb_t_idx, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c],
pl.read(row3_cur, [comb_t_idx, c]),
)
row0_cur = pl.div(row0_eff, col_sum)
row1_cur = pl.div(row1_eff, col_sum)
row2_cur = pl.div(row2_eff, col_sum)
row3_cur = pl.div(row3_eff, col_sum)

for _ in pl.unroll(HC_SINKHORN_ITER - 1):
row0_norm = pl.row_expand_div(row0_cur, pl.add(pl.row_sum(row0_cur, row_sum_tmp_iter), HC_EPS))
row1_norm = pl.row_expand_div(row1_cur, pl.add(pl.row_sum(row1_cur, row_sum_tmp_iter), HC_EPS))
row2_norm = pl.row_expand_div(row2_cur, pl.add(pl.row_sum(row2_cur, row_sum_tmp_iter), HC_EPS))
row3_norm = pl.row_expand_div(row3_cur, pl.add(pl.row_sum(row3_cur, row_sum_tmp_iter), HC_EPS))
col_sum = pl.add(pl.add(row0_norm, row1_norm), pl.add(row2_norm, row3_norm))
col_sum = pl.add(col_sum, HC_EPS)
row0_cur = pl.div(row0_norm, col_sum)
row1_cur = pl.div(row1_norm, col_sum)
row2_cur = pl.div(row2_norm, col_sum)
row3_cur = pl.div(row3_norm, col_sum)

for ti in pl.unroll(COMB_T_TILE):
for c in pl.unroll(HC_MULT):
comb_t_idx = t0 + ti
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c],
pl.read(row0_cur, [ti, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c],
pl.read(row1_cur, [ti, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c],
pl.read(row2_cur, [ti, c]),
)
pl.write(
comb_flat,
[comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c],
pl.read(row3_cur, [ti, c]),
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Hardcoded assumption that HC_MULT == 4 will break if model config changes.

The comb sinkhorn implementation explicitly loads/processes four rows (row0row3) and writes results for row indices 0, 1, 2, 3. If M.hc_mult differs from 4, this kernel will silently produce incorrect results or access out-of-bounds memory.

Consider adding a static assertion at module level to guard this assumption:

assert HC_MULT == 4, f"comb_sinkhorn kernel assumes HC_MULT == 4, got {HC_MULT}"

Or refactor to dynamically handle the HC_MULT dimension.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/deepseek/v4/hc_pre.py` around lines 177 - 266, The comb_sinkhorn loop
currently hardcodes four rows (row0..row3) and writes indices 0..3, which breaks
if HC_MULT != 4; either add a module-level static check asserting HC_MULT == 4
(e.g., assert HC_MULT == 4 ...) to fail fast, or refactor the kernel to iterate
over HC_MULT: replace the four explicit loads/exp/soft/div/write blocks
(references: comb_logits, comb_flat, row0..row3, row*_exp, row*_soft, row*_cur)
with loops over an index r in range(HC_MULT) that allocate per-row tiles
dynamically or use arrays/tiles indexed by r, and make the final write compute
base index using r so the code works for any HC_MULT.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the hc_pre kernel by introducing more granular tiling constants and parallelizing the RMS, linear, and Sinkhorn computation blocks. The feedback identifies that the mix_x and comb_sinkhorn implementations hardcode the number of hyper-connection components to four, which limits the kernel's flexibility. It is recommended to replace these hardcoded variables with loops using pl.unroll(HC_MULT) to ensure the implementation remains generic and maintainable if the configuration changes.

Comment on lines +280 to +357
pre0 = pl.reshape(
pl.load(
pre_val_t,
[0, t0],
[1, T_TILE],
target_memory=pl.MemorySpace.Vec,
),
[T_TILE, 1],
)
pre1 = pl.reshape(
pl.load(
pre_val_t,
[1, t0],
[1, T_TILE],
target_memory=pl.MemorySpace.Vec,
),
[T_TILE, 1],
)
pre2 = pl.reshape(
pl.load(
pre_val_t,
[2, t0],
[1, T_TILE],
target_memory=pl.MemorySpace.Vec,
),
[T_TILE, 1],
)
pre3 = pl.reshape(
pl.load(
pre_val_t,
[3, t0],
[1, T_TILE],
target_memory=pl.MemorySpace.Vec,
),
[T_TILE, 1],
)
for db in pl.range(D_BLOCKS):
d0 = db * D_CHUNK
y_row = pl.tile.full([1, D_CHUNK], dtype=pl.FP32, value=0.0)
for h in pl.range(HC_MULT):
pre_th = pl.read(pre_val_flat, [token_idx * HC_PAD + h])
x_row = pl.load(
x_flat_fp32,
[token_idx, h * D + d0],
[1, D_CHUNK],
x0 = pl.cast(
pl.load(
x_flat,
[t0, 0 * D + d0],
[T_TILE, D_CHUNK],
target_memory=pl.MemorySpace.Vec,
)
y_row = pl.add(y_row, pl.mul(x_row, pre_th))
),
target_type=pl.FP32,
)
x1 = pl.cast(
pl.load(
x_flat,
[t0, 1 * D + d0],
[T_TILE, D_CHUNK],
target_memory=pl.MemorySpace.Vec,
),
target_type=pl.FP32,
)
x2 = pl.cast(
pl.load(
x_flat,
[t0, 2 * D + d0],
[T_TILE, D_CHUNK],
target_memory=pl.MemorySpace.Vec,
),
target_type=pl.FP32,
)
x3 = pl.cast(
pl.load(
x_flat,
[t0, 3 * D + d0],
[T_TILE, D_CHUNK],
target_memory=pl.MemorySpace.Vec,
),
target_type=pl.FP32,
)
y_tile = pl.add(
pl.add(pl.row_expand_mul(x0, pre0), pl.row_expand_mul(x1, pre1)),
pl.add(pl.row_expand_mul(x2, pre2), pl.row_expand_mul(x3, pre3)),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mix_x implementation hardcodes the number of hyper-connection components to 4 by using explicit variables pre0..3 and x0..3. This makes the kernel non-generic and fragile if HC_MULT is changed in the configuration. Although this unrolling was likely done for performance (moving pre loads out of the inner loop), it should be implemented using a loop with pl.unroll(HC_MULT) to maintain maintainability while preserving the optimization.

Comment on lines +179 to +206
row0 = pl.fillpad(pl.load(
comb_logits,
[t0, 0 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row1 = pl.fillpad(pl.load(
comb_logits,
[t0, 1 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row2 = pl.fillpad(pl.load(
comb_logits,
[t0, 2 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
row3 = pl.fillpad(pl.load(
comb_logits,
[t0, 3 * HC_MULT],
[COMB_T_TILE, HC_PAD],
valid_shapes=[COMB_T_TILE, HC_MULT],
target_memory=pl.MemorySpace.Vec,
), pad_value=pl.PadValue.min)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comb_sinkhorn block hardcodes the loading of 4 rows (row0..3). While this matches the current HC_MULT=4 configuration, it limits the kernel's flexibility. Consider using a loop to load and process rows generically, which would allow the kernel to support different hyper-connection widths without manual code changes.

@high-cloud high-cloud force-pushed the optimize/hc-pre branch 2 times, most recently from b06a6eb to 99b8d01 Compare May 20, 2026 04:51
- Tile RMS, linear, comb sinkhorn, and mix_x over the token axis
- Remove the standalone x cast stage and cast where data is consumed
- Use larger linear K tiles while keeping RMS tiles within vector limits
@zhangqi-chen zhangqi-chen merged commit 6dbe78c into hw-native-sys:main May 20, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants