Skip to content

[Example] Optimize topk selector for B=1 and large S cases#2108

Closed
Rachmanino wants to merge 3 commits intotile-ai:mainfrom
Rachmanino:topk_splitk
Closed

[Example] Optimize topk selector for B=1 and large S cases#2108
Rachmanino wants to merge 3 commits intotile-ai:mainfrom
Rachmanino:topk_splitk

Conversation

@Rachmanino
Copy link
Copy Markdown
Collaborator

@Rachmanino Rachmanino commented Apr 27, 2026

This pr introduces a topk variant that launchs 3 kernels to split task on different SMs, significantly reducing latency to ~30us @ (B, S) = (1, 131072) shape on H200.

Still under optimization.

Summary by CodeRabbit

  • New Features
    • GPU-accelerated top-k selection implementation providing efficient ranking for large-scale data processing.
    • Optimized multi-stage architecture for improved throughput and performance.
    • Includes comprehensive testing harness and performance benchmarking tools for validation and analysis.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 27, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 704bc5b0-5cfb-45b1-a5e0-5025c5a18da2

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

A new file introduces a 3-stage GPU topk selector optimized for large sequence lengths and batch size 1. Stage 1 builds a radix histogram, Stage 2 computes thresholds and collects candidates, and Stage 3 performs radix refinement rounds to finalize top-k indices. A Python wrapper and test/benchmark harness are included.

Changes

Cohort / File(s) Summary
TileLang GPU TopK Implementation
examples/deepseek_v32/topk_selector_3stages.py
New multi-stage topk selector with 3 kernel stages (histogram building, threshold computation & candidate collection, radix refinement), Python wrapper tl_topk, correctness test, and performance benchmark harness against torch.topk.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 Three stages of sorting, so swift and so bright,
Radix histograms gleaming in GPU light,
Thresholds computed, candidates refined,
A topk selector brilliantly designed! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main optimization: a topk selector optimized for B=1 (batch size 1) and large S (sequence length) cases, which aligns perfectly with the changeset introducing a 3-stage GPU topk selector for these exact parameters.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@Rachmanino Rachmanino marked this pull request as draft April 27, 2026 14:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
examples/deepseek_v32/topk_selector_3stages.py (3)

371-408: Test/benchmark harness nits.

Optional cleanups in test_topk_selector:

  • torch.zeros(batch, dtype=torch.int32).cuda() / torch.ones(...).cuda() * seq_len allocates on CPU then copies. Use device='cuda' directly.
  • print(f'{input.shape=}') and the bare print(indexes) / print(indexes_ref) on a 2048-wide tensor are noisy; either gate behind a verbose flag or drop.
  • set(ref_np) works because ref_np is 1D int32, but set_ref - set_trt would be a more useful diagnostic than just intersection size.
  • run_regression_perf is defined but never invoked in __main__; either wire it up or remove if dead.

No correctness impact; ignore if you'd rather keep this as a scratch harness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 371 - 408, In
test_topk_selector, avoid CPU->GPU copies by creating starts and ends directly
on CUDA (use device='cuda' when constructing tensors for starts/ends), remove or
gate noisy prints (the f'{input.shape=}' and printing indexes/indexes_ref)
behind a verbose flag or delete them, replace the intersection diagnostic with a
more useful difference check (e.g., compute set_ref - set_trt to see missed
items) instead of only reporting intersection size, and either invoke
run_regression_perf from the __main__ block or delete that unused function to
remove dead code.

287-287: round shadows the Python builtin.

Ruff flag (A001). Trivial rename keeps tooling clean and avoids a footgun if any host‑side helper gets pulled in later that relies on the builtin.

♻️ Proposed fix
-            for round in T.serial(4):
+            for r in T.serial(4):
                 if l_new_topk <= 0:
                     break
 
-                r_idx = round % 2
+                r_idx = r % 2
                 l_start_pos = topk - l_new_topk
@@
-                        l_bin_id32 = T.cast(
-                            ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]])
-                              >> (24 - round * 8)) & 0xFF), T.int32)
+                        l_bin_id32 = T.cast(
+                            ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]])
+                              >> (24 - r * 8)) & 0xFF), T.int32)
@@
-                        if l_bin_id32 > l_threshold_bin_id:
+                        if l_bin_id32 > l_threshold_bin_id:
                             pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1,
                                                return_prev=True) + l_start_pos
                             index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
                         elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
-                            if round == 3:
+                            if r == 3:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector_3stages.py` at line 287, The loop
variable name "round" shadows the Python builtin (Ruff A001); rename the loop
variable in the for loop that reads "for round in T.serial(4):" to a
non-built-in identifier (e.g., "round_idx" or "rnd") and update all references
inside that loop body (and any nested scopes) to the new name so functionality
is unchanged and the linter warning is cleared.

354-368: Per‑invocation buffer allocation and JIT lookup inflate measured latency.

Five fresh CUDA tensors (indexes, global_histogram, direct_counter, candidate_idx, candidate_count) are allocated on every call, three of them via torch.zeros which adds a cudaMemsetAsync to the critical path. On top of that, tl_topk_stage{1,2,3}_impl(...) is invoked from inside tl_topk on every call, paying a JIT cache lookup (and a re‑specialization on topk for stages 2/3) per benchmark iteration. For a kernel claiming ~30 µs latency, both costs are non‑trivial and skew the reported numbers.

Consider hoisting the JIT impls to module level (or memoizing on topk) and either preallocating the scratch buffers in the caller or caching them keyed by (batch, topk, device):

♻️ Sketch
-def tl_topk(input, starts, ends, topk):
-    batch, seq_len = input.shape
-    indexes = torch.zeros(batch, topk, dtype=torch.int32, device=input.device)
-    global_histogram = torch.zeros(batch, RADIX, dtype=torch.int32, device=input.device)
-    direct_counter = torch.zeros(batch, RADIX + 1, dtype=torch.int32, device=input.device)
-    candidate_idx = torch.empty(batch, SMEM_INPUT_SIZE, dtype=torch.int32, device=input.device)
-    candidate_count = torch.zeros(batch, dtype=torch.int32, device=input.device)
-
-    stage1 = tl_topk_stage1_impl()
-    stage2 = tl_topk_stage2_impl(topk)
-    stage3 = tl_topk_stage3_impl(topk)
+_stage1_cache = None
+_stage2_cache: dict = {}
+_stage3_cache: dict = {}
+
+def _get_stages(topk):
+    global _stage1_cache
+    if _stage1_cache is None:
+        _stage1_cache = tl_topk_stage1_impl()
+    if topk not in _stage2_cache:
+        _stage2_cache[topk] = tl_topk_stage2_impl(topk)
+        _stage3_cache[topk] = tl_topk_stage3_impl(topk)
+    return _stage1_cache, _stage2_cache[topk], _stage3_cache[topk]
+
+def tl_topk(input, starts, ends, topk, *, scratch=None):
+    batch, _ = input.shape
+    if scratch is None:
+        scratch = (
+            torch.zeros(batch, topk, dtype=torch.int32, device=input.device),
+            torch.zeros(batch, RADIX, dtype=torch.int32, device=input.device),
+            torch.zeros(batch, RADIX + 1, dtype=torch.int32, device=input.device),
+            torch.empty(batch, SMEM_INPUT_SIZE, dtype=torch.int32, device=input.device),
+            torch.zeros(batch, dtype=torch.int32, device=input.device),
+        )
+    indexes, global_histogram, direct_counter, candidate_idx, candidate_count = scratch
+    stage1, stage2, stage3 = _get_stages(topk)

Also note: if a caller does pass in reused scratch buffers, global_histogram, direct_counter, candidate_count, and indexes need to be zeroed before each call (stage 1 atomic_adds into global_histogram; stage 2 atomic_adds into direct_counter and candidate_count). Worth documenting on the public API.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 354 - 368, The
current tl_topk function allocates five CUDA tensors (indexes, global_histogram,
direct_counter, candidate_idx, candidate_count) and calls tl_topk_stage1_impl /
tl_topk_stage2_impl / tl_topk_stage3_impl on every invocation, causing per-call
cudaMemsetAsync and JIT lookup overhead; fix by hoisting/memoizing the JIT
kernels (tl_topk_stage1_impl, tl_topk_stage2_impl, tl_topk_stage3_impl) to
module scope or caching them keyed by topk, and change tl_topk to accept
optional preallocated scratch buffers (indexes, global_histogram,
direct_counter, candidate_idx, candidate_count) or pull them from a cache keyed
by (batch, topk, device); when reusing buffers ensure you zero global_histogram,
direct_counter, candidate_count and indexes before each call (since stage1/2 use
atomic_add), and add API docs/comments noting the buffer reuse/zeroing contract.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/deepseek_v32/topk_selector_3stages.py`:
- Around line 167-183: Initialize the shared scalar s_threshold_bin_id before
the per-thread reduction so it is always valid if no thread writes it: set
s_threshold_bin_id[0] = 0 (or another safe default) prior to the "if tx <
RADIX:" reduction loop, call T.sync_threads() after that initialization, then
proceed with the existing reduction and the subsequent read into
l_threshold_bin_id and use in l_new_topk; apply the same initialization + sync
pattern in the stage-3 analogous block (the places referencing
s_threshold_bin_id, tx, s_histogram, l_new_topk, and l_threshold_bin_id).
- Around line 196-202: The bug is that candidate_count is incremented
unconditionally in the producer (atomic_add on candidate_count[bx]) which can
exceed the candidate buffer and cause stage_3 to read past s_input_idx; fix by
either (A) clamping the consumer read of l_num_input to at most SMEM_INPUT_SIZE
when stage 3 reads candidate_count (so l_num_input = min(candidate_count[bx],
SMEM_INPUT_SIZE) before using it / before writing s_num_input[0]) or (B) change
the producer logic around candidate_count/candidate_idx to only increment/write
when the measured pos < SMEM_INPUT_SIZE (i.e., perform the atomic_add only if
you will store, or do a compare-and-add pattern so candidate_count never grows
beyond SMEM_INPUT_SIZE); for full safety apply both: cap candidate_count on the
producer and also clamp in the consumer; relevant symbols: candidate_count,
candidate_idx, SMEM_INPUT_SIZE, stage 3 read of l_num_input / s_num_input.

---

Nitpick comments:
In `@examples/deepseek_v32/topk_selector_3stages.py`:
- Around line 371-408: In test_topk_selector, avoid CPU->GPU copies by creating
starts and ends directly on CUDA (use device='cuda' when constructing tensors
for starts/ends), remove or gate noisy prints (the f'{input.shape=}' and
printing indexes/indexes_ref) behind a verbose flag or delete them, replace the
intersection diagnostic with a more useful difference check (e.g., compute
set_ref - set_trt to see missed items) instead of only reporting intersection
size, and either invoke run_regression_perf from the __main__ block or delete
that unused function to remove dead code.
- Line 287: The loop variable name "round" shadows the Python builtin (Ruff
A001); rename the loop variable in the for loop that reads "for round in
T.serial(4):" to a non-built-in identifier (e.g., "round_idx" or "rnd") and
update all references inside that loop body (and any nested scopes) to the new
name so functionality is unchanged and the linter warning is cleared.
- Around line 354-368: The current tl_topk function allocates five CUDA tensors
(indexes, global_histogram, direct_counter, candidate_idx, candidate_count) and
calls tl_topk_stage1_impl / tl_topk_stage2_impl / tl_topk_stage3_impl on every
invocation, causing per-call cudaMemsetAsync and JIT lookup overhead; fix by
hoisting/memoizing the JIT kernels (tl_topk_stage1_impl, tl_topk_stage2_impl,
tl_topk_stage3_impl) to module scope or caching them keyed by topk, and change
tl_topk to accept optional preallocated scratch buffers (indexes,
global_histogram, direct_counter, candidate_idx, candidate_count) or pull them
from a cache keyed by (batch, topk, device); when reusing buffers ensure you
zero global_histogram, direct_counter, candidate_count and indexes before each
call (since stage1/2 use atomic_add), and add API docs/comments noting the
buffer reuse/zeroing contract.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 621c2ea6-dfa7-4b08-ae96-b0a49f70cd32

📥 Commits

Reviewing files that changed from the base of the PR and between 225ead6 and a569349.

📒 Files selected for processing (1)
  • examples/deepseek_v32/topk_selector_3stages.py

Comment on lines +167 to +183
if tx < RADIX:
for i in T.serial(8):
offset = 1 << i
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
l_val = s_histogram[tx] + s_histogram[tx + offset]
T.sync_threads(3, RADIX)
if tx < RADIX - offset:
s_histogram[tx] = l_val

T.sync_threads(3, RADIX)
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
T.sync_threads()
l_threshold_bin_id = s_threshold_bin_id[0]
l_new_topk = l_new_topk - s_histogram[l_threshold_bin_id + 1]
T.sync_threads()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

s_threshold_bin_id is read without a guaranteed write.

If no tx satisfies s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk — e.g. when the total number of valid elements is ≤ topk, or when all valid elements fall in a single bin and s_histogram[0] == l_new_topks_threshold_bin_id[0] is never written and the subsequent reads at line 181 / line 269 pick up uninitialized shared memory. The intended (B=1, S=131072) shape doesn't hit this, but it's worth a tiny initialization to keep the kernel robust against seq_len < topk callers and starts/ends ranges that exclude most of the row.

🛡️ Proposed fix (apply analogously in stage 3 around line 267)
             T.fill(s_histogram, 0)
             T.sync_threads()
             if tx < RADIX:
                 s_histogram[tx] = global_histogram[bx, tx]
+            if tx == 0:
+                s_threshold_bin_id[0] = 0
             T.sync_threads()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 167 - 183,
Initialize the shared scalar s_threshold_bin_id before the per-thread reduction
so it is always valid if no thread writes it: set s_threshold_bin_id[0] = 0 (or
another safe default) prior to the "if tx < RADIX:" reduction loop, call
T.sync_threads() after that initialization, then proceed with the existing
reduction and the subsequent read into l_threshold_bin_id and use in l_new_topk;
apply the same initialization + sync pattern in the stage-3 analogous block (the
places referencing s_threshold_bin_id, tx, s_histogram, l_new_topk, and
l_threshold_bin_id).

Comment on lines +196 to +202
pos = T.atomic_add(
direct_counter[bx, l_bin_id32 + 1], 1, return_prev=True)
index[bx, l_bin_offset + pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
if pos < SMEM_INPUT_SIZE:
candidate_idx[bx, pos] = input_idx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Candidate count is incremented unconditionally — stage 3 can read past s_input_idx.

T.atomic_add(candidate_count[bx], 1, return_prev=True) fires before the pos < SMEM_INPUT_SIZE guard, so when the threshold bucket holds more than SMEM_INPUT_SIZE elements (e.g., highly degenerate inputs where many values share the same fp16 high‑8 bits — large duplicate runs, constant slabs, saturated/inf‑clipped activations) candidate_count[bx] ends up larger than the candidate buffer. Stage 3 then loads that uncapped value as l_num_input (line 277, propagated to s_num_input[0] at line 283) and the round‑0 loop at lines 301–306 indexes s_input_idx[r_idx, s * BLOCK_SIZE + tx] with tx + s * BLOCK_SIZE going up to l_num_input - 1, which is a shared‑memory OOB read past the [2, SMEM_INPUT_SIZE] buffer.

Random inputs from torch.randn won't trigger this, but it's a latent correctness/safety bug for adversarial or quantized inputs. Easiest fix is to clamp on the consumer side (stage 3 line 277):

🛡️ Proposed clamp in stage 3
-            l_num_input = candidate_count[bx]
+            l_num_input = candidate_count[bx]
+            if l_num_input > SMEM_INPUT_SIZE:
+                l_num_input = SMEM_INPUT_SIZE

Or guard the producer so candidate_count stops growing once the buffer is full:

🛡️ Alternative: only count what fits
                     elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
-                        pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
-                        if pos < SMEM_INPUT_SIZE:
-                            candidate_idx[bx, pos] = input_idx
+                        if candidate_count[bx] < SMEM_INPUT_SIZE:
+                            pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
+                            if pos < SMEM_INPUT_SIZE:
+                                candidate_idx[bx, pos] = input_idx

Note that the second variant has a benign TOCTOU window so a couple of extra atomic_adds may sneak in — pair it with the stage‑3 clamp to be fully safe. The comment at lines 31-34 ("Assumes the threshold bucket size after the first pass is < 4K elements") should also be tightened or replaced with a hard cap.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pos = T.atomic_add(
direct_counter[bx, l_bin_id32 + 1], 1, return_prev=True)
index[bx, l_bin_offset + pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
if pos < SMEM_INPUT_SIZE:
candidate_idx[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if candidate_count[bx] < SMEM_INPUT_SIZE:
pos = T.atomic_add(candidate_count[bx], 1, return_prev=True)
if pos < SMEM_INPUT_SIZE:
candidate_idx[bx, pos] = input_idx
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/deepseek_v32/topk_selector_3stages.py` around lines 196 - 202, The
bug is that candidate_count is incremented unconditionally in the producer
(atomic_add on candidate_count[bx]) which can exceed the candidate buffer and
cause stage_3 to read past s_input_idx; fix by either (A) clamping the consumer
read of l_num_input to at most SMEM_INPUT_SIZE when stage 3 reads
candidate_count (so l_num_input = min(candidate_count[bx], SMEM_INPUT_SIZE)
before using it / before writing s_num_input[0]) or (B) change the producer
logic around candidate_count/candidate_idx to only increment/write when the
measured pos < SMEM_INPUT_SIZE (i.e., perform the atomic_add only if you will
store, or do a compare-and-add pattern so candidate_count never grows beyond
SMEM_INPUT_SIZE); for full safety apply both: cap candidate_count on the
producer and also clamp in the consumer; relevant symbols: candidate_count,
candidate_idx, SMEM_INPUT_SIZE, stage 3 read of l_num_input / s_num_input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants