[Example] Optimize topk selector for B=1 and large S cases#2108
[Example] Optimize topk selector for B=1 and large S cases#2108Rachmanino wants to merge 3 commits intotile-ai:mainfrom
B=1 and large S cases#2108Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughA new file introduces a 3-stage GPU topk selector optimized for large sequence lengths and batch size 1. Stage 1 builds a radix histogram, Stage 2 computes thresholds and collects candidates, and Stage 3 performs radix refinement rounds to finalize top-k indices. A Python wrapper and test/benchmark harness are included. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
examples/deepseek_v32/topk_selector_3stages.py (3)
371-408: Test/benchmark harness nits.Optional cleanups in
test_topk_selector:
torch.zeros(batch, dtype=torch.int32).cuda()/torch.ones(...).cuda() * seq_lenallocates on CPU then copies. Usedevice='cuda'directly.print(f'{input.shape=}')and the bareprint(indexes)/print(indexes_ref)on a 2048-wide tensor are noisy; either gate behind a verbose flag or drop.set(ref_np)works becauseref_npis 1Dint32, butset_ref - set_trtwould be a more useful diagnostic than just intersection size.run_regression_perfis defined but never invoked in__main__; either wire it up or remove if dead.No correctness impact; ignore if you'd rather keep this as a scratch harness.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 371 - 408, In test_topk_selector, avoid CPU->GPU copies by creating starts and ends directly on CUDA (use device='cuda' when constructing tensors for starts/ends), remove or gate noisy prints (the f'{input.shape=}' and printing indexes/indexes_ref) behind a verbose flag or delete them, replace the intersection diagnostic with a more useful difference check (e.g., compute set_ref - set_trt to see missed items) instead of only reporting intersection size, and either invoke run_regression_perf from the __main__ block or delete that unused function to remove dead code.
287-287:roundshadows the Python builtin.Ruff flag (A001). Trivial rename keeps tooling clean and avoids a footgun if any host‑side helper gets pulled in later that relies on the builtin.
♻️ Proposed fix
- for round in T.serial(4): + for r in T.serial(4): if l_new_topk <= 0: break - r_idx = round % 2 + r_idx = r % 2 l_start_pos = topk - l_new_topk @@ - l_bin_id32 = T.cast( - ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) - >> (24 - round * 8)) & 0xFF), T.int32) + l_bin_id32 = T.cast( + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) + >> (24 - r * 8)) & 0xFF), T.int32) @@ - if l_bin_id32 > l_threshold_bin_id: + if l_bin_id32 > l_threshold_bin_id: pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: - if round == 3: + if r == 3:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/topk_selector_3stages.py` at line 287, The loop variable name "round" shadows the Python builtin (Ruff A001); rename the loop variable in the for loop that reads "for round in T.serial(4):" to a non-built-in identifier (e.g., "round_idx" or "rnd") and update all references inside that loop body (and any nested scopes) to the new name so functionality is unchanged and the linter warning is cleared.
354-368: Per‑invocation buffer allocation and JIT lookup inflate measured latency.Five fresh CUDA tensors (
indexes,global_histogram,direct_counter,candidate_idx,candidate_count) are allocated on every call, three of them viatorch.zeroswhich adds acudaMemsetAsyncto the critical path. On top of that,tl_topk_stage{1,2,3}_impl(...)is invoked from insidetl_topkon every call, paying a JIT cache lookup (and a re‑specialization ontopkfor stages 2/3) per benchmark iteration. For a kernel claiming ~30 µs latency, both costs are non‑trivial and skew the reported numbers.Consider hoisting the JIT impls to module level (or memoizing on
topk) and either preallocating the scratch buffers in the caller or caching them keyed by(batch, topk, device):♻️ Sketch
-def tl_topk(input, starts, ends, topk): - batch, seq_len = input.shape - indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device) - global_histogram = torch.zeros(batch, RADIX, dtype=torch.int32, device=input.device) - direct_counter = torch.zeros(batch, RADIX + 1, dtype=torch.int32, device=input.device) - candidate_idx = torch.empty(batch, SMEM_INPUT_SIZE, dtype=torch.int32, device=input.device) - candidate_count = torch.zeros(batch, dtype=torch.int32, device=input.device) - - stage1 = tl_topk_stage1_impl() - stage2 = tl_topk_stage2_impl(topk) - stage3 = tl_topk_stage3_impl(topk) +_stage1_cache = None +_stage2_cache: dict = {} +_stage3_cache: dict = {} + +def _get_stages(topk): + global _stage1_cache + if _stage1_cache is None: + _stage1_cache = tl_topk_stage1_impl() + if topk not in _stage2_cache: + _stage2_cache[topk] = tl_topk_stage2_impl(topk) + _stage3_cache[topk] = tl_topk_stage3_impl(topk) + return _stage1_cache, _stage2_cache[topk], _stage3_cache[topk] + +def tl_topk(input, starts, ends, topk, *, scratch=None): + batch, _ = input.shape + if scratch is None: + scratch = ( + torch.zeros(batch, topk, dtype=torch.int32, device=input.device), + torch.zeros(batch, RADIX, dtype=torch.int32, device=input.device), + torch.zeros(batch, RADIX + 1, dtype=torch.int32, device=input.device), + torch.empty(batch, SMEM_INPUT_SIZE, dtype=torch.int32, device=input.device), + torch.zeros(batch, dtype=torch.int32, device=input.device), + ) + indexes, global_histogram, direct_counter, candidate_idx, candidate_count = scratch + stage1, stage2, stage3 = _get_stages(topk)Also note: if a caller does pass in reused scratch buffers,
global_histogram,direct_counter,candidate_count, andindexesneed to be zeroed before each call (stage 1 atomic_adds intoglobal_histogram; stage 2 atomic_adds intodirect_counterandcandidate_count). Worth documenting on the public API.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 354 - 368, The current tl_topk function allocates five CUDA tensors (indexes, global_histogram, direct_counter, candidate_idx, candidate_count) and calls tl_topk_stage1_impl / tl_topk_stage2_impl / tl_topk_stage3_impl on every invocation, causing per-call cudaMemsetAsync and JIT lookup overhead; fix by hoisting/memoizing the JIT kernels (tl_topk_stage1_impl, tl_topk_stage2_impl, tl_topk_stage3_impl) to module scope or caching them keyed by topk, and change tl_topk to accept optional preallocated scratch buffers (indexes, global_histogram, direct_counter, candidate_idx, candidate_count) or pull them from a cache keyed by (batch, topk, device); when reusing buffers ensure you zero global_histogram, direct_counter, candidate_count and indexes before each call (since stage1/2 use atomic_add), and add API docs/comments noting the buffer reuse/zeroing contract.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/deepseek_v32/topk_selector_3stages.py`:
- Around line 167-183: Initialize the shared scalar s_threshold_bin_id before
the per-thread reduction so it is always valid if no thread writes it: set
s_threshold_bin_id[0] = 0 (or another safe default) prior to the "if tx <
RADIX:" reduction loop, call T.sync_threads() after that initialization, then
proceed with the existing reduction and the subsequent read into
l_threshold_bin_id and use in l_new_topk; apply the same initialization + sync
pattern in the stage-3 analogous block (the places referencing
s_threshold_bin_id, tx, s_histogram, l_new_topk, and l_threshold_bin_id).
- Around line 196-202: The bug is that candidate_count is incremented
unconditionally in the producer (atomic_add on candidate_count[bx]) which can
exceed the candidate buffer and cause stage_3 to read past s_input_idx; fix by
either (A) clamping the consumer read of l_num_input to at most SMEM_INPUT_SIZE
when stage 3 reads candidate_count (so l_num_input = min(candidate_count[bx],
SMEM_INPUT_SIZE) before using it / before writing s_num_input[0]) or (B) change
the producer logic around candidate_count/candidate_idx to only increment/write
when the measured pos < SMEM_INPUT_SIZE (i.e., perform the atomic_add only if
you will store, or do a compare-and-add pattern so candidate_count never grows
beyond SMEM_INPUT_SIZE); for full safety apply both: cap candidate_count on the
producer and also clamp in the consumer; relevant symbols: candidate_count,
candidate_idx, SMEM_INPUT_SIZE, stage 3 read of l_num_input / s_num_input.
---
Nitpick comments:
In `@examples/deepseek_v32/topk_selector_3stages.py`:
- Around line 371-408: In test_topk_selector, avoid CPU->GPU copies by creating
starts and ends directly on CUDA (use device='cuda' when constructing tensors
for starts/ends), remove or gate noisy prints (the f'{input.shape=}' and
printing indexes/indexes_ref) behind a verbose flag or delete them, replace the
intersection diagnostic with a more useful difference check (e.g., compute
set_ref - set_trt to see missed items) instead of only reporting intersection
size, and either invoke run_regression_perf from the __main__ block or delete
that unused function to remove dead code.
- Line 287: The loop variable name "round" shadows the Python builtin (Ruff
A001); rename the loop variable in the for loop that reads "for round in
T.serial(4):" to a non-built-in identifier (e.g., "round_idx" or "rnd") and
update all references inside that loop body (and any nested scopes) to the new
name so functionality is unchanged and the linter warning is cleared.
- Around line 354-368: The current tl_topk function allocates five CUDA tensors
(indexes, global_histogram, direct_counter, candidate_idx, candidate_count) and
calls tl_topk_stage1_impl / tl_topk_stage2_impl / tl_topk_stage3_impl on every
invocation, causing per-call cudaMemsetAsync and JIT lookup overhead; fix by
hoisting/memoizing the JIT kernels (tl_topk_stage1_impl, tl_topk_stage2_impl,
tl_topk_stage3_impl) to module scope or caching them keyed by topk, and change
tl_topk to accept optional preallocated scratch buffers (indexes,
global_histogram, direct_counter, candidate_idx, candidate_count) or pull them
from a cache keyed by (batch, topk, device); when reusing buffers ensure you
zero global_histogram, direct_counter, candidate_count and indexes before each
call (since stage1/2 use atomic_add), and add API docs/comments noting the
buffer reuse/zeroing contract.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 621c2ea6-dfa7-4b08-ae96-b0a49f70cd32
📒 Files selected for processing (1)
examples/deepseek_v32/topk_selector_3stages.py
| if tx < RADIX: | ||
| for i in T.serial(8): | ||
| offset = 1 << i | ||
| T.sync_threads(3, RADIX) | ||
| if tx < RADIX - offset: | ||
| l_val = s_histogram[tx] + s_histogram[tx + offset] | ||
| T.sync_threads(3, RADIX) | ||
| if tx < RADIX - offset: | ||
| s_histogram[tx] = l_val | ||
|
|
||
| T.sync_threads(3, RADIX) | ||
| if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: | ||
| s_threshold_bin_id[0] = tx | ||
| T.sync_threads() | ||
| l_threshold_bin_id = s_threshold_bin_id[0] | ||
| l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1] | ||
| T.sync_threads() |
There was a problem hiding this comment.
s_threshold_bin_id is read without a guaranteed write.
If no tx satisfies s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk — e.g. when the total number of valid elements is ≤ topk, or when all valid elements fall in a single bin and s_histogram[0] == l_new_topk — s_threshold_bin_id[0] is never written and the subsequent reads at line 181 / line 269 pick up uninitialized shared memory. The intended (B=1, S=131072) shape doesn't hit this, but it's worth a tiny initialization to keep the kernel robust against seq_len < topk callers and starts/ends ranges that exclude most of the row.
🛡️ Proposed fix (apply analogously in stage 3 around line 267)
T.fill(s_histogram, 0)
T.sync_threads()
if tx < RADIX:
s_histogram[tx] = global_histogram[bx, tx]
+ if tx == 0:
+ s_threshold_bin_id[0] = 0
T.sync_threads()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 167 - 183,
Initialize the shared scalar s_threshold_bin_id before the per-thread reduction
so it is always valid if no thread writes it: set s_threshold_bin_id[0] = 0 (or
another safe default) prior to the "if tx < RADIX:" reduction loop, call
T.sync_threads() after that initialization, then proceed with the existing
reduction and the subsequent read into l_threshold_bin_id and use in l_new_topk;
apply the same initialization + sync pattern in the stage-3 analogous block (the
places referencing s_threshold_bin_id, tx, s_histogram, l_new_topk, and
l_threshold_bin_id).
| pos = T.atomic_add( | ||
| direct_counter[bx, l_bin_id32 + 1], 1, return_prev=True) | ||
| index[bx, l_bin_offset + pos] = input_idx | ||
| elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: | ||
| pos = T.atomic_add(candidate_count[bx], 1, return_prev=True) | ||
| if pos < SMEM_INPUT_SIZE: | ||
| candidate_idx[bx, pos] = input_idx |
There was a problem hiding this comment.
Candidate count is incremented unconditionally — stage 3 can read past s_input_idx.
T.atomic_add(candidate_count[bx], 1, return_prev=True) fires before the pos < SMEM_INPUT_SIZE guard, so when the threshold bucket holds more than SMEM_INPUT_SIZE elements (e.g., highly degenerate inputs where many values share the same fp16 high‑8 bits — large duplicate runs, constant slabs, saturated/inf‑clipped activations) candidate_count[bx] ends up larger than the candidate buffer. Stage 3 then loads that uncapped value as l_num_input (line 277, propagated to s_num_input[0] at line 283) and the round‑0 loop at lines 301–306 indexes s_input_idx[r_idx, s * BLOCK_SIZE + tx] with tx + s * BLOCK_SIZE going up to l_num_input - 1, which is a shared‑memory OOB read past the [2, SMEM_INPUT_SIZE] buffer.
Random inputs from torch.randn won't trigger this, but it's a latent correctness/safety bug for adversarial or quantized inputs. Easiest fix is to clamp on the consumer side (stage 3 line 277):
🛡️ Proposed clamp in stage 3
- l_num_input = candidate_count[bx]
+ l_num_input = candidate_count[bx]
+ if l_num_input > SMEM_INPUT_SIZE:
+ l_num_input = SMEM_INPUT_SIZEOr guard the producer so candidate_count stops growing once the buffer is full:
🛡️ Alternative: only count what fits
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
- pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
- if pos < SMEM_INPUT_SIZE:
- candidate_idx[bx, pos] = input_idx
+ if candidate_count[bx] < SMEM_INPUT_SIZE:
+ pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
+ if pos < SMEM_INPUT_SIZE:
+ candidate_idx[bx, pos] = input_idxNote that the second variant has a benign TOCTOU window so a couple of extra atomic_adds may sneak in — pair it with the stage‑3 clamp to be fully safe. The comment at lines 31-34 ("Assumes the threshold bucket size after the first pass is < 4K elements") should also be tightened or replaced with a hard cap.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pos = T.atomic_add( | |
| direct_counter[bx, l_bin_id32 + 1], 1, return_prev=True) | |
| index[bx, l_bin_offset + pos] = input_idx | |
| elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: | |
| pos = T.atomic_add(candidate_count[bx], 1, return_prev=True) | |
| if pos < SMEM_INPUT_SIZE: | |
| candidate_idx[bx, pos] = input_idx | |
| elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: | |
| if candidate_count[bx] < SMEM_INPUT_SIZE: | |
| pos = T.atomic_add(candidate_count[bx], 1, return_prev=True) | |
| if pos < SMEM_INPUT_SIZE: | |
| candidate_idx[bx, pos] = input_idx |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 196 - 202, The
bug is that candidate_count is incremented unconditionally in the producer
(atomic_add on candidate_count[bx]) which can exceed the candidate buffer and
cause stage_3 to read past s_input_idx; fix by either (A) clamping the consumer
read of l_num_input to at most SMEM_INPUT_SIZE when stage 3 reads
candidate_count (so l_num_input = min(candidate_count[bx], SMEM_INPUT_SIZE)
before using it / before writing s_num_input[0]) or (B) change the producer
logic around candidate_count/candidate_idx to only increment/write when the
measured pos < SMEM_INPUT_SIZE (i.e., perform the atomic_add only if you will
store, or do a compare-and-add pattern so candidate_count never grows beyond
SMEM_INPUT_SIZE); for full safety apply both: cap candidate_count on the
producer and also clamp in the consumer; relevant symbols: candidate_count,
candidate_idx, SMEM_INPUT_SIZE, stage 3 read of l_num_input / s_num_input.
opt: stage-2/3 perf fixes for 3-stage topk selector
This pr introduces a topk variant that launchs 3 kernels to split task on different SMs, significantly reducing latency to
~30us@ (B, S) = (1, 131072) shape on H200.Still under optimization.
Summary by CodeRabbit