Update: optimize DeepSeek hc_pre tiling#317
Conversation
|
Warning Rate limit exceeded
You’ve run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis PR restructures the DeepSeek-V4 hyper-connections pre-mix kernel with independent tiling schemes for RMS, linear mixing, and comb sinkhorn stages. The compute flow now uses separate tile sizes and chunk sizes, stores pre-activation values with a transpose for efficient access, rewrites comb sinkhorn with tiled row-softmax, and reaccumulates x_mixed using the pre-transposed layout and D_CHUNK-blocked x data. ChangesHC Pre-Mix Kernel Tiling Restructure
Sequence Diagram(s)flowchart TD
A["Allocate tiled intermediates<br/>inv_rms, mixes, mix_raw, pre_val_store, pre_val_t"] --> B["Compute inv_rms<br/>Loop: RMS_T_TILE × RMS_K_BLOCKS"]
B --> C["Compute mix_raw<br/>Loop: LINEAR_T_TILE × LINEAR_K_BLOCKS"]
C --> D["Assemble pre_val_store<br/>from pre_val"]
D --> E["Transpose pre_val_store<br/>into pre_val_t"]
E --> F["Comb sinkhorn<br/>Process COMB_T_TILE row blocks"]
F --> G["Row-softmax per block<br/>with max-subtraction + eps"]
G --> H["Tiled Sinkhorn iterations<br/>row/column normalization"]
H --> I["x_mixed accumulation<br/>Reshape to T,D view"]
I --> J["Load pre_val_t rows<br/>pre0–pre3 per T_TILE"]
J --> K["Load x chunks<br/>BF16→FP32 in D_CHUNK blocks"]
K --> L["Build y_tile<br/>Expand-multiply pre × x"]
L --> M["Store x_mixed_view<br/>reshape to B,S,D"]
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
models/deepseek/v4/hc_pre.py (1)
278-357:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSame
HC_MULT == 4assumption applies here.The
mix_xblock loads exactly four pre values (pre0–pre3) and four x chunks (x0–x3), again hardcoding the assumption thatHC_MULT == 4. This is the same issue as flagged in the comb sinkhorn section and should be addressed together.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/deepseek/v4/hc_pre.py` around lines 278 - 357, The mix_x block currently hardcodes four loads (pre0–pre3) and four x chunks (x0–x3), assuming HC_MULT == 4; change this to iterate over the actual HC_MULT value (or a computed HCM = HC_MULT) so the number of pre loads and x loads is dynamic: allocate or build lists/arrays of pre_i by looping t0 loads from pre_val_t, and x_i by looping db loads from x_flat, then compute y_tile by summing row_expand_mul(x_i, pre_i) across the loop instead of the fixed pl.add chain; update the mix_x block where pre0..pre3, x0..x3, and the final y_tile computation occur to use the looped variables so the code works for any HC_MULT.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/deepseek/v4/hc_pre.py`:
- Around line 177-266: The comb_sinkhorn loop currently hardcodes four rows
(row0..row3) and writes indices 0..3, which breaks if HC_MULT != 4; either add a
module-level static check asserting HC_MULT == 4 (e.g., assert HC_MULT == 4 ...)
to fail fast, or refactor the kernel to iterate over HC_MULT: replace the four
explicit loads/exp/soft/div/write blocks (references: comb_logits, comb_flat,
row0..row3, row*_exp, row*_soft, row*_cur) with loops over an index r in
range(HC_MULT) that allocate per-row tiles dynamically or use arrays/tiles
indexed by r, and make the final write compute base index using r so the code
works for any HC_MULT.
---
Duplicate comments:
In `@models/deepseek/v4/hc_pre.py`:
- Around line 278-357: The mix_x block currently hardcodes four loads
(pre0–pre3) and four x chunks (x0–x3), assuming HC_MULT == 4; change this to
iterate over the actual HC_MULT value (or a computed HCM = HC_MULT) so the
number of pre loads and x loads is dynamic: allocate or build lists/arrays of
pre_i by looping t0 loads from pre_val_t, and x_i by looping db loads from
x_flat, then compute y_tile by summing row_expand_mul(x_i, pre_i) across the
loop instead of the fixed pl.add chain; update the mix_x block where pre0..pre3,
x0..x3, and the final y_tile computation occur to use the looped variables so
the code works for any HC_MULT.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 18afe7bb-cad7-457e-8e2e-30bd59c0a39c
📒 Files selected for processing (1)
models/deepseek/v4/hc_pre.py
| for t0 in pl.parallel(0, T, COMB_T_TILE): | ||
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="comb_sinkhorn"): | ||
| row0 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 0 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row1 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 1 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row2 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 2 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row3 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 3 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
|
|
||
| row_max_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) | ||
| row_sum_tmp = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) | ||
| row0_exp = pl.exp(pl.row_expand_sub(row0, pl.row_max(row0, row_max_tmp))) | ||
| row1_exp = pl.exp(pl.row_expand_sub(row1, pl.row_max(row1, row_max_tmp))) | ||
| row2_exp = pl.exp(pl.row_expand_sub(row2, pl.row_max(row2, row_max_tmp))) | ||
| row3_exp = pl.exp(pl.row_expand_sub(row3, pl.row_max(row3, row_max_tmp))) | ||
| row0_soft = pl.add(pl.row_expand_div(row0_exp, pl.row_sum(row0_exp, row_sum_tmp)), HC_EPS) | ||
| row1_soft = pl.add(pl.row_expand_div(row1_exp, pl.row_sum(row1_exp, row_sum_tmp)), HC_EPS) | ||
| row2_soft = pl.add(pl.row_expand_div(row2_exp, pl.row_sum(row2_exp, row_sum_tmp)), HC_EPS) | ||
| row3_soft = pl.add(pl.row_expand_div(row3_exp, pl.row_sum(row3_exp, row_sum_tmp)), HC_EPS) | ||
|
|
||
| row0_eff = pl.tile.fillpad(pl.tile.set_validshape(row0_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) | ||
| row1_eff = pl.tile.fillpad(pl.tile.set_validshape(row1_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) | ||
| row2_eff = pl.tile.fillpad(pl.tile.set_validshape(row2_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) | ||
| row3_eff = pl.tile.fillpad(pl.tile.set_validshape(row3_soft, COMB_T_TILE, HC_MULT), pad_value=pl.PadValue.zero) | ||
|
|
||
| row_sum_tmp_iter = pl.create_tile([COMB_T_TILE, 1], dtype=pl.FP32, target_memory=pl.MemorySpace.Vec) | ||
| col_sum = pl.add(pl.add(row0_eff, row1_eff), pl.add(row2_eff, row3_eff)) | ||
| col_sum = pl.add(col_sum, HC_EPS) | ||
| row0_cur = pl.div(row0_norm, col_sum) | ||
| row1_cur = pl.div(row1_norm, col_sum) | ||
| row2_cur = pl.div(row2_norm, col_sum) | ||
| row3_cur = pl.div(row3_norm, col_sum) | ||
|
|
||
| for comb_t_idx in pl.unroll(T): | ||
| for c in pl.unroll(HC_MULT): | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c], | ||
| pl.read(row0_cur, [comb_t_idx, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c], | ||
| pl.read(row1_cur, [comb_t_idx, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c], | ||
| pl.read(row2_cur, [comb_t_idx, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c], | ||
| pl.read(row3_cur, [comb_t_idx, c]), | ||
| ) | ||
| row0_cur = pl.div(row0_eff, col_sum) | ||
| row1_cur = pl.div(row1_eff, col_sum) | ||
| row2_cur = pl.div(row2_eff, col_sum) | ||
| row3_cur = pl.div(row3_eff, col_sum) | ||
|
|
||
| for _ in pl.unroll(HC_SINKHORN_ITER - 1): | ||
| row0_norm = pl.row_expand_div(row0_cur, pl.add(pl.row_sum(row0_cur, row_sum_tmp_iter), HC_EPS)) | ||
| row1_norm = pl.row_expand_div(row1_cur, pl.add(pl.row_sum(row1_cur, row_sum_tmp_iter), HC_EPS)) | ||
| row2_norm = pl.row_expand_div(row2_cur, pl.add(pl.row_sum(row2_cur, row_sum_tmp_iter), HC_EPS)) | ||
| row3_norm = pl.row_expand_div(row3_cur, pl.add(pl.row_sum(row3_cur, row_sum_tmp_iter), HC_EPS)) | ||
| col_sum = pl.add(pl.add(row0_norm, row1_norm), pl.add(row2_norm, row3_norm)) | ||
| col_sum = pl.add(col_sum, HC_EPS) | ||
| row0_cur = pl.div(row0_norm, col_sum) | ||
| row1_cur = pl.div(row1_norm, col_sum) | ||
| row2_cur = pl.div(row2_norm, col_sum) | ||
| row3_cur = pl.div(row3_norm, col_sum) | ||
|
|
||
| for ti in pl.unroll(COMB_T_TILE): | ||
| for c in pl.unroll(HC_MULT): | ||
| comb_t_idx = t0 + ti | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 0 * HC_MULT + c], | ||
| pl.read(row0_cur, [ti, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 1 * HC_MULT + c], | ||
| pl.read(row1_cur, [ti, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 2 * HC_MULT + c], | ||
| pl.read(row2_cur, [ti, c]), | ||
| ) | ||
| pl.write( | ||
| comb_flat, | ||
| [comb_t_idx * HC_MULT * HC_MULT + 3 * HC_MULT + c], | ||
| pl.read(row3_cur, [ti, c]), | ||
| ) |
There was a problem hiding this comment.
Hardcoded assumption that HC_MULT == 4 will break if model config changes.
The comb sinkhorn implementation explicitly loads/processes four rows (row0–row3) and writes results for row indices 0, 1, 2, 3. If M.hc_mult differs from 4, this kernel will silently produce incorrect results or access out-of-bounds memory.
Consider adding a static assertion at module level to guard this assumption:
assert HC_MULT == 4, f"comb_sinkhorn kernel assumes HC_MULT == 4, got {HC_MULT}"Or refactor to dynamically handle the HC_MULT dimension.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/deepseek/v4/hc_pre.py` around lines 177 - 266, The comb_sinkhorn loop
currently hardcodes four rows (row0..row3) and writes indices 0..3, which breaks
if HC_MULT != 4; either add a module-level static check asserting HC_MULT == 4
(e.g., assert HC_MULT == 4 ...) to fail fast, or refactor the kernel to iterate
over HC_MULT: replace the four explicit loads/exp/soft/div/write blocks
(references: comb_logits, comb_flat, row0..row3, row*_exp, row*_soft, row*_cur)
with loops over an index r in range(HC_MULT) that allocate per-row tiles
dynamically or use arrays/tiles indexed by r, and make the final write compute
base index using r so the code works for any HC_MULT.
There was a problem hiding this comment.
Code Review
This pull request refactors the hc_pre kernel by introducing more granular tiling constants and parallelizing the RMS, linear, and Sinkhorn computation blocks. The feedback identifies that the mix_x and comb_sinkhorn implementations hardcode the number of hyper-connection components to four, which limits the kernel's flexibility. It is recommended to replace these hardcoded variables with loops using pl.unroll(HC_MULT) to ensure the implementation remains generic and maintainable if the configuration changes.
| pre0 = pl.reshape( | ||
| pl.load( | ||
| pre_val_t, | ||
| [0, t0], | ||
| [1, T_TILE], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| [T_TILE, 1], | ||
| ) | ||
| pre1 = pl.reshape( | ||
| pl.load( | ||
| pre_val_t, | ||
| [1, t0], | ||
| [1, T_TILE], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| [T_TILE, 1], | ||
| ) | ||
| pre2 = pl.reshape( | ||
| pl.load( | ||
| pre_val_t, | ||
| [2, t0], | ||
| [1, T_TILE], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| [T_TILE, 1], | ||
| ) | ||
| pre3 = pl.reshape( | ||
| pl.load( | ||
| pre_val_t, | ||
| [3, t0], | ||
| [1, T_TILE], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| [T_TILE, 1], | ||
| ) | ||
| for db in pl.range(D_BLOCKS): | ||
| d0 = db * D_CHUNK | ||
| y_row = pl.tile.full([1, D_CHUNK], dtype=pl.FP32, value=0.0) | ||
| for h in pl.range(HC_MULT): | ||
| pre_th = pl.read(pre_val_flat, [token_idx * HC_PAD + h]) | ||
| x_row = pl.load( | ||
| x_flat_fp32, | ||
| [token_idx, h * D + d0], | ||
| [1, D_CHUNK], | ||
| x0 = pl.cast( | ||
| pl.load( | ||
| x_flat, | ||
| [t0, 0 * D + d0], | ||
| [T_TILE, D_CHUNK], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ) | ||
| y_row = pl.add(y_row, pl.mul(x_row, pre_th)) | ||
| ), | ||
| target_type=pl.FP32, | ||
| ) | ||
| x1 = pl.cast( | ||
| pl.load( | ||
| x_flat, | ||
| [t0, 1 * D + d0], | ||
| [T_TILE, D_CHUNK], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| target_type=pl.FP32, | ||
| ) | ||
| x2 = pl.cast( | ||
| pl.load( | ||
| x_flat, | ||
| [t0, 2 * D + d0], | ||
| [T_TILE, D_CHUNK], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| target_type=pl.FP32, | ||
| ) | ||
| x3 = pl.cast( | ||
| pl.load( | ||
| x_flat, | ||
| [t0, 3 * D + d0], | ||
| [T_TILE, D_CHUNK], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), | ||
| target_type=pl.FP32, | ||
| ) | ||
| y_tile = pl.add( | ||
| pl.add(pl.row_expand_mul(x0, pre0), pl.row_expand_mul(x1, pre1)), | ||
| pl.add(pl.row_expand_mul(x2, pre2), pl.row_expand_mul(x3, pre3)), | ||
| ) |
There was a problem hiding this comment.
The mix_x implementation hardcodes the number of hyper-connection components to 4 by using explicit variables pre0..3 and x0..3. This makes the kernel non-generic and fragile if HC_MULT is changed in the configuration. Although this unrolling was likely done for performance (moving pre loads out of the inner loop), it should be implemented using a loop with pl.unroll(HC_MULT) to maintain maintainability while preserving the optimization.
| row0 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 0 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row1 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 1 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row2 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 2 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) | ||
| row3 = pl.fillpad(pl.load( | ||
| comb_logits, | ||
| [t0, 3 * HC_MULT], | ||
| [COMB_T_TILE, HC_PAD], | ||
| valid_shapes=[COMB_T_TILE, HC_MULT], | ||
| target_memory=pl.MemorySpace.Vec, | ||
| ), pad_value=pl.PadValue.min) |
There was a problem hiding this comment.
The comb_sinkhorn block hardcodes the loading of 4 rows (row0..3). While this matches the current HC_MULT=4 configuration, it limits the kernel's flexibility. Consider using a loop to load and process rows generically, which would allow the kernel to support different hyper-connection widths without manual code changes.
b06a6eb to
99b8d01
Compare
- Tile RMS, linear, comb sinkhorn, and mix_x over the token axis - Remove the standalone x cast stage and cast where data is consumed - Use larger linear K tiles while keeping RMS tiles within vector limits
99b8d01 to
71a3cc5
Compare
Summary
pl.transpose+pl.storeRelated Issues
None