Skip to content

refactor: drop deprecated pl.auto_chunk / chunked_loop_optimizer#372

Open
lyfne123 wants to merge 1 commit into
hw-native-sys:mainfrom
lyfne123:refactor/remove-auto-chunk-explicit-loops
Open

refactor: drop deprecated pl.auto_chunk / chunked_loop_optimizer#372
lyfne123 wants to merge 1 commit into
hw-native-sys:mainfrom
lyfne123:refactor/remove-auto-chunk-explicit-loops

Conversation

@lyfne123
Copy link
Copy Markdown
Contributor

pypto#1504 removed pl.at(optimization=, split=) and the chunked_loop_optimizer sentinel, and deprecated pl.auto_chunk. Migrate all 22 kernels to explicit chunk loops: outer pl.parallel over chunk stride, pl.at inside, inner pl.range over the original step. chunk= values and name_hint preserved; pl.split kept where present.

pypto#1504 removed pl.at(optimization=, split=) and the
chunked_loop_optimizer sentinel, and deprecated pl.auto_chunk. Migrate
all 22 kernels to explicit chunk loops: outer pl.parallel over chunk
stride, pl.at inside, inner pl.range over the original step. chunk=
values and name_hint preserved; pl.split kept where present.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 25, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR refactors parallel loop scheduling across 16 kernel files in the PyPTO library, converting pl.parallel(..., chunk=X) patterns into explicit two-level nested iteration: outer pl.parallel(start, end, chunk_size) loops over chunk blocks paired with inner pl.range loops over elements within each chunk. Across the same scopes, optimization=pl.chunked_loop_optimizer and optimizations=[pl.auto_chunk] directives are removed from pl.at() scopes. No algorithm logic or exported signatures change.

Changes

Loop Scheduling Refactoring

Layer / File(s) Summary
Example programs: beginner and intermediate kernels
examples/advanced/gemm_eltwise.py, examples/beginner/hello_world.py, examples/beginner/matmul.py, examples/intermediate/gemm.py, examples/intermediate/layer_norm.py, examples/intermediate/rms_norm.py, examples/intermediate/rope.py, examples/intermediate/softmax.py
Beginner and intermediate kernels (GEMM+eltwise fused residual, scalar add, matmul, LayerNorm, RMSNorm, RoPE, softmax) restructure chunked row/tile/batch iteration from implicit chunking to explicit outer parallel chunk loops plus inner element-wise ranges; optimization directives are removed while compute logic remains unchanged.
Deepseek V3.2: batch loop and optimization restructuring
models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py, models/deepseek/v3_2/deepseek_v3_2_decode_front.py
Prefill front refactors outer batch iteration to parallel over b_chunk (step 4) with inner pl.range(b_chunk, b_chunk + 4), moving layer_id read into the parallel region and removing chunked_loop_optimizer hint; decode front removes auto_chunk optimization wrapper from the Q index reduction stage.
Deepseek V4: multi-file loop scheduling updates
models/deepseek/v4/decode_attention_hca.py, models/deepseek/v4/decode_attention_swa.py, models/deepseek/v4/decode_compressor_ratio128.py, models/deepseek/v4/decode_indexer.py, models/deepseek/v4/hc_post.py, models/deepseek/v4/qkv_proj_rope.py
Attention scatter/top-k (swa), KV cache write (compressor), quantization index (indexer), post-compute (hc_post), and RMS partials (qkv_proj_rope) refactor chunked loops to explicit nested iteration and remove chunked-loop-optimizer/auto-chunk directives across reduction, scatter, and partial-sum stages.
Kimi K2 decode: comprehensive loop restructuring
models/kimi/kimi_k2_decode_draft.py
Q/K/V projection accumulators, RoPE/KV cache and attention head iteration, and expert down-projections are restructured into chunked parallel blocks with inner ranges; chunked-loop-optimizer hint is removed while online softmax and gating-weighted expert combination logic stay functionally equivalent.
MiLM decode: comprehensive scheduling refactoring
models/milm/milm_decode_draft.py
Q/K/V projection, RoPE/KV cache/attention (with batch, KV-head, and query-head nested chunking), output projection, and MLP down-projection loops are refactored to explicit parallel/range nesting across multiple stages; chunked-loop-optimizer is removed while residual addition, online softmax, and MLP gating remain unchanged.
Qwen3 14b L3: large-scale prefill/decode loop restructuring
models/qwen3/14b/qwen3_14b_l3_generate.py
Prefill and decode scopes (Q/K/V projection, RoPE/cache, causal/grouped attention with softmax, SiLU blocks, and LM-head) convert chunked iteration to explicit parallel/range nesting; pl.auto_chunk and chunked_loop_optimizer directives are removed while tensor accumulation, masking, and output assembly remain consistent.
Qwen3 32b: decode and prefill optimization/loop updates
models/qwen3/32b/qwen3_32b_decode.py, models/qwen3/32b/qwen3_32b_prefill_draft.py
Decode removes auto_chunk from out_proj and down_proj scopes while preserving split mode; prefill refactors Q/K/V projection, KV cache/RoPE, attention softmax/matmul, and down-projection loops to explicit two-level chunk/range iteration without prior chunked-loop-optimizer configuration.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related issues

  • hw-native-sys/pypto-lib#314: The PR directly implements the kv_cache_write chunking refactoring in models/deepseek/v4/decode_compressor_ratio128.py by converting pl.parallel(..., chunk=16) with auto_chunk to explicit outer b_chunk parallel plus inner b_idx range loop.
  • hw-native-sys/pypto-lib#156: The PR addresses the same models/deepseek/v3_2/deepseek_v3_2_decode_front.py code region by removing the pl.at(..., optimizations=[pl.auto_chunk]) wrapper from the Q index reduction loop.

Possibly related PRs

  • hw-native-sys/pypto-lib#189: Both PRs modify models/qwen3/32b/qwen3_32b_decode.py's CORE_GROUP scheduling for out_proj_residual and down_proj_residual scopes, adjusting optimization directives in the same decode output/MLP stages.
  • hw-native-sys/pypto-lib#332: Both PRs refactor models/deepseek/v4/qkv_proj_rope.py by removing chunked_loop_optimizer from the attn_norm and qr_rms RMS partial-sum scopes.
  • hw-native-sys/pypto-lib#262: Both PRs refactor models/qwen3/14b/qwen3_14b_l3_generate.py by converting chunked loops to explicit outer parallel chunk blocks plus inner ranges and removing chunked_loop_optimizer directives across Scope 1 Q/K/V and Scope 2 attention/cache stages.

Poem

🐰 Loops are nested, chunks now clear,
Outer parallels drawing near!
Inner ranges march in place,
Optimizers leave no trace.
Kernels hum at their own pace!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'refactor: drop deprecated pl.auto_chunk / chunked_loop_optimizer' directly and accurately summarizes the main change: removing deprecated loop optimization directives across all 22 kernels.
Description check ✅ Passed The description clearly explains the refactoring scope, referencing the motivating upstream change (pypto#1504), specifying the migration pattern, and noting what is preserved (chunk values, name_hint, and pl.split).
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request migrates various examples and model implementations from the deprecated chunked_loop_optimizer and auto_chunk to explicit chunk loops. The review feedback highlights several critical bugs where inner loop bounds are not properly clamped, potentially causing out-of-bounds memory accesses when dimensions are not multiples of the chunk size. Additionally, the reviewer noted inconsistencies such as redundant nested pl.at blocks, variable shadowing, and incomplete refactoring of some parallel loops still using the deprecated chunk= parameter.

resid = pl.assemble(resid, resid_sum, [0, n0])
for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk):
with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]):
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner loop pl.range(nb_chunk, nb_chunk + 1 * chunk, 1) does not account for cases where n_blocks is not a multiple of chunk. In the final iteration of the outer pl.parallel loop, nb will exceed n_blocks, leading to out-of-bounds memory accesses when slicing tensors (e.g., wo at line 59). Use pl.min to clamp the stop condition.

Suggested change
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
for nb in pl.range(nb_chunk, pl.min(nb_chunk + 1 * chunk, n_blocks), 1):

Comment on lines +52 to +53
for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile):
for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These inner loops can exceed the dimensions m and n if they are not multiples of the total chunk size (m_tile * m_chunk and n_tile * n_chunk). This will result in out-of-bounds slicing of tensors a, b, and c. You should use pl.min to ensure the loop indices stay within the valid range.

Suggested change
for mb in pl.range(mb_chunk, mb_chunk + m_tile * m_chunk, m_tile):
for nb in pl.range(nb_chunk, nb_chunk + n_tile * n_chunk, n_tile):
for mb in pl.range(mb_chunk, pl.min(mb_chunk + m_tile * m_chunk, m), m_tile):
for nb in pl.range(nb_chunk, pl.min(nb_chunk + n_tile * n_chunk, n), n_tile):

for b_chunk in pl.parallel(0, BATCH_CFG, 4):
with pl.at(level=pl.Level.CORE_GROUP):
layer_id = pl.tensor.read(layer_id_t, [0])
for b in pl.range(b_chunk, b_chunk + 4):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner loop pl.range(b_chunk, b_chunk + 4) will iterate beyond BATCH_CFG if it is not a multiple of 4. This will cause invalid memory reads from seq_lens at line 146 and hidden_states at line 164. Use pl.min(b_chunk + 4, BATCH_CFG) as the stop condition.

Suggested change
for b in pl.range(b_chunk, b_chunk + 4):
for b in pl.range(b_chunk, pl.min(b_chunk + 4, BATCH_CFG)):


Stage 0 (matmul: attn_out x wo) and Stage 1 (residual add) can be:
- Fused: single pl.at block with chunked_loop_optimizer (mix mode)
- Fused: single pl.at block with auto_chunk (mix mode)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment is being updated to use auto_chunk, but auto_chunk is deprecated and its usage is being removed from the code in this PR. Since the implementation has migrated to explicit chunk loops, the documentation should reflect this change to avoid confusion.

Suggested change
- Fused: single pl.at block with auto_chunk (mix mode)
- Fused: single pl.at block with explicit chunk loops (mix mode)

# Fused Q path (local fusion trial for former incore_0/1):
# directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b
# without materializing full qr_tile.
for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This pl.parallel loop (and others at lines 199 and 374) still uses the deprecated chunk= parameter, but the wrapping pl.at block at line 143 no longer includes the chunked_loop_optimizer. This is inconsistent with the PR's objective of migrating to explicit chunk loops. These inner loops should also be refactored.

v_acc = pl.add(v_acc, pl.matmul(normed_bf16, wv_chunk))
k_proj = pl.assemble(k_proj, pl.cast(k_acc, target_type=pl.BF16), [b0, kv0])
v_proj = pl.assemble(v_proj, pl.cast(v_acc, target_type=pl.BF16), [b0, kv0])
for ob_chunk in pl.parallel(0, KV_OUT_BLOCKS, 8):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop variable ob_chunk shadows the one defined at line 152. While they are in separate loop scopes, using unique names is a better practice to improve code clarity and avoid potential maintenance errors.

Comment on lines +250 to +252
with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_q_proj"):
for ob_chunk in pl.parallel(0, q_out_blocks, 4):
with pl.at(level=pl.Level.CORE_GROUP):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a redundant nested pl.at block structure here. The outer pl.at (line 250) wraps the pl.parallel loop, while the inner pl.at (line 252) is inside it. This is inconsistent with the pattern established in other files in this PR. You should remove the outer pl.at and move its name_hint to the inner block.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request migrates various kernels across example scripts and model implementations (DeepSeek, Kimi, MILM, Qwen) from implicit chunking via chunked_loop_optimizer or auto_chunk to explicit chunked loop structures using nested pl.parallel and pl.range constructs. Feedback highlights potential out-of-bounds access issues in several files where the inner loop upper bounds do not account for cases where the total number of blocks is not a multiple of the chunk size. Additionally, some nested loops were found to still use the deprecated chunk parameter, which is inconsistent with the migration's goal of using explicit chunking.

resid = pl.assemble(resid, resid_sum, [0, n0])
for nb_chunk in pl.parallel(0, n_blocks, 1 * chunk):
with pl.at(level=pl.Level.CORE_GROUP, optimizations=[pl.split(pl.SplitMode.UP_DOWN)]):
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inner loop upper bound nb_chunk + 1 * chunk does not account for cases where n_blocks is not a multiple of chunk. This can lead to out-of-bounds access in the last chunk. It is safer to use pl.min to clamp the bound.

Suggested change
for nb in pl.range(nb_chunk, nb_chunk + 1 * chunk, 1):
for nb in pl.range(nb_chunk, pl.min(n_blocks, nb_chunk + 1 * chunk), 1):

all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

ctx_blocks is dynamic. The inner loop should be guarded with pl.min(ctx_blocks, sb_chunk + SB_BATCH) to prevent out-of-bounds access when ctx_blocks is not a multiple of SB_BATCH. This also applies to other ctx_blocks loops in this file (lines 289, 310).

Suggested change
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
for sb in pl.range(sb_chunk, pl.min(ctx_blocks, sb_chunk + SB_BATCH)):

# Fused Q path (local fusion trial for former incore_0/1):
# directly accumulates q_proj_tile from x -> wq_a -> q_norm -> wq_b
# without materializing full qr_tile.
for ob in pl.parallel(0, Q_OUT_BLOCKS, 1, chunk=8):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This nested pl.parallel loop still uses the deprecated chunk=8 parameter. According to the PR description, all kernels should be migrated to explicit chunk loops. This inconsistency may lead to compilation errors or unoptimized code since the surrounding pl.at block no longer includes the necessary optimizers. Similar instances occur at lines 200 and 374.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/advanced/gemm_eltwise.py`:
- Around line 14-15: Update the mix-mode documentation that still references
`auto_chunk` to describe the current implementation’s explicit chunking: replace
mentions of `auto_chunk` with a description that the fused/mix mode uses an
outer `pl.parallel(...)` to split work into chunks and an inner `pl.range(...)`
loop for per-chunk iteration (i.e., explicit outer parallel + inner range
chunking). Edit the comment blocks in the example that describe "Fused: single
pl.at block with auto_chunk (mix mode)" so they instead mention the outer
`pl.parallel` + inner `pl.range` chunking pattern (also update the analogous
phrasing later in the file where `auto_chunk` appears).

In `@models/qwen3/14b/qwen3_14b_l3_generate.py`:
- Around line 460-463: The chunked inner loop over sb (using pl.range(sb_chunk,
sb_chunk + SB_BATCH)) can run past the valid ctx_blocks when ctx_blocks <
SB_BATCH or not divisible; change the iteration to compute an end guard (e.g.,
sb_end = min(sb_chunk + SB_BATCH, ctx_blocks)) and iterate pl.range(sb_chunk,
sb_end) (and use sb_end wherever the chunk end is assumed) in the prefill
attention loops that contain pl.parallel, pl.at, pl.range and the
block_table_idx = b * max_blocks_per_seq + sb calculation to avoid
reading/writing beyond valid attention blocks.
- Around line 1060-1063: The decode-attention loop that iterates sb from
pl.range(sb_chunk, sb_chunk + SB_BATCH) can run past ctx_blocks for the final
partial chunk; update the loop to cap the upper bound or guard each iteration so
sb < ctx_blocks to avoid indexing invalid block_table/cache rows and temporary
buffers (e.g., compute end = min(sb_chunk + SB_BATCH, ctx_blocks) and use
pl.range(sb_chunk, end) or add an if sb < ctx_blocks inside the inner loop).
Apply this fix to the occurrences around pl.parallel / pl.at / pl.range using
sb_chunk, SB_BATCH, ctx_blocks, and block_table_idx (the same pattern appears at
the other two spots mentioned).

In `@models/qwen3/32b/qwen3_32b_prefill_draft.py`:
- Around line 277-280: The sb loop can iterate past ctx_blocks when
sb_chunk+SB_BATCH exceeds ctx_blocks; update the chunked loops using
pl.range(sb_chunk, min(sb_chunk + SB_BATCH, ctx_blocks)) or add an inner
conditional guard (if sb >= ctx_blocks: break/continue) before computing s0 = sb
* SEQ_TILE to prevent out-of-bounds access; apply the same fix to the other
chunked sb loops that use SB_BATCH, pl.range and compute s0 so all prefill
attention stages check bounds against ctx_blocks.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 72991c5a-1da3-4703-8c60-e6b8498d9d8a

📥 Commits

Reviewing files that changed from the base of the PR and between cddcc84 and 43ffc48.

📒 Files selected for processing (21)
  • examples/advanced/gemm_eltwise.py
  • examples/beginner/hello_world.py
  • examples/beginner/matmul.py
  • examples/intermediate/gemm.py
  • examples/intermediate/layer_norm.py
  • examples/intermediate/rms_norm.py
  • examples/intermediate/rope.py
  • examples/intermediate/softmax.py
  • models/deepseek/v3_2/deepseek_v3_2_decode_front.py
  • models/deepseek/v3_2/deepseek_v3_2_prefill_front_draft.py
  • models/deepseek/v4/decode_attention_hca.py
  • models/deepseek/v4/decode_attention_swa.py
  • models/deepseek/v4/decode_compressor_ratio128.py
  • models/deepseek/v4/decode_indexer.py
  • models/deepseek/v4/hc_post.py
  • models/deepseek/v4/qkv_proj_rope.py
  • models/kimi/kimi_k2_decode_draft.py
  • models/milm/milm_decode_draft.py
  • models/qwen3/14b/qwen3_14b_l3_generate.py
  • models/qwen3/32b/qwen3_32b_decode.py
  • models/qwen3/32b/qwen3_32b_prefill_draft.py

Comment on lines +14 to 15
- Fused: single pl.at block with auto_chunk (mix mode)
- Split: separate pl.at blocks for each stage (split mode)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update mix-mode docs to match explicit chunk loops.

Line 14 and Line 39 still mention auto_chunk, but this implementation now uses explicit outer pl.parallel + inner pl.range chunking.

✏️ Suggested doc fix
-  - Fused: single pl.at block with auto_chunk (mix mode)
+  - Fused: single pl.at block with explicit chunk loops (mix mode)
@@
-    """Build fused matmul + elementwise program with auto_chunk."""
+    """Build fused matmul + elementwise program with explicit chunk loops."""

Also applies to: 39-39

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/advanced/gemm_eltwise.py` around lines 14 - 15, Update the mix-mode
documentation that still references `auto_chunk` to describe the current
implementation’s explicit chunking: replace mentions of `auto_chunk` with a
description that the fused/mix mode uses an outer `pl.parallel(...)` to split
work into chunks and an inner `pl.range(...)` loop for per-chunk iteration
(i.e., explicit outer parallel + inner range chunking). Edit the comment blocks
in the example that describe "Fused: single pl.at block with auto_chunk (mix
mode)" so they instead mention the outer `pl.parallel` + inner `pl.range`
chunking pattern (also update the analogous phrasing later in the file where
`auto_chunk` appears).

Comment on lines +460 to +463
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
block_table_idx = b * max_blocks_per_seq + sb
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Add a tail guard for chunked sb iteration in prefill attention stages.

Line 462 / Line 483 / Line 515 iterate up to sb_chunk + SB_BATCH unconditionally; when ctx_blocks < SB_BATCH (or not divisible by SB_BATCH) this reads/writes past valid attention blocks.

💡 Suggested fix
 with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_qk_matmul"):
     for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
         with pl.at(level=pl.Level.CORE_GROUP):
             for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+                if sb < ctx_blocks:
                     block_table_idx = b * max_blocks_per_seq + sb
                     pbid = pl.cast(
                         pl.tensor.read(block_table, [block_table_idx]), pl.INDEX
                     )
                     cache_row0 = (pbid * num_kv_heads + kvh) * BLOCK_SIZE
                     k_tile = pl.slice(
                         k_cache_all,
                         [SEQ_TILE, head_dim],
                         [layer_off_cache + cache_row0, 0],
                     )
                     raw_scores = pl.matmul(
                         q_padded, k_tile, b_trans=True, out_dtype=pl.FP32
                     )
                     all_raw_scores = pl.assemble(
                         all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0]
                     )

 with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_softmax"):
     for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
         with pl.at(level=pl.Level.CORE_GROUP):
             for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+                if sb < ctx_blocks:
                     s0 = sb * SEQ_TILE
                     valid_len = pl.min(SEQ_TILE, ctx_len - s0)
                     ...

 with pl.at(level=pl.Level.CORE_GROUP, name_hint="prefill_sv_matmul"):
     for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
         with pl.at(level=pl.Level.CORE_GROUP):
             for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+                if sb < ctx_blocks:
                     block_table_idx = b * max_blocks_per_seq + sb
                     ...

Also applies to: 481-484, 513-516

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 460 - 463, The
chunked inner loop over sb (using pl.range(sb_chunk, sb_chunk + SB_BATCH)) can
run past the valid ctx_blocks when ctx_blocks < SB_BATCH or not divisible;
change the iteration to compute an end guard (e.g., sb_end = min(sb_chunk +
SB_BATCH, ctx_blocks)) and iterate pl.range(sb_chunk, sb_end) (and use sb_end
wherever the chunk end is assumed) in the prefill attention loops that contain
pl.parallel, pl.at, pl.range and the block_table_idx = b * max_blocks_per_seq +
sb calculation to avoid reading/writing beyond valid attention blocks.

Comment on lines +1060 to +1063
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
block_table_idx = block_table_base + sb
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Decode attention has the same out-of-range chunk-tail bug on sb.

Line 1062 / Line 1081 / Line 1115 loop to sb_chunk + SB_BATCH without sb < ctx_blocks checks, so partial chunks can index invalid block-table/cache regions and invalid temporary rows.

💡 Suggested fix
 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 block_table_idx = block_table_base + sb
                 pbid = pl.cast(pl.tensor.read(block_table, [block_table_idx]), pl.INDEX)
                 ...

 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 s0 = sb * BLOCK_SIZE
                 valid_len = pl.min(BLOCK_SIZE, ctx_len - s0)
                 ...

 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 block_table_idx = block_table_base + sb
                 ...

Also applies to: 1079-1082, 1113-1116

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/14b/qwen3_14b_l3_generate.py` around lines 1060 - 1063, The
decode-attention loop that iterates sb from pl.range(sb_chunk, sb_chunk +
SB_BATCH) can run past ctx_blocks for the final partial chunk; update the loop
to cap the upper bound or guard each iteration so sb < ctx_blocks to avoid
indexing invalid block_table/cache rows and temporary buffers (e.g., compute end
= min(sb_chunk + SB_BATCH, ctx_blocks) and use pl.range(sb_chunk, end) or add an
if sb < ctx_blocks inside the inner loop). Apply this fix to the occurrences
around pl.parallel / pl.at / pl.range using sb_chunk, SB_BATCH, ctx_blocks, and
block_table_idx (the same pattern appears at the other two spots mentioned).

Comment on lines +277 to +280
for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
with pl.at(level=pl.Level.CORE_GROUP):
for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
s0 = sb * SEQ_TILE
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Chunked sb loops need a bounds check in prefill attention stages.

Line 279 / Line 289 / Line 310 currently execute beyond ctx_blocks for non-full chunks (and typically here since SB_BATCH is much larger than active context blocks), which risks invalid cache/block accesses and invalid temporary-tile indexing.

💡 Suggested fix
 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 s0 = sb * SEQ_TILE
                 cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
                 k_tile = pl.slice(k_cache, [SEQ_TILE, head_dim], [cache_row0, 0])
                 raw_scores = pl.matmul(q_padded, k_tile, b_trans=True, out_dtype=pl.FP32)
                 all_raw_scores = pl.assemble(all_raw_scores, raw_scores, [sb * Q_HEAD_PAD, 0])

 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 s0 = sb * SEQ_TILE
                 valid_len = pl.min(SEQ_TILE, ctx_len - s0)
                 ...

 for sb_chunk in pl.parallel(0, ctx_blocks, SB_BATCH):
     with pl.at(level=pl.Level.CORE_GROUP):
         for sb in pl.range(sb_chunk, sb_chunk + SB_BATCH):
+            if sb < ctx_blocks:
                 s0 = sb * SEQ_TILE
                 cache_row0 = b * num_kv_heads * max_seq + kvh * max_seq + s0
                 ...

Also applies to: 287-290, 308-311

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/qwen3/32b/qwen3_32b_prefill_draft.py` around lines 277 - 280, The sb
loop can iterate past ctx_blocks when sb_chunk+SB_BATCH exceeds ctx_blocks;
update the chunked loops using pl.range(sb_chunk, min(sb_chunk + SB_BATCH,
ctx_blocks)) or add an inner conditional guard (if sb >= ctx_blocks:
break/continue) before computing s0 = sb * SEQ_TILE to prevent out-of-bounds
access; apply the same fix to the other chunked sb loops that use SB_BATCH,
pl.range and compute s0 so all prefill attention stages check bounds against
ctx_blocks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant