refactor: drop deprecated pl.auto_chunk / chunked_loop_optimizer#372
refactor: drop deprecated pl.auto_chunk / chunked_loop_optimizer#372lyfne123 wants to merge 1 commit into
Conversation
pypto#1504 removed pl.at(optimization=, split=) and the chunked_loop_optimizer sentinel, and deprecated pl.auto_chunk. Migrate all 22 kernels to explicit chunk loops: outer pl.parallel over chunk stride, pl.at inside, inner pl.range over the original step. chunk= values and name_hint preserved; pl.split kept where present. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis PR refactors parallel loop scheduling across 16 kernel files in the PyPTO library, converting ChangesLoop Scheduling Refactoring
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related issues
Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request migrates various examples and model implementations from the deprecated chunked_loop_optimizer and auto_chunk to explicit chunk loops. The review feedback highlights several critical bugs where inner loop bounds are not properly clamped, potentially causing out-of-bounds memory accesses when dimensions are not multiples of the chunk size. Additionally, the reviewer noted inconsistencies such as redundant nested pl.at blocks, variable shadowing, and incomplete refactoring of some parallel loops still using the deprecated chunk= parameter.
| resid = pl.assemble(resid, resid_sum, [0, n0]) | ||
| for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]): | ||
| for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1): |
There was a problem hiding this comment.
The inner loop pl.range(nb_chunk, nb_chunk + 1 * chunk, 1) does not account for cases where n_blocks is not a multiple of chunk. In the final iteration of the outer pl.parallel loop, nb will exceed n_blocks, leading to out-of-bounds memory accesses when slicing tensors (e.g., wo at line 59). Use pl.min to clamp the stop condition.
| for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1): | |
| for nb in pl.range(nb_chunk, pl.min(nb_chunk + 1 * chunk, n_blocks), 1): |
| for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile): | ||
| for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile): |
There was a problem hiding this comment.
These inner loops can exceed the dimensions m and n if they are not multiples of the total chunk size (m_tile * m_chunk and n_tile * n_chunk). This will result in out-of-bounds slicing of tensors a, b, and c. You should use pl.min to ensure the loop indices stay within the valid range.
| for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile): | |
| for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile): | |
| for mb in pl.range(mb_chunk, pl.min(mb_chunk + m_tile * m_chunk, m), m_tile): | |
| for nb in pl.range(nb_chunk, pl.min(nb_chunk + n_tile * n_chunk, n), n_tile): |
| for b_chunk in pl.parallel(0, BATCH_CFG, 4): | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| layer_id = pl.tensor.read(layer_id_t, [0]) | ||
| for b in pl.range(b_chunk, b_chunk + 4): |
There was a problem hiding this comment.
The inner loop pl.range(b_chunk, b_chunk + 4) will iterate beyond BATCH_CFG if it is not a multiple of 4. This will cause invalid memory reads from seq_lens at line 146 and hidden_states at line 164. Use pl.min(b_chunk + 4, BATCH_CFG) as the stop condition.
| for b in pl.range(b_chunk, b_chunk + 4): | |
| for b in pl.range(b_chunk, pl.min(b_chunk + 4, BATCH_CFG)): |
|
|
||
| Stage 0 (matmul: attn_out x wo) and Stage 1 (residual add) can be: | ||
| - Fused: single pl.at block with chunked_loop_optimizer (mix mode) | ||
| - Fused: single pl.at block with auto_chunk (mix mode) |
There was a problem hiding this comment.
The comment is being updated to use auto_chunk, but auto_chunk is deprecated and its usage is being removed from the code in this PR. Since the implementation has migrated to explicit chunk loops, the documentation should reflect this change to avoid confusion.
| - Fused: single pl.at block with auto_chunk (mix mode) | |
| - Fused: single pl.at block with explicit chunk loops (mix mode) |
| # Fused Q path (local fusion trial for former incore_0/1): | ||
| # directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b | ||
| # without materializing full qr_tile. | ||
| for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): |
There was a problem hiding this comment.
This pl.parallel loop (and others at lines 199 and 374) still uses the deprecated chunk= parameter, but the wrapping pl.at block at line 143 no longer includes the chunked_loop_optimizer. This is inconsistent with the PR's objective of migrating to explicit chunk loops. These inner loops should also be refactored.
| v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk)) | ||
| k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0]) | ||
| v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0]) | ||
| for ob_chunk in pl.parallel(0, KV_OUT_BLOCKS, 8): |
| with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_q_proj"): | ||
| for ob_chunk in pl.parallel(0, q_out_blocks, 4): | ||
| with pl.at(level=pl.Level.CORE_GROUP): |
There was a problem hiding this comment.
There is a redundant nested pl.at block structure here. The outer pl.at (line 250) wraps the pl.parallel loop, while the inner pl.at (line 252) is inside it. This is inconsistent with the pattern established in other files in this PR. You should remove the outer pl.at and move its name_hint to the inner block.
There was a problem hiding this comment.
Code Review
This pull request migrates various kernels across example scripts and model implementations (DeepSeek, Kimi, MILM, Qwen) from implicit chunking via chunked_loop_optimizer or auto_chunk to explicit chunked loop structures using nested pl.parallel and pl.range constructs. Feedback highlights potential out-of-bounds access issues in several files where the inner loop upper bounds do not account for cases where the total number of blocks is not a multiple of the chunk size. Additionally, some nested loops were found to still use the deprecated chunk parameter, which is inconsistent with the migration's goal of using explicit chunking.
| resid = pl.assemble(resid, resid_sum, [0, n0]) | ||
| for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk): | ||
| with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]): | ||
| for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1): |
There was a problem hiding this comment.
The inner loop upper bound nb_chunk + 1 * chunk does not account for cases where n_blocks is not a multiple of chunk. This can lead to out-of-bounds access in the last chunk. It is safer to use pl.min to clamp the bound.
| for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1): | |
| for nb in pl.range(nb_chunk, pl.min(n_blocks, nb_chunk + 1 * chunk), 1): |
| all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]) | ||
| for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): |
There was a problem hiding this comment.
ctx_blocks is dynamic. The inner loop should be guarded with pl.min(ctx_blocks, sb_chunk + SB_BATCH) to prevent out-of-bounds access when ctx_blocks is not a multiple of SB_BATCH. This also applies to other ctx_blocks loops in this file (lines 289, 310).
| for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): | |
| for sb in pl.range(sb_chunk, pl.min(ctx_blocks, sb_chunk + SB_BATCH)): |
| # Fused Q path (local fusion trial for former incore_0/1): | ||
| # directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b | ||
| # without materializing full qr_tile. | ||
| for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8): |
There was a problem hiding this comment.
This nested pl.parallel loop still uses the deprecated chunk=8 parameter. According to the PR description, all kernels should be migrated to explicit chunk loops. This inconsistency may lead to compilation errors or unoptimized code since the surrounding pl.at block no longer includes the necessary optimizers. Similar instances occur at lines 200 and 374.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/advanced/gemm_eltwise.py`:
- Around line 14-15: Update the mix-mode documentation that still references
`auto_chunk` to describe the current implementation’s explicit chunking: replace
mentions of `auto_chunk` with a description that the fused/mix mode uses an
outer `pl.parallel(...)` to split work into chunks and an inner `pl.range(...)`
loop for per-chunk iteration (i.e., explicit outer parallel + inner range
chunking). Edit the comment blocks in the example that describe "Fused: single
pl.at block with auto_chunk (mix mode)" so they instead mention the outer
`pl.parallel` + inner `pl.range` chunking pattern (also update the analogous
phrasing later in the file where `auto_chunk` appears).
In `@models/qwen3/14b/qwen3_14b_l3_generate.py`:
- Around line 460-463: The chunked inner loop over sb (using pl.range(sb_chunk,
sb_chunk + SB_BATCH)) can run past the valid ctx_blocks when ctx_blocks <
SB_BATCH or not divisible; change the iteration to compute an end guard (e.g.,
sb_end = min(sb_chunk + SB_BATCH, ctx_blocks)) and iterate pl.range(sb_chunk,
sb_end) (and use sb_end wherever the chunk end is assumed) in the prefill
attention loops that contain pl.parallel, pl.at, pl.range and the
block_table_idx = b * max_blocks_per_seq + sb calculation to avoid
reading/writing beyond valid attention blocks.
- Around line 1060-1063: The decode-attention loop that iterates sb from
pl.range(sb_chunk, sb_chunk + SB_BATCH) can run past ctx_blocks for the final
partial chunk; update the loop to cap the upper bound or guard each iteration so
sb < ctx_blocks to avoid indexing invalid block_table/cache rows and temporary
buffers (e.g., compute end = min(sb_chunk + SB_BATCH, ctx_blocks) and use
pl.range(sb_chunk, end) or add an if sb < ctx_blocks inside the inner loop).
Apply this fix to the occurrences around pl.parallel / pl.at / pl.range using
sb_chunk, SB_BATCH, ctx_blocks, and block_table_idx (the same pattern appears at
the other two spots mentioned).
In `@models/qwen3/32b/qwen3_32b_prefill_draft.py`:
- Around line 277-280: The sb loop can iterate past ctx_blocks when
sb_chunk+SB_BATCH exceeds ctx_blocks; update the chunked loops using
pl.range(sb_chunk, min(sb_chunk + SB_BATCH, ctx_blocks)) or add an inner
conditional guard (if sb >= ctx_blocks: break/continue) before computing s0 = sb
* SEQ_TILE to prevent out-of-bounds access; apply the same fix to the other
chunked sb loops that use SB_BATCH, pl.range and compute s0 so all prefill
attention stages check bounds against ctx_blocks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 72991c5a-1da3-4703-8c60-e6b8498d9d8a
📒 Files selected for processing (21)
examples/advanced/gemm_eltwise.pyexamples/beginner/hello_world.pyexamples/beginner/matmul.pyexamples/intermediate/gemm.pyexamples/intermediate/layer_norm.pyexamples/intermediate/rms_norm.pyexamples/intermediate/rope.pyexamples/intermediate/softmax.pymodels/deepseek/v3_2/deepseek_v3_2_decode_front.pymodels/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.pymodels/deepseek/v4/decode_attention_hca.pymodels/deepseek/v4/decode_attention_swa.pymodels/deepseek/v4/decode_compressor_ratio128.pymodels/deepseek/v4/decode_indexer.pymodels/deepseek/v4/hc_post.pymodels/deepseek/v4/qkv_proj_rope.pymodels/kimi/kimi_k2_decode_draft.pymodels/milm/milm_decode_draft.pymodels/qwen3/14b/qwen3_14b_l3_generate.pymodels/qwen3/32b/qwen3_32b_decode.pymodels/qwen3/32b/qwen3_32b_prefill_draft.py
| - Fused: single pl.at block with auto_chunk (mix mode) | ||
| - Split: separate pl.at blocks for each stage (split mode) |
There was a problem hiding this comment.
Update mix-mode docs to match explicit chunk loops.
Line 14 and Line 39 still mention auto_chunk, but this implementation now uses explicit outer pl.parallel + inner pl.range chunking.
✏️ Suggested doc fix
- - Fused: single pl.at block with auto_chunk (mix mode)
+ - Fused: single pl.at block with explicit chunk loops (mix mode)
@@
- """Build fused matmul + elementwise program with auto_chunk."""
+ """Build fused matmul + elementwise program with explicit chunk loops."""Also applies to: 39-39
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/advanced/gemm_eltwise.py` around lines 14 - 15, Update the mix-mode
documentation that still references `auto_chunk` to describe the current
implementation’s explicit chunking: replace mentions of `auto_chunk` with a
description that the fused/mix mode uses an outer `pl.parallel(...)` to split
work into chunks and an inner `pl.range(...)` loop for per-chunk iteration
(i.e., explicit outer parallel + inner range chunking). Edit the comment blocks
in the example that describe "Fused: single pl.at block with auto_chunk (mix
mode)" so they instead mention the outer `pl.parallel` + inner `pl.range`
chunking pattern (also update the analogous phrasing later in the file where
`auto_chunk` appears).
| for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): | ||
| block_table_idx = b * max_blocks_per_seq + sb |
There was a problem hiding this comment.
Add a tail guard for chunked sb iteration in prefill attention stages.
Line 462 / Line 483 / Line 515 iterate up to sb_chunk + SB_BATCH unconditionally; when ctx_blocks < SB_BATCH (or not divisible by SB_BATCH) this reads/writes past valid attention blocks.
💡 Suggested fix
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_qk_matmul"):
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
block_table_idx = b * max_blocks_per_seq + sb
pbid = pl.cast(
pl.tensor.read(block_table, [block_table_idx]), pl.INDEX
)
cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE
k_tile = pl.slice(
k_cache_all,
[SEQ_TILE, head_dim],
[layer_off_cache + cache_row0, 0],
)
raw_scores = pl.matmul(
q_padded, k_tile, b_trans=True, out_dtype=pl.FP32
)
all_raw_scores = pl.assemble(
all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]
)
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_softmax"):
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
...
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_sv_matmul"):
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
block_table_idx = b * max_blocks_per_seq + sb
...Also applies to: 481-484, 513-516
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 460 - 463, The
chunked inner loop over sb (using pl.range(sb_chunk, sb_chunk + SB_BATCH)) can
run past the valid ctx_blocks when ctx_blocks < SB_BATCH or not divisible;
change the iteration to compute an end guard (e.g., sb_end = min(sb_chunk +
SB_BATCH, ctx_blocks)) and iterate pl.range(sb_chunk, sb_end) (and use sb_end
wherever the chunk end is assumed) in the prefill attention loops that contain
pl.parallel, pl.at, pl.range and the block_table_idx = b * max_blocks_per_seq +
sb calculation to avoid reading/writing beyond valid attention blocks.
| for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): | ||
| block_table_idx = block_table_base + sb |
There was a problem hiding this comment.
Decode attention has the same out-of-range chunk-tail bug on sb.
Line 1062 / Line 1081 / Line 1115 loop to sb_chunk + SB_BATCH without sb < ctx_blocks checks, so partial chunks can index invalid block-table/cache regions and invalid temporary rows.
💡 Suggested fix
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
block_table_idx = block_table_base + sb
pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX)
...
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
s0 = sb * BLOCK_SIZE
valid_len = pl.min(BLOCK_SIZE, ctx_len - s0)
...
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
block_table_idx = block_table_base + sb
...Also applies to: 1079-1082, 1113-1116
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 1060 - 1063, The
decode-attention loop that iterates sb from pl.range(sb_chunk, sb_chunk +
SB_BATCH) can run past ctx_blocks for the final partial chunk; update the loop
to cap the upper bound or guard each iteration so sb < ctx_blocks to avoid
indexing invalid block_table/cache rows and temporary buffers (e.g., compute end
= min(sb_chunk + SB_BATCH, ctx_blocks) and use pl.range(sb_chunk, end) or add an
if sb < ctx_blocks inside the inner loop). Apply this fix to the occurrences
around pl.parallel / pl.at / pl.range using sb_chunk, SB_BATCH, ctx_blocks, and
block_table_idx (the same pattern appears at the other two spots mentioned).
| for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH): | ||
| with pl.at(level=pl.Level.CORE_GROUP): | ||
| for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH): | ||
| s0 = sb * SEQ_TILE |
There was a problem hiding this comment.
Chunked sb loops need a bounds check in prefill attention stages.
Line 279 / Line 289 / Line 310 currently execute beyond ctx_blocks for non-full chunks (and typically here since SB_BATCH is much larger than active context blocks), which risks invalid cache/block accesses and invalid temporary-tile indexing.
💡 Suggested fix
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
k_tile = pl.slice(k_cache, [SEQ_TILE, head_dim], [cache_row0, 0])
raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
s0 = sb * SEQ_TILE
valid_len = pl.min(SEQ_TILE, ctx_len - s0)
...
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+ if sb < ctx_blocks:
s0 = sb * SEQ_TILE
cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
...Also applies to: 287-290, 308-311
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@models/qwen3/32b/qwen3_32b_prefill_draft.py` around lines 277 - 280, The sb
loop can iterate past ctx_blocks when sb_chunk+SB_BATCH exceeds ctx_blocks;
update the chunked loops using pl.range(sb_chunk, min(sb_chunk + SB_BATCH,
ctx_blocks)) or add an inner conditional guard (if sb >= ctx_blocks:
break/continue) before computing s0 = sb * SEQ_TILE to prevent out-of-bounds
access; apply the same fix to the other chunked sb loops that use SB_BATCH,
pl.range and compute s0 so all prefill attention stages check bounds against
ctx_blocks.
pypto#1504 removed pl.at(optimization=, split=) and the chunked_loop_optimizer sentinel, and deprecated pl.auto_chunk. Migrate all 22 kernels to explicit chunk loops: outer pl.parallel over chunk stride, pl.at inside, inner pl.range over the original step. chunk= values and name_hint preserved; pl.split kept where present.