Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 20, 2025

📌 Description

This PR implements several features and optimizations for top-k:

Multi-CTA optimization

Followup of #2044 , this PR optimizes the top-k/top-p filter using multi-cta optimizations, more specifically:

  • split the vocabulary into chunks and let each cta handles one chunk
  • make sure the logits/probs can fit into shared memory inside a CTA
  • use global memory to store data structures for cross-cta synchronization.
  • make sure the total number of ctas (a multiple of num_chunks) do no exceed number of SMs, using loop to iterate over rows when batch size is greater than number of groups (num_ctas/num_chunks).

The major advantage over the main branch is that we make sure logits/probabilities are stored in shared memory, so multi-round of scan doesn't affect performance too much.

Radix Top-K

The number of radix bits is set to 8.

More Top-K APIs and Data Types

  • Data types: We used to only support fp32 and this PR implements fp16 and bf16.
  • APIs: besides top-k mask logits/renorm probs, we add APIs for raw top-k (return indices/values), and fusion with page table construction (for sparse attention).

Speedup

On H100, the speedup over 0.5.2 is available at:
https://docs.google.com/spreadsheets/d/1DO8_11gzv-EUACCY6q4IMIHa8SaYv4q8hJ6gZl-D0mU/edit?usp=sharing (3-14x faster).

On consumer GPUs (e.g. Ada6000), the gap is even larger, e.g. for small batch size and large vocabulary setting:

v0.5.2
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 10, duration: 3004.42 us, effective bandwidth: 2.73 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 100, duration: 3633.15 us, effective bandwidth: 2.25 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 1000, duration: 4258.82 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 5000, duration: 4256.77 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 10, duration: 2376.80 us, effective bandwidth: 3.45 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 100, duration: 3627.01 us, effective bandwidth: 2.26 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 1000, duration: 3945.38 us, effective bandwidth: 2.08 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 5000, duration: 4259.84 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 10, duration: 3316.74 us, effective bandwidth: 2.47 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 100, duration: 3624.96 us, effective bandwidth: 2.26 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 1000, duration: 4566.02 us, effective bandwidth: 1.79 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 5000, duration: 4576.26 us, effective bandwidth: 1.79 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 10, duration: 3003.39 us, effective bandwidth: 2.73 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 100, duration: 4260.86 us, effective bandwidth: 1.92 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 1000, duration: 3947.52 us, effective bandwidth: 2.08 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 5000, duration: 5514.24 us, effective bandwidth: 1.49 GB/s

Multi-CTA optimization (only) 
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 10, duration: 322.56 us, effective bandwidth: 25.40 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 1000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 5000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 10, duration: 257.02 us, effective bandwidth: 31.87 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 1000, duration: 421.89 us, effective bandwidth: 19.42 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 5000, duration: 455.68 us, effective bandwidth: 17.98 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 10, duration: 355.33 us, effective bandwidth: 23.05 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 100, duration: 388.10 us, effective bandwidth: 21.11 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 1000, duration: 486.40 us, effective bandwidth: 16.84 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 5000, duration: 488.45 us, effective bandwidth: 16.77 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 10, duration: 320.51 us, effective bandwidth: 25.56 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 100, duration: 452.61 us, effective bandwidth: 18.10 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 1000, duration: 422.91 us, effective bandwidth: 19.37 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 5000, duration: 585.73 us, effective bandwidth: 13.99 GB/s

Multi-CTA + Radix (this PR)
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 10, duration: 29.70 us, effective bandwidth: 275.86 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 100, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 1000, duration: 29.70 us, effective bandwidth: 275.86 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=1), k: 5000, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 10, duration: 29.70 us, effective bandwidth: 275.86 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 100, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 1000, duration: 29.70 us, effective bandwidth: 275.86 GB/s
vocab_size: 256000, batch_size: 4, distrib: normal_distribution(std=5), k: 5000, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 10, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 100, duration: 28.90 us, effective bandwidth: 283.50 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 1000, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=0.1), k: 5000, duration: 29.66 us, effective bandwidth: 276.16 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 10, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 100, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 1000, duration: 28.67 us, effective bandwidth: 285.71 GB/s
vocab_size: 256000, batch_size: 4, distrib: gumbel_distribution(beta=1), k: 5000, duration: 28.86 us, effective bandwidth: 283.81 GB/s

The gap can be as large as 100 times.

🔍 Related Issues

#2044

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a radix-based GPU top-k API and a public top_k function (returns values and indices, optional sorting).
  • Improvements

    • New multi-CTA top-k path with per-call row-states buffer for better large-batch throughput; FP32/FP16/BF16 supported.
    • Sampling ops and logits/probs top-k now accept and allocate a row-states buffer; small utility helpers and NVCC flags updated.
    • Exposed topk JIT module and top_k at package level.
  • Tests

    • New and expanded tests for top-k and sampling across dtypes, large batches, and relaxed numeric tolerances.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Threads a per-row device-side row_states_buffer through Python wrappers, TVM bindings, CUDA callers, and kernels; adds radix multi‑CTA Top‑K kernels, inter‑CTA sync primitives, dtype dispatch helpers, JIT wiring, cache buffer options, and expanded tests for FP32/FP16/BF16 and large‑batch cases.

Changes

Cohort / File(s) Summary
Bindings & callers
csrc/flashinfer_sampling_binding.cu, csrc/renorm.cu, csrc/topk.cu, csrc/flashinfer_topk_binding.cu
Add row_states_buffer parameter to sampling/renorm/topk bindings; update renorm/mask to launch Radix multi‑CTA variants; add radix_topk binding declaration and implement radix_topk caller with optional row‑states handling, dtype dispatch, and CUDA error checks.
Device headers & kernels
include/flashinfer/sampling.cuh
Introduce RadixTopKTraits specializations, OrderedType conversions, RadixRowState, inter‑CTA sync primitives, radix utilities (histogram/suffix/pivot), and multi‑CTA kernel/launcher templates for mask/renorm/topk.
Utility headers
include/flashinfer/utils.cuh
Make ceil_div/round_up constexpr noexcept and add new round_down helper.
Python API & JIT wiring
flashinfer/sampling.py, flashinfer/topk.py, flashinfer/__init__.py, flashinfer/jit/topk.py, flashinfer/jit/core.py
Expose top_k/topk APIs and JIT spec; sampling ops accept row_states_buffer and allocate per‑call cache buffer (≈1MB) when absent; add topk module and adjust NVCC flags.
Utils & dispatch helpers
flashinfer/utils.py, csrc/tvm_ffi_utils.h
_get_cache_buf gains zero_init option; add DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 macro to dispatch float/half/bf16.
New CUDA top‑k implementation
csrc/topk.cu
Add radix_topk caller with input/output validation, optional row‑states and output‑values handling, dtype dispatch to RadixTopKMultiCTA, and CUDA status checks.
Bindings export file
csrc/flashinfer_topk_binding.cu
New TVM FFI declaration and export for radix_topk (declaration-only in diff).
Tests
tests/utils/test_sampling.py, tests/utils/test_topk.py
Broaden sampling tests across FP32/FP16/BF16; add comprehensive flashinfer.top_k tests (accuracy thresholds, sorted/unsorted, single/large batches, large k) to exercise multi‑CTA and dtype paths.
Integration & scripts
scripts/task_jit_run_tests_part3.sh, flashinfer/aot.py
Run new topk tests in JIT test runner; include gen_topk_module() in AOT module generation.
Consumers
flashinfer/logits_processor/operators.py
Allocate and pass per‑device row_states_buffer (1MB, zeroed) into top_k_renorm_probs/top_k_mask_logits calls.
Tests tolerance updates
tests/utils/test_logits_processor.py
Relaxed exact-equality assertions to approximate comparisons to accommodate floating-point/dtype differences.

Sequence Diagram(s)

sequenceDiagram
    participant User as Python caller
    participant PyAPI as flashinfer.topk / sampling wrapper
    participant JIT as JIT module loader
    participant TVMBind as TVM CUDA binding
    participant Kernel as Radix Multi‑CTA Kernel (device)
    participant RowState as RadixRowState (device buffer)

    User->>PyAPI: call top_k / top_k_renorm_probs / top_k_mask_logits
    alt no row_states_buffer
        PyAPI->>PyAPI: allocate row_states_buffer (~1MB) via _get_cache_buf
    end
    PyAPI->>JIT: ensure topk module loaded / get binding
    JIT->>TVMBind: call radix_topk / module.top_k_* (ptrs, maybe_row_states_buffer, k)
    TVMBind->>TVMBind: validate shapes/dtypes, select device & stream
    TVMBind->>Kernel: launch RadixTopKMultiCTA / RadixTopKRenormProbMultiCTA / RadixTopKMaskLogitsMultiCTA(...)
    Kernel->>RowState: initialize/use per-row RadixRowState, CTAs coordinate via atomics
    Kernel-->>TVMBind: write outputs (indices, values, or renormed/masked arrays)
    TVMBind-->>PyAPI: return results
    PyAPI-->>User: postprocess (dtype conv, optional sort) and return
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

  • Review focus:
    • include/flashinfer/sampling.cuh: multi‑CTA sync primitives, RadixRowState layout, histogram/pivot correctness, memory alignment/padding.
    • csrc/renorm.cu and csrc/topk.cu: dtype dispatch macro usage, optional buffer handling, stream/device selection, and CUDA error propagation.
    • TVM bindings and argument order changes: csrc/flashinfer_sampling_binding.cu, csrc/flashinfer_topk_binding.cu.
    • Python wrappers: flashinfer/sampling.py, flashinfer/topk.py: cache allocation size/semantics, fake-op parity, public signature changes.
    • Tests: thresholds and large-batch test reliability when exercising multi‑CTA paths.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • aleozlx
  • djmmoss
  • wenscarl
  • bkryu
  • kahyunnam
  • yongwww
  • jiahanc

Poem

🐰
I hopped through kernels, soft and fleet,
Row‑states hummed beneath my feet,
CTAs counted, histograms spun,
Radix carrots sorted one by one,
Top‑K grown — a crunchy treat!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.43% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: bunch of features and optimizations for top-k (sampling + sparse attention)' clearly describes the main changes: performance improvements and optimizations for top-k operations used in sampling and sparse attention. It is specific and directly relates to the changeset.
Description check ✅ Passed The PR description covers all required sections: detailed explanation of multi-CTA optimization, radix top-K, expanded data types and APIs, performance benchmarks, related issues, and pre-commit/test status. It exceeds template requirements with comprehensive technical details and performance data.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement for the top-k/top-p renormalization and sampling routines. By refactoring the CUDA kernels to leverage a multi-CTA architecture and prioritizing shared memory usage for logits and probabilities, the changes aim to drastically reduce memory latency and improve throughput, especially for operations involving multiple scan rounds. This optimization builds upon previous work and focuses on efficient parallel processing across the GPU.

Highlights

  • Multi-CTA Optimization: The top-k/top-p renorm/sampling process has been optimized using a multi-CTA approach, where the vocabulary is split into chunks, and each CTA processes a chunk.
  • Shared Memory Utilization: Logits and probabilities are now ensured to fit into shared memory within each CTA, significantly improving performance by reducing global memory access during multi-round scans.
  • Cross-CTA Synchronization: Global memory is utilized to store data structures necessary for efficient synchronization between CTAs, enabling coordinated processing across the GPU.
  • Dynamic Grid Sizing: The total number of CTAs is dynamically managed to not exceed the number of Streaming Multiprocessors (SMs), with a looping mechanism to iterate over rows when the batch size is larger than the number of available groups.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a multi-CTA optimization for top-k/top-p sampling, which is a significant performance enhancement. The implementation is well-structured, leveraging advanced CUDA features like inter-CTA synchronization via atomic operations and memory fences to efficiently process large vocabularies. The changes in the Python bindings and C++ interface are consistent with the new kernel's requirements. However, I've identified a critical bug in the kernel launch configuration logic within TopKMaskLogitsMultiCTA. The calculation for chunk_size can lead to requesting more shared memory than available, which would cause kernel launch failures under certain conditions. I have provided a detailed comment with a suggested fix for this issue. Overall, this is a great performance improvement, and with the suggested fix, it should be robust.

Comment on lines 2467 to 2476
constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

// Calculate how many CTAs needed per row
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
// Round up chunk_size to multiple of VEC_SIZE
chunk_size = round_up(chunk_size, VEC_SIZE);
// Ensure minimum chunk size
chunk_size = std::max(chunk_size, min_chunk_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for calculating chunk_size is incorrect and can lead to requesting more shared memory than is available, causing a kernel launch failure.

Specifically:

  • max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); at line 2468 can inflate max_chunk_elements beyond the available shared memory.
  • chunk_size = std::max(chunk_size, min_chunk_size); at line 2476 can similarly cause chunk_size to exceed shared memory limits, as it ignores the max_chunk_elements constraint.

This can happen if the available shared memory is small, making max_chunk_elements smaller than min_chunk_size.

I suggest replacing this block with logic that validates against min_chunk_size instead of forcing it, to ensure the kernel configuration is always valid.

      constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
      if (max_chunk_elements < min_chunk_size) {
        // Not enough shared memory for even the minimum chunk size.
        return cudaErrorInvalidConfiguration;
      }

      // Calculate how many CTAs needed per row
      uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
      uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
      // Round up chunk_size to multiple of VEC_SIZE
      chunk_size = round_up(chunk_size, VEC_SIZE);

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 049e8db and 51af95c.

📒 Files selected for processing (6)
  • csrc/flashinfer_sampling_binding.cu (1 hunks)
  • csrc/renorm.cu (1 hunks)
  • flashinfer/sampling.py (3 hunks)
  • flashinfer/utils.py (1 hunks)
  • include/flashinfer/sampling.cuh (1 hunks)
  • include/flashinfer/utils.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • _get_cache_buf (206-217)
csrc/flashinfer_sampling_binding.cu (1)
  • top_k_mask_logits (57-59)
csrc/renorm.cu (2)
  • top_k_mask_logits (61-84)
  • top_k_mask_logits (61-63)
flashinfer/logits_processor/operators.py (1)
  • _to_tensor_scalar_tuple (28-34)
🪛 Ruff (0.14.5)
flashinfer/sampling.py

407-407: Unused function argument: row_states_buffer

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines 205 to 216

def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
def _get_cache_buf(
name: str, bytes: int, device: torch.device, zero_init: bool = False
) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None or buf.size(0) < bytes:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
if zero_init:
buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
else:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Zero-initialized cache must be cleared on reuse

When zero_init=True we only zero the tensor on first allocation; subsequent reuses skip the zero-fill. The new multi-CTA top-k path stores RowReductionState objects in this cache and assumes every launch starts from a fully cleared buffer. A reused buffer now comes back with stale counters/min/max values, so the first CTA observes non-zero state and the reductions diverge (easy to repro by calling top_k_mask_logits twice on the same device). Please zero the buffer whenever zero_init is requested.

     if buf is None or buf.size(0) < bytes:
         if zero_init:
             buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
         else:
             buf = torch.empty(bytes, dtype=torch.uint8, device=device)
         _cache_buf[key] = buf
+    elif zero_init:
+        buf.zero_()
     return buf
🤖 Prompt for AI Agents
In flashinfer/utils.py around lines 205 to 216, the cache allocator only zeroes
the tensor on first allocation but does not clear reused buffers when
zero_init=True; update the function so that when an existing cached buffer is
found and zero_init is True you explicitly zero it (e.g., buf.zero_() or
buf.fill_(0)) before returning/using it, and keep the existing behavior of
allocating a zeroed tensor for new buffers; ensure the zeroing runs on the
correct device and dtype (torch.uint8).

__device__ __forceinline__ void red_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
// SM70 and newer use memory consistency qualifiers
// Release pattern using acq_rel fence + relaxed modifier
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some more clarifiction: besides releasing, this also performs a reduction (sum)

int persistent_iteration = 0;

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoud we add tests that explicitly trigger num_groups > 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have, for vocab_size=128256, it will be splitted into 4 chunks (one per SM), so when batch_size is greater than 33, the num_groups will be greater than 132 (the number of SMs on hopper) for H100.

@yzh119 yzh119 requested a review from kahyunnam as a code owner December 11, 2025 12:34
@yzh119 yzh119 changed the title perf: using multi-cta optimization for top-k/top-p perf: bunch of features and optimizations for top-k (sampling + sparse attention) Dec 11, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)

2231-2286: Shared‑memory chunk sizing can exceed hardware limits (same issue as earlier review)

In all of the multi‑CTA Top‑K helpers you compute a per‑CTA chunk size from max_smem_per_block, but then forcibly enlarge it to a minimum without re‑checking against the shared‑memory budget:

  • TopKMaskLogitsMultiCTA:
    • max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
    • chunk_size = std::max(chunk_size, min_chunk_size);
  • RadixTopKMaskLogitsMultiCTA
  • RadixTopKRenormProbMultiCTA
  • RadixTopKMultiCTA

On GPUs with relatively small opt‑in shared memory, available_for_* / sizeof(...) can be less than min_chunk_size. Forcing max_chunk_elements and chunk_size up past what fits in max_smem_per_block then leads to:

  • smem_size > max_smem_per_block, and
  • cudaFuncSetAttribute / cudaLaunchKernel failing at runtime with cudaErrorInvalidConfiguration.

Instead of forcing the minimum, you should validate and bail if the minimum is not achievable:

-      const size_t available_for_logits = max_smem_per_block - temp_storage_aligned;
-      uint32_t max_chunk_elements = available_for_logits / sizeof(DType);
-
-      // Round down to multiple of VEC_SIZE
-      max_chunk_elements = round_down(max_chunk_elements, VEC_SIZE);
-
-      // Ensure minimum chunk size for vectorized access
-      constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
-      max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
+      const size_t available_for_logits = max_smem_per_block - temp_storage_aligned;
+      uint32_t max_chunk_elements = available_for_logits / sizeof(DType);
+
+      // Round down to multiple of VEC_SIZE
+      max_chunk_elements = round_down(max_chunk_elements, VEC_SIZE);
+
+      // Ensure we have enough shared memory for at least one chunk
+      constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
+      if (max_chunk_elements < min_chunk_size) {
+        // Not enough shared memory for even a single minimally sized chunk
+        return cudaErrorInvalidConfiguration;
+      }
@@
-      uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
-      uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
-      // Round up chunk_size to multiple of VEC_SIZE
-      chunk_size = round_up(chunk_size, VEC_SIZE);
-      // Ensure minimum chunk size
-      chunk_size = std::max(chunk_size, min_chunk_size);
+      uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
+      uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
+      // Round up chunk_size to multiple of VEC_SIZE; we already know this
+      // rounded size cannot exceed max_chunk_elements.
+      chunk_size = round_up(chunk_size, VEC_SIZE);

Apply the same pattern in:

  • RadixTopKMaskLogitsMultiCTA (use sizeof(OrderedType) and its min_chunk_size),
  • RadixTopKRenormProbMultiCTA,
  • RadixTopKMultiCTA.

This keeps the per‑CTA shared‑memory requirement provably ≤ max_smem_per_block and avoids runtime launch failures on lower‑SMEM parts.

Also applies to: 2915-2963, 3255-3299, 3731-3776

🧹 Nitpick comments (7)
flashinfer/jit/topk.py (1)

17-28: LGTM!

The new JIT module generator follows the established pattern for other modules. The JitSpec import on line 18 appears unused (only gen_jit_spec is called), but this is a minor nit that doesn't affect functionality.

-from .core import JitSpec, gen_jit_spec
+from .core import gen_jit_spec
csrc/topk.cu (1)

27-33: Consider adding output shape validation.

The function validates input dimensions but doesn't verify that output_indices has the expected shape (batch_size, top_k). This could lead to out-of-bounds writes if the caller passes an incorrectly sized tensor.

  CHECK_INPUT(input);
  CHECK_INPUT(output_indices);
  CHECK_DIM(2, input);           // input: (batch_size, d)
  CHECK_DIM(2, output_indices);  // output_indices: (batch_size, top_k)

  unsigned int batch_size = input.size(0);
  unsigned int d = input.size(1);
+
+  TVM_FFI_ICHECK_EQ(output_indices.size(0), batch_size)
+      << "output_indices batch size mismatch";
+  TVM_FFI_ICHECK_EQ(output_indices.size(1), top_k)
+      << "output_indices second dimension should equal top_k";
flashinfer/topk.py (2)

71-192: Tighten top_k API to be a safer drop‑in for torch.topk

Right now top_k doesn’t validate k or input shape, and sorted=True is ignored when return_values=False (indices are returned unsorted). For better compatibility and clearer failures:

  • Check input.dim() == 2 and 0 < k <= input.size(1) and raise a ValueError otherwise.
  • Either:
    • sort indices even when return_values=False, or
    • explicitly disallow sorted=True when return_values=False (e.g., raise), and document that combination.

56-64: Silence lint for unused row_states_buffer / output_values in fake op

In _fake_radix_topk, row_states_buffer and output_values are intentionally unused but trigger ARG001. You can keep the signature while silencing lint by underscoring or explicitly discarding them.

-@register_fake_op("flashinfer::radix_topk")
-def _fake_radix_topk(
-    input: torch.Tensor,
-    top_k: int,
-    row_states_buffer: Optional[torch.Tensor],
-    output_values: Optional[torch.Tensor] = None,
-) -> torch.Tensor:
+@register_fake_op("flashinfer::radix_topk")
+def _fake_radix_topk(
+    input: torch.Tensor,
+    top_k: int,
+    _row_states_buffer: Optional[torch.Tensor],
+    _output_values: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
flashinfer/sampling.py (1)

378-384: Unused row_states_buffer in fake ops can be underscored to placate Ruff

The fake implementations of top_k_renorm_probs and top_k_mask_logits need to keep row_states_buffer in their signature but don’t use it, which triggers ARG001.

-def _fake_top_k_renorm_probs(
-    probs: torch.Tensor,
-    maybe_top_k_arr: Optional[torch.Tensor],
-    top_k_val: int,
-    row_states_buffer: torch.Tensor,
-) -> torch.Tensor:
+def _fake_top_k_renorm_probs(
+    probs: torch.Tensor,
+    maybe_top_k_arr: Optional[torch.Tensor],
+    top_k_val: int,
+    _row_states_buffer: torch.Tensor,
+) -> torch.Tensor:
@@
-def _fake_top_k_mask_logits(
-    logits: torch.Tensor,
-    maybe_top_k_arr: Optional[torch.Tensor],
-    top_k_val: int,
-    row_states_buffer: torch.Tensor,
-) -> torch.Tensor:
+def _fake_top_k_mask_logits(
+    logits: torch.Tensor,
+    maybe_top_k_arr: Optional[torch.Tensor],
+    top_k_val: int,
+    _row_states_buffer: torch.Tensor,
+) -> torch.Tensor:

Also applies to: 413-420

tests/utils/test_topk.py (2)

33-47: verify_topk_correctness is currently unused and has an unused parameter

The helper takes indices but never uses it, and the function itself isn’t referenced by any test. You can either remove it or wire it into the suite; in both cases, consider dropping the unused indices parameter.

-def verify_topk_correctness(logits, values, indices, k):
+def verify_topk_correctness(logits, values, k):
@@
-    for i in range(batch_size):
+    for i in range(batch_size):
         # Get the k-th largest value (ground truth threshold)
         kth_largest = torch.kthvalue(-logits[i], k).values.item() * -1

66-67: Avoid binding unused ref_values in tests

In several tests you unpack torch.topk as (ref_values, ref_indices) but only use ref_indices. Renaming the unused variable (e.g., to _) will silence Ruff’s RUF059 and make intent clearer.

-    ref_values, ref_indices = torch.topk(logits, k, dim=-1)
+    _ref_values, ref_indices = torch.topk(logits, k, dim=-1)
@@
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True)
+    _ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True)
@@
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1)
+    _ref_values, ref_indices = torch.topk(logits, k, dim=-1)
@@
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1)
+    _ref_values, ref_indices = torch.topk(logits, k, dim=-1)
@@
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1)
+    _ref_values, ref_indices = torch.topk(logits, k, dim=-1)

Also applies to: 104-105, 138-139, 161-162, 188-189

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 51af95c and b391eb7.

📒 Files selected for processing (13)
  • csrc/flashinfer_sampling_binding.cu (1 hunks)
  • csrc/flashinfer_topk_binding.cu (1 hunks)
  • csrc/renorm.cu (1 hunks)
  • csrc/topk.cu (1 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/jit/core.py (1 hunks)
  • flashinfer/jit/topk.py (1 hunks)
  • flashinfer/sampling.py (9 hunks)
  • flashinfer/topk.py (1 hunks)
  • include/flashinfer/sampling.cuh (3 hunks)
  • tests/utils/test_sampling.py (1 hunks)
  • tests/utils/test_topk.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🧬 Code graph analysis (7)
flashinfer/__init__.py (1)
flashinfer/topk.py (1)
  • top_k (71-192)
csrc/flashinfer_topk_binding.cu (2)
flashinfer/topk.py (2)
  • top_k (71-192)
  • radix_topk (36-54)
csrc/topk.cu (2)
  • radix_topk (24-65)
  • radix_topk (24-26)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
  • top_k_renorm_probs (356-375)
  • top_k_renorm_probs (1249-1320)
  • softmax (52-72)
  • softmax (522-576)
  • top_k_mask_logits (391-411)
  • top_k_mask_logits (1326-1395)
flashinfer/jit/topk.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (216-315)
  • gen_jit_spec (318-384)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (289-291)
tests/utils/test_topk.py (1)
flashinfer/topk.py (1)
  • top_k (71-192)
flashinfer/sampling.py (6)
flashinfer/utils.py (1)
  • _get_cache_buf (206-217)
csrc/flashinfer_sampling_binding.cu (2)
  • top_k_renorm_probs (54-56)
  • top_k_mask_logits (58-60)
csrc/renorm.cu (4)
  • top_k_renorm_probs (42-70)
  • top_k_renorm_probs (42-44)
  • top_k_mask_logits (72-101)
  • top_k_mask_logits (72-74)
flashinfer/logits_processor/types.py (2)
  • probs (81-85)
  • logits (74-78)
csrc/tvm_ffi_utils.h (1)
  • Tensor (299-301)
flashinfer/topk.py (1)
  • top_k (71-192)
🪛 Ruff (0.14.8)
flashinfer/jit/core.py

125-128: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

tests/utils/test_topk.py

33-33: Unused function argument: indices

(ARG001)


66-66: Unpacked variable ref_values is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


104-104: Unpacked variable ref_values is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


138-138: Unpacked variable ref_values is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


161-161: Unpacked variable ref_values is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


188-188: Unpacked variable ref_values is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer/sampling.py

382-382: Unused function argument: row_states_buffer

(ARG001)


418-418: Unused function argument: row_states_buffer

(ARG001)

flashinfer/topk.py

60-60: Unused function argument: row_states_buffer

(ARG001)


61-61: Unused function argument: output_values

(ARG001)

🔇 Additional comments (14)
csrc/tvm_ffi_utils.h (1)

95-111: LGTM!

The new DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 macro follows the established pattern of existing dtype dispatchers, correctly adds FP32 support alongside FP16/BF16, and includes proper error handling for unsupported types.

flashinfer/jit/core.py (1)

125-128: LGTM!

The addition of -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED flag aligns with the new multi-CTA kernel requirements. The concatenation pattern is consistent with other similar flag definitions in this file (e.g., sm100a_nvcc_flags, sm103a_nvcc_flags).

flashinfer/__init__.py (1)

144-145: LGTM!

The new topk module import and top_k function re-export follow the established import patterns in this file and correctly expose the new radix-based top-k functionality at the package level.

csrc/flashinfer_topk_binding.cu (1)

1-25: LGTM!

Clean FFI binding file following the established pattern. The forward declaration correctly matches the implementation signature in csrc/topk.cu, and the export registration properly exposes the function.

csrc/topk.cu (1)

48-61: LGTM!

The dtype dispatch and kernel invocation logic is correct. The optional output validation inside the dispatch block properly handles the case when values output is requested.

csrc/renorm.cu (2)

42-70: LGTM!

The top_k_renorm_probs function is correctly updated with the new row_states_buffer parameter and multi-CTA radix-based implementation. Input validation and dtype dispatch are properly handled.


72-101: LGTM!

The top_k_mask_logits function mirrors the changes to top_k_renorm_probs correctly. The implementation properly supports the multi-CTA path with appropriate validation and error handling.

tests/utils/test_sampling.py (2)

450-516: Solid test coverage for multi-dtype support.

The dtype-parameterized test correctly handles the precision differences between FP32 and FP16/BF16:

  • FP32 uses exact ground truth comparison
  • FP16/BF16 uses property-based validation (sum-to-one, non-zero counts with tolerance)

The tolerance calculation max(k // 5, 20) appropriately accounts for ties at the pivot in low precision.


519-584: LGTM!

The test_top_k_mask_logits function properly validates the mask logits functionality across all supported dtypes with appropriate precision-aware assertions.

csrc/flashinfer_sampling_binding.cu (1)

54-60: Signatures correctly extended with row_states_buffer

The added TensorView row_states_buffer parameters on top_k_renorm_probs and top_k_mask_logits match the updated implementations and FFI exports, so the binding surface remains consistent.

flashinfer/sampling.py (2)

353-375: Ensure C++ kernels support all dtypes allowed by Python assertions

The custom ops for top_k_renorm_probs and top_k_mask_logits now assert support for float32, float16, and bfloat16 without casting to float32, and forward tensors directly into the CUDA path.

In csrc/renorm.cu, the launch sites currently dispatch via DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16, whose name suggests only fp32/fp16 support. If that macro hasn’t been extended to cover BF16 as well, BF16 inputs will hit an unsupported-dtype path at runtime.

Please double-check that:

  • The dispatch macro used in top_k_renorm_probs / top_k_mask_logits handles BF16, or
  • If not, either:
    • restrict the Python-side assertion to {float32, float16}, or
    • insert an explicit probs = probs.float() / logits = logits.float() cast for BF16.

Also applies to: 388-410


1307-1320: Row-state buffer sizing and reuse look reasonable

The public wrappers allocate a 1MB cached row_states_buffer per op/device and reuse it across calls. Given RadixRowState is ~2 KB, this leaves room for ~500 concurrent groups, which exceeds num_sms for current architectures and matches the kernel’s num_groups design.

No change requested here; just noting that the sizing and zero_init=True behavior are consistent with the C++ multi-CTA implementation (which resets arrival_counter at the end of each launch).

Also applies to: 1379-1395

tests/utils/test_topk.py (1)

50-229: Top‑k test coverage looks solid

The parametrized tests exercise large vocabularies, multiple batch sizes, large k, all supported dtypes, sorted/unsorted behavior, and a compatibility case vs torch.topk, which is appropriate for this new kernel.

include/flashinfer/sampling.cuh (1)

119-205: Radix type traits for float/half/bfloat16 look correct

The new RadixTopKTraits specializations implement the usual sign‑bit tricks to map FP values to an unsigned ordered domain and back, plus explicit -inf constants for each dtype. The implementation matches standard descending‑order radix‑sort patterns for float, half, and nv_bfloat16.

Comment on lines 26 to 28
# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter + output_counter)
# = 2*256*4 + 4 + 4 + 4 + 4 = 2064 bytes
RADIX_ROW_STATE_SIZE = 2064
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

RADIX_ROW_STATE_SIZE is out of sync with the C++ RadixRowState struct

RadixRowState in include/flashinfer/sampling.cuh now includes an extra sum_topk field, so its size is 2068 bytes, not 2064. The Python constant and comment still reflect the older layout.

To avoid users under‑allocating custom buffers based on this constant, update it (and the comment) to match the current struct layout.

-# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter + output_counter)
-# = 2*256*4 + 4 + 4 + 4 + 4 = 2064 bytes
-RADIX_ROW_STATE_SIZE = 2064
+# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter
+# + output_counter + sum_topk)
+# = 2*256*4 + 4 + 4 + 4 + 4 + 4 = 2068 bytes
+RADIX_ROW_STATE_SIZE = 2068
🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 26 to 28, the RADIX_ROW_STATE_SIZE constant
and its comment are stale because the C++ RadixRowState struct in
include/flashinfer/sampling.cuh added a sum_topk field; update the comment to
show the new size calculation and change RADIX_ROW_STATE_SIZE to 2068 to match
the C++ struct, and verify the byte math matches the struct field sizes so
Python buffer allocations stay in sync with the C++ layout.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)

372-404: Fix unused arguments in fake ops and replace dtype assertions with explicit exceptions

The custom ops for top_k_renorm_probs and top_k_mask_logits correctly thread row_states_buffer down to the C++ module and restrict inputs to FP32/FP16/BF16. Two issues need fixing:

  1. Fake ops: unused row_states_buffer argument (Ruff ARG001)
    The fake implementations accept row_states_buffer (required for API parity with the custom ops) but never reference it, triggering linter warnings. Add a trivial reference:
 @register_fake_op("flashinfer::top_k_renorm_probs")
 def _fake_top_k_renorm_probs(
     probs: torch.Tensor,
     maybe_top_k_arr: Optional[torch.Tensor],
     top_k_val: int,
     row_states_buffer: torch.Tensor,
 ) -> torch.Tensor:
+    _ = row_states_buffer  # kept for API parity with the real op
     return torch.empty_like(probs)

 @register_fake_op("flashinfer::top_k_mask_logits")
 def _fake_top_k_mask_logits(
     logits: torch.Tensor,
     maybe_top_k_arr: Optional[torch.Tensor],
     top_k_val: int,
     row_states_buffer: torch.Tensor,
 ) -> torch.Tensor:
+    _ = row_states_buffer  # kept for API parity with the real op
     return torch.empty_like(logits)
  1. Dtype validation: use explicit exceptions instead of assert
    assert can be stripped with python -O and produces less controlled error messages. Replace with explicit TypeError:
 def top_k_renorm_probs(
     probs: torch.Tensor,
     maybe_top_k_arr: Optional[torch.Tensor],
     top_k_val: int,
     row_states_buffer: torch.Tensor,
 ) -> torch.Tensor:
-    assert probs.dtype in [torch.float32, torch.float16, torch.bfloat16], (
-        f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16"
-    )
+    if probs.dtype not in (torch.float32, torch.float16, torch.bfloat16):
+        raise TypeError(
+            f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16"
+        )

 def top_k_mask_logits(
     logits: torch.Tensor,
     maybe_top_k_arr: Optional[torch.Tensor],
     top_k_val: int,
     row_states_buffer: torch.Tensor,
 ) -> torch.Tensor:
-    assert logits.dtype in [torch.float32, torch.float16, torch.bfloat16], (
-        f"Unsupported dtype {logits.dtype}, expected float32, float16, or bfloat16"
-    )
+    if logits.dtype not in (torch.float32, torch.float16, torch.bfloat16):
+        raise TypeError(
+            f"Unsupported dtype {logits.dtype}, expected float32, float16, or bfloat16"
+        )

1254-1313: Define _check_tensor_param and _get_cache_buf, then use them in top-p/top-k APIs to fix undefined names and restore validation

Both top_k_renorm_probs and top_k_mask_logits call _check_tensor_param and _get_cache_buf, but these helpers are missing entirely, causing NameError at runtime. Additionally, top_p_renorm_probs lacks parameter validation that the other APIs have.

Fix this by:

  1. Add _check_tensor_param and _get_cache_buf after _to_tensor_scalar_tuple (which is also undefined)

    Implement _check_tensor_param to validate scalar or per-row tensor sampling parameters:

    def _check_tensor_param(param, probs: torch.Tensor) -> None:
        """Validate scalar or per-row tensor sampling parameters against probs/logits batch size."""
        if isinstance(param, torch.Tensor):
            if param.dim() == 0:
                raise ValueError(
                    "Expected a 1D tensor of shape (batch_size,) or scalar for sampling parameter, "
                    "got a 0-dimensional tensor"
                )
            if param.dim() == 1:
                if param.size(0) != probs.size(0):
                    raise ValueError("Sampling parameter tensor batch size mismatch")
                return
            raise ValueError(
                f"Expected a 1D tensor or scalar for sampling parameter, got a {param.dim()}D tensor"
            )

    Implement _get_cache_buf to manage row_states_buffer allocation.

  2. Add _check_tensor_param calls to all three renorm/mask APIs before calling the module

    For top_p_renorm_probs (lines 1254–1313):

    _check_tensor_param(top_p, probs)
    return get_sampling_module().top_p_renorm_probs(
        probs, *_to_tensor_scalar_tuple(top_p)
    )
  3. Ensure _to_tensor_scalar_tuple is properly defined to convert scalars/tensors for the kernel calls.

This resolves the NameError failures, restores explicit parameter validation for top-p, and ensures all three APIs validate inputs consistently.

♻️ Duplicate comments (1)
flashinfer/utils.py (1)

206-216: Zero-initialized cache buffers must be cleared on reuse to avoid stale row state

_get_cache_buf only zero-fills on first allocation or resize; when a cached buffer is reused with zero_init=True, it is returned without being cleared. The new multi-CTA top‑k paths (both sampling and radix_topk) store per-row RowReductionState objects in this cache and assume a clean state each launch; reusing a “dirty” buffer leads to stale counters/min/max and incorrect results on subsequent calls.

You should zero the buffer whenever it’s reused with zero_init=True:

 def _get_cache_buf(
     name: str, bytes: int, device: torch.device, zero_init: bool = False
 ) -> torch.Tensor:
     key = (name, device)
     buf = _cache_buf.get(key)
     if buf is None or buf.size(0) < bytes:
         if zero_init:
             buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
         else:
             buf = torch.empty(bytes, dtype=torch.uint8, device=device)
         _cache_buf[key] = buf
+    elif zero_init:
+        buf.zero_()
     return buf

This keeps the cache behavior but guarantees callers requesting zero_init=True always get a cleared workspace.

🧹 Nitpick comments (2)
tests/utils/test_sampling.py (1)

450-517: Good per-dtype coverage for top_k_renorm_probs; consider tightening FP16/BF16 nonzero-count tolerance

The FP32 ground-truth path and the FP16/BF16 softmax-based checks look solid and exercise the new kernels well. For the low-precision branch, tolerance = max(k // 5, 20) can be quite loose for small k (e.g., k=10 gives a lower bound below zero), so the nonzero-count assertions effectively only cap the upper side. If you want stronger guarantees that the kernel respects the intended min(k, nonzero_input) behavior, you might consider scaling the tolerance relative to expected_counts or clamping it to something like min(expected_counts.min().item(), max(k // 5, 20)).

csrc/renorm.cu (1)

44-73: Multi-CTA top‑k bindings look correct; consider enforcing device match for row_states_buffer

The switch to RadixTopK*MultiCTA with DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 and the extra row_states_buffer argument is wired correctly: FP32/FP16/BF16 are dispatched via c_type, scalar/tensor top_k is handled via maybe_top_k_arr vs top_k_val, and the row-state workspace is threaded through as sampling::RadixRowState*.

One minor robustness improvement: CHECK_INPUT(row_states_buffer) ensures CUDA + contiguity, but doesn’t guard against accidentally passing a buffer on a different CUDA device than probs/logits. Adding an explicit device check would make the FFI safer for non-Python callers:

 void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
                         Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
                         TensorView row_states_buffer) {
   CHECK_INPUT(probs);
-  CHECK_INPUT(row_states_buffer);
+  CHECK_INPUT(row_states_buffer);
+  CHECK_DEVICE(probs, row_states_buffer);
   ...
 }

 void top_k_mask_logits(TensorView logits, TensorView mask_logits,
                        Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
                        TensorView row_states_buffer) {
   CHECK_INPUT(logits);
-  CHECK_INPUT(row_states_buffer);
+  CHECK_INPUT(row_states_buffer);
+  CHECK_DEVICE(logits, row_states_buffer);
   ...
 }

Not blocking for this PR if all current callers pass a same-device buffer, but it will help catch misuses early.

Also applies to: 75-105

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b391eb7 and ca14df1.

📒 Files selected for processing (5)
  • csrc/renorm.cu (3 hunks)
  • csrc/tvm_ffi_utils.h (1 hunks)
  • flashinfer/sampling.py (9 hunks)
  • flashinfer/utils.py (1 hunks)
  • tests/utils/test_sampling.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
flashinfer/utils.py (1)
csrc/fmha_v2/fused_multihead_attention_utils.h (1)
  • zeros (768-768)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
  • top_k_renorm_probs (375-394)
  • top_k_renorm_probs (1320-1391)
  • softmax (53-73)
  • softmax (524-578)
  • top_k_mask_logits (410-430)
  • top_k_mask_logits (1398-1467)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • _get_cache_buf (206-217)
csrc/renorm.cu (4)
  • top_k_renorm_probs (44-73)
  • top_k_renorm_probs (44-46)
  • top_k_mask_logits (75-105)
  • top_k_mask_logits (75-77)
csrc/flashinfer_sampling_binding.cu (2)
  • top_k_renorm_probs (54-56)
  • top_k_mask_logits (58-60)
flashinfer/topk.py (1)
  • top_k (71-192)
🪛 GitHub Actions: pre-commit
flashinfer/sampling.py

[error] 1378-1378: mypy: Name '_check_tensor_param' is not defined. (see ruff: F821 undefined name)


[error] 1451-1451: mypy: Name '_check_tensor_param' is not defined. (see ruff: F821 undefined name)

🪛 Ruff (0.14.8)
flashinfer/sampling.py

401-401: Unused function argument: row_states_buffer

(ARG001)


437-437: Unused function argument: row_states_buffer

(ARG001)


1378-1378: Undefined name _check_tensor_param

(F821)


1451-1451: Undefined name _check_tensor_param

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/utils/test_sampling.py (1)

519-585: top_k_mask_logits tests nicely mirror top_k_renorm_probs behavior across dtypes

The FP32 path’s exact comparison against top_k_renorm_prob and the FP16/BF16 tolerance-based checks (finite counts, sum-to-one after softmax, -inf handling) give good coverage of the new masking kernels. I don’t see functional issues here.

csrc/tvm_ffi_utils.h (1)

96-112: FP32/FP16/BF16 dispatcher macro is consistent with existing patterns

The new DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 macro cleanly extends the existing FP16/BF16 dispatcher to handle float32_code as float while reusing _DISPATCH_CASE_F16 and _DISPATCH_CASE_BF16 for half/bfloat16. The lambda/return __VA_ARGS__() style matches the other dispatch helpers, and the TVM_FFI_ICHECK default keeps unsupported dtypes failing fast.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/sampling.py (1)

407-430: LGTM: Parameter threading is correct.

The row_states_buffer parameter is properly integrated into top_k_mask_logits, with correct dtype validation and module call wiring.

Minor style note: Line 422 has an extra blank line that could be removed for consistency.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca14df1 and 7d71381.

📒 Files selected for processing (1)
  • flashinfer/sampling.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/sampling.py (5)
flashinfer/utils.py (3)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • _get_cache_buf (206-217)
csrc/flashinfer_sampling_binding.cu (2)
  • top_k_renorm_probs (54-56)
  • top_k_mask_logits (58-60)
csrc/renorm.cu (4)
  • top_k_renorm_probs (44-73)
  • top_k_renorm_probs (44-46)
  • top_k_mask_logits (75-105)
  • top_k_mask_logits (75-77)
flashinfer/logits_processor/operators.py (1)
  • _to_tensor_scalar_tuple (28-34)
flashinfer/topk.py (1)
  • top_k (71-192)
🪛 Ruff (0.14.8)
flashinfer/sampling.py

401-401: Unused function argument: row_states_buffer

(ARG001)


437-437: Unused function argument: row_states_buffer

(ARG001)

🔇 Additional comments (5)
flashinfer/sampling.py (5)

381-384: LGTM: Dtype validation correctly implemented.

The runtime dtype checks appropriately validate FP32, FP16, and BF16 support, matching the C++ dispatcher capabilities shown in the relevant code snippets.

Also applies to: 416-419


372-394: LGTM: Parameter threading is correct.

The row_states_buffer parameter is properly added to the function signature, validated, and passed through to the underlying module call, matching the C++ binding signature from csrc/flashinfer_sampling_binding.cu.


396-403: Static analysis false positive: Unused parameters are expected in fake implementations.

The fake/mock implementations (_fake_top_k_renorm_probs and _fake_top_k_mask_logits) correctly accept the row_states_buffer parameter to match the real function signatures but don't use it, which is the intended behavior for test stubs. The Ruff warnings at lines 401 and 437 can be safely ignored.

Also applies to: 432-439


1378-1389: LGTM: Buffer allocation and initialization are correct.

The 1MB row_states_buffer allocation strategy is well-designed:

  • zero_init=True ensures the arrival_counter starts at 0 for multi-CTA synchronization
  • Device-specific caching avoids cross-device conflicts
  • Function-specific cache keys prevent interference between concurrent operations
  • Buffer size is justified by the PR objectives (sufficient for ~500 groups)

The explanatory comment on lines 1458-1460 helpfully documents the kernel's reset behavior.

Also applies to: 1449-1463


1330-1330: LGTM: Documentation accurately reflects dtype support.

The documentation updates correctly specify the supported input dtypes (FP32, FP16, BF16) and clarify that output dtypes match input dtypes, which represents an important behavioral improvement over the previous FP32-only implementation.

Also applies to: 1342-1342, 1405-1405, 1417-1417

@yzh119 yzh119 requested a review from yongwww as a code owner December 11, 2025 21:26
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
include/flashinfer/sampling.cuh (1)

2824-2906: Verify shared memory bounds in these launchers.

The same shared memory calculation issue applies here as in RadixTopKMaskLogitsMultiCTA (lines 2517-2529). Both launchers force max_chunk_elements = std::max(max_chunk_elements, min_chunk_size), which can exceed available shared memory.

Apply the same validation fix suggested for RadixTopKMaskLogitsMultiCTA to prevent cudaErrorInvalidConfiguration on GPUs with limited shared memory.

Also applies to: 3301-3373

flashinfer/topk.py (1)

26-26: RADIX_ROW_STATE_SIZE constant is out of sync with the C++ RadixRowState struct.

The RadixRowState struct in include/flashinfer/sampling.cuh (lines 1880-1887) now includes a sum_topk field (line 1886), making the struct size 2068 bytes, not 2064. Update this constant to match:

+# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter
+# + output_counter + sum_topk)
+# = 2*256*4 + 4 + 4 + 4 + 4 + 4 = 2068 bytes
-RADIX_ROW_STATE_SIZE = 2064
+RADIX_ROW_STATE_SIZE = 2068
🧹 Nitpick comments (1)
flashinfer/topk.py (1)

149-149: Remove unused computation.

Line 149 computes input.size(1) but doesn't assign or use the result. This appears to be dead code.

Apply this diff:

-    input.size(1)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7d71381 and 72fea03.

📒 Files selected for processing (3)
  • flashinfer/topk.py (1 hunks)
  • include/flashinfer/sampling.cuh (2 hunks)
  • scripts/task_jit_run_tests_part3.sh (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🧬 Code graph analysis (1)
flashinfer/topk.py (5)
flashinfer/jit/topk.py (1)
  • gen_topk_module (21-28)
flashinfer/utils.py (1)
  • _get_cache_buf (206-217)
flashinfer/jit/core.py (1)
  • build_and_load (303-315)
csrc/flashinfer_topk_binding.cu (1)
  • radix_topk (20-22)
csrc/topk.cu (2)
  • radix_topk (24-65)
  • radix_topk (24-26)
🪛 Ruff (0.14.8)
flashinfer/topk.py

56-56: Unused function argument: row_states_buffer

(ARG001)


57-57: Unused function argument: output_values

(ARG001)

🔇 Additional comments (14)
scripts/task_jit_run_tests_part3.sh (1)

15-15: LGTM!

The new test invocation follows the established pattern and correctly isolates the top-k tests to run separately, consistent with the script's design goal of isolating CUDA memory issues. This aligns well with the PR's expanded test coverage for top-k functionality.

flashinfer/topk.py (4)

52-60: Static analysis false positive: unused parameters in fake op are intentional.

Ruff flags row_states_buffer and output_values as unused (lines 56-57), but this is expected. Fake ops must match the real op's signature for PyTorch's operator registration system, even if they don't use all parameters.


73-148: LGTM! Excellent documentation.

The docstring is comprehensive, well-structured, and clearly documents:

  • Parameter types and meanings
  • Differences from torch.topk (unsorted by default, O(n) complexity)
  • Performance characteristics (optimal for large vocabularies)
  • Clear examples with expected outputs

153-162: LGTM! Buffer allocation strategy is sound.

The 1MB cache buffer allocation with zero_init=True is well-designed:

  • Size calculation (1MB / 2068 bytes ≈ 494 groups) matches the comment's "up to ~500 groups"
  • Zero initialization ensures arrival_counter starts at 0 for first use
  • Using _get_cache_buf avoids repeated allocations

180-192: LGTM! Sorting and return logic are correct.

The implementation properly handles both sorted and unsorted cases:

  • Sorting uses torch.sort with descending=True for values
  • Indices are reordered via torch.gather to match sorted values
  • Alias topk = top_k provides backward compatibility
include/flashinfer/sampling.cuh (9)

119-204: LGTM! Type traits implementation is correct.

The RadixTopKTraits specializations properly handle FP32, FP16, and BF16:

  • ToOrdered correctly transforms floats to unsigned integers for descending radix sort using IEEE 754 bit manipulation
  • FromOrdered correctly reverses the transformation
  • NegInf constants are accurate (0xFC00 for FP16, 0xFF80 for BF16)
  • num_rounds computation is generic and correct

1801-1821: LGTM!

The launcher function correctly dispatches to the kernel with appropriate template parameters.


1825-1875: LGTM! Multi-CTA synchronization primitives are correctly implemented.

The memory ordering primitives properly implement acquire-release semantics:

  • ld_acquire uses acquire loads on SM70+, with CG fallback
  • red_release correctly uses acq_rel fence before relaxed atomic (addresses earlier review comment with clarification on lines 1844-1845)
  • st_release properly fences before release store
  • wait_ge correctly spins only on thread 0, then synchronizes all threads

1880-1887: LGTM! RadixRowState struct layout is well-designed.

The struct correctly includes all necessary fields for multi-CTA coordination:

  • Double-buffered histograms enable barrier-free ping-pong
  • Size is 2068 bytes (matches the corrected constant in Python)
  • sum_topk field supports RenormProb aggregation

1890-2092: LGTM! Radix select helper functions are correctly implemented.

The device functions properly implement the multi-round radix select algorithm:

  • RadixSuffixSum uses parallel stride-doubling reduction
  • RadixFindThresholdBucket correctly identifies the bucket containing the k-th element
  • RadixBuildLocalHistogram accurately counts elements matching the current prefix
  • RadixSelectOneRound properly coordinates single-CTA and multi-CTA paths with appropriate barriers

2190-2483: LGTM! Kernel implementation is well-optimized.

The RadixTopKMaskLogitsKernel_MultiCTA implementation demonstrates excellent optimization:

  • Single load of data into shared memory (Stage 1) avoids repeated global memory access
  • Double-buffering histograms eliminates second barrier per round
  • Persistent kernel pattern efficiently handles large batch sizes
  • Fast path for k >= vocab_size avoids unnecessary computation

2517-2529: Verify shared memory bounds before launching kernel.

Lines 2517-2519 force max_chunk_elements to at least min_chunk_size, which can exceed available shared memory. Similarly, line 2523's chunk_size calculation may exceed limits. This could cause kernel launch failures with cudaErrorInvalidConfiguration when max_smem_per_block is small.

Consider adding validation logic to prevent configuration errors:

const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
if (available_for_ordered < min_chunk_size * sizeof(OrderedType)) {
  // Not enough shared memory for even the minimum chunk size
  return cudaErrorInvalidConfiguration;
}

// Then proceed with existing logic, knowing min_chunk_size will fit
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = (max_chunk_elements / vec_size) * vec_size;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
// ... rest of calculation

Note: This concern was also raised in a previous review but may not have been addressed yet.


2579-2822: LGTM! RenormProb kernel correctly implements normalization.

The RadixTopKRenormProbKernel_MultiCTA properly:

  • Finds the pivot using the shared RadixSelectFindPivot function
  • Aggregates sum of top-k elements across CTAs via atomicAdd to state->sum_topk
  • Normalizes elements >= pivot and zeros others
  • Handles k >= vocab_size with a fast normalization path

2916-3286: LGTM! Index collection kernel handles ties correctly.

The RadixTopKKernel_MultiCTA implementation uses a smart two-pass approach:

  • Pass 1 collects all elements > pivot with optimized local counting + single global atomic per CTA
  • Pass 2 fills remaining slots with elements == pivot using global atomic coordination
  • This ensures ties are handled correctly and respects the k limit

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 11, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !190 has been created, and the CI pipeline #40056862 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 11, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !190 has been updated with latest changes, and the CI pipeline #40058924 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)

2486-2554: Shared‑memory chunk sizing can exceed device limits (same issue in all three multi‑CTA launchers)

In RadixTopKMaskLogitsMultiCTA, RadixTopKRenormProbMultiCTA, and RadixTopKMultiCTA, the shared‑mem sizing pattern is:

const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);

If available_for_ordered / sizeof(OrderedType) < min_chunk_size, max_chunk_elements gets forcibly promoted to min_chunk_size, ignoring the real shared‑memory limit. That can make smem_size > max_smem_per_block, causing cudaFuncSetAttribute / kernel launch to fail with cudaErrorInvalidConfiguration on devices with smaller dynamic shared memory.

The safer pattern is:

  • Treat min_chunk_size as a requirement, not something to enforce via std::max.
  • If the hardware doesn’t have enough shared memory to fit even min_chunk_size, return an error early instead of over‑requesting.

For example, factor this into a helper and use it in all three launchers:

constexpr uint32_t BLOCK_THREADS = 1024;
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;

const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);
constexpr size_t fixed_smem_aligned = /* existing fixed overhead, already rounded up */;

const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;

uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);

// Not enough shared mem even for minimum viable chunk: bail out.
if (max_chunk_elements < min_chunk_size) {
  return cudaErrorInvalidConfiguration;
}

// Normal path: choose chunk_size ≤ max_chunk_elements.
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);

This addresses the previously raised concern and keeps the three launchers consistent.

Please confirm against cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin, ...) for your lowest‑SMEM target (e.g., older consumer GPUs) that the revised logic never requests more than the allowed dynamic shared memory.

Also applies to: 2809-2878, 3273-3341

🧹 Nitpick comments (2)
flashinfer/topk.py (1)

52-60: Silence Ruff ARG001 in fake op by explicitly “using” the extra arguments

row_states_buffer and output_values in _fake_radix_topk are intentionally unused but trigger Ruff ARG001. To keep signatures aligned with the real op while satisfying linting, you can add a trivial use:

@register_fake_op("flashinfer::radix_topk")
def _fake_radix_topk(
    input: torch.Tensor,
    top_k: int,
    row_states_buffer: Optional[torch.Tensor],
    output_values: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+    # Mark unused arguments as used to satisfy linters; required by the op schema.
+    _ = row_states_buffer, output_values
     batch_size = input.size(0)
     return torch.empty(batch_size, top_k, dtype=torch.int32, device=input.device)
flashinfer/sampling.py (1)

396-404: Fake sampling ops: unused row_states_buffer argument (Ruff ARG001)

_fake_top_k_renorm_probs and _fake_top_k_mask_logits must accept row_states_buffer to mirror the real kernels, but Ruff flags them as unused. You can placate the linter without changing behavior:

@register_fake_op("flashinfer::top_k_renorm_probs")
def _fake_top_k_renorm_probs(
    probs: torch.Tensor,
    maybe_top_k_arr: Optional[torch.Tensor],
    top_k_val: int,
    row_states_buffer: torch.Tensor,
) -> torch.Tensor:
-    return torch.empty_like(probs)
+    _ = row_states_buffer  # required by op schema; unused in fake implementation
+    return torch.empty_like(probs)

(and similarly for _fake_top_k_mask_logits).

Also applies to: 432-439

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4532d88 and 566b432.

📒 Files selected for processing (5)
  • flashinfer/aot.py (2 hunks)
  • flashinfer/sampling.py (9 hunks)
  • flashinfer/topk.py (1 hunks)
  • include/flashinfer/sampling.cuh (2 hunks)
  • tests/utils/test_sampling.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🧬 Code graph analysis (3)
flashinfer/topk.py (3)
flashinfer/jit/topk.py (1)
  • gen_topk_module (21-28)
csrc/flashinfer_topk_binding.cu (1)
  • radix_topk (20-22)
csrc/topk.cu (2)
  • radix_topk (24-65)
  • radix_topk (24-26)
flashinfer/aot.py (1)
flashinfer/jit/topk.py (1)
  • gen_topk_module (21-28)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
  • softmax (53-73)
  • softmax (524-578)
  • top_k_renorm_probs (375-394)
  • top_k_renorm_probs (1320-1390)
  • top_k_mask_logits (410-430)
  • top_k_mask_logits (1397-1465)
🪛 Ruff (0.14.8)
flashinfer/topk.py

56-56: Unused function argument: row_states_buffer

(ARG001)


57-57: Unused function argument: output_values

(ARG001)

flashinfer/sampling.py

401-401: Unused function argument: row_states_buffer

(ARG001)


437-437: Unused function argument: row_states_buffer

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/aot.py (1)

70-71: Top‑k JIT module correctly integrated into AOT misc set

Importing gen_topk_module and adding it to the add_misc list is consistent with how other JIT modules are wired, so top‑k will be built both under JIT and AOT flows. No issues from this change.

Also applies to: 524-533

tests/utils/test_sampling.py (2)

362-393: Relaxed logits/probs alignment threshold for top‑k/top‑p sampling is reasonable

Lowering the required match rate to 95% and explicitly documenting both the softmax‑ordering differences and k‑th‑position ties makes this test more robust to numerics without weakening its intent (catching gross divergences between logits and probs paths). The improved failure message with mismatch counts is also helpful.


451-514: New top‑k renorm/mask tests give strong coverage across dtypes, distributions, and ties

test_top_k_renorm_probs and test_top_k_mask_logits look solid:

  • They drive large batch/vocab sizes, three distributions, and {fp32, fp16, bf16}.
  • They assert output dtype matches input and that renormalized probabilities/softmax(masked_logits) still sum to 1 within a reasonable tolerance.
  • Tie‑aware logic (pivot, num_greater, num_ties) checks that the number of non‑zero / finite elements stays within the mathematically valid range (≥ expected k, ≤ k plus ties), accounting for underflow with a ±1 slack.

This should catch a wide range of regressions in the multi‑CTA radix top‑k implementation.

Also applies to: 516-597

include/flashinfer/sampling.cuh (1)

119-205: RadixTopKTraits specializations look correct for fp32/fp16/bf16

The RadixTopKTraits implementations (ordered key mapping, num_rounds, and NegInf) for float, half, and nv_bfloat16 follow the standard sign‑bit trick and IEEE encodings (0xFC00 for fp16 −inf, 0xFF80 for bf16 −inf). This gives you a clean, type‑generic radix pipeline across all three dtypes.

Comment on lines +372 to 395
@register_custom_op(
"flashinfer::top_k_renorm_probs", mutates_args=("row_states_buffer",)
)
def top_k_renorm_probs(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
row_states_buffer: torch.Tensor,
) -> torch.Tensor:
probs = probs.float()
# Support FP32, FP16, BF16
assert probs.dtype in [torch.float32, torch.float16, torch.bfloat16], (
f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16"
)
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
module.top_k_renorm_probs(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
row_states_buffer,
)
return renorm_probs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's understand the structure and find the relevant files
git ls-files | grep -E "(sampling|cache)" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 767


🏁 Script executed:

# Check the file size of sampling.py to determine reading strategy
wc -l flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 95


🏁 Script executed:

# Read the custom op wrapper at lines 372-395
sed -n '360,400p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1442


🏁 Script executed:

# Read the other mentioned sections (407-440, 1379-1390, 1451-1465)
sed -n '407,440p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1216


🏁 Script executed:

sed -n '1379,1390p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 484


🏁 Script executed:

sed -n '1451,1465p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 645


🏁 Script executed:

# Find _get_cache_buf implementation
rg "_get_cache_buf" -A 10 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 20383


🏁 Script executed:

# Search for buffer cache key structure and device/stream handling
rg "row_states" -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 21881


🏁 Script executed:

# Check if there are existing tests for concurrent stream usage
fd "test.*sampling" -o fd "sampling.*test"

Repository: flashinfer-ai/flashinfer

Length of output: 191


🏁 Script executed:

# Search for arrival_counter and RadixRowState definition
rg "arrival_counter|RadixRowState" -A 5 -B 2 include/

Repository: flashinfer-ai/flashinfer

Length of output: 19138


🏁 Script executed:

# Check the exact behavior of arrival_counter reset
rg "arrival_counter" -A 3 -B 3

Repository: flashinfer-ai/flashinfer

Length of output: 11160


🏁 Script executed:

# Look at test file for concurrent stream testing
cat tests/utils/test_sampling.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3562


🏁 Script executed:

# Search for any stream-related testing in sampling tests
rg "stream|concurrent" tests/utils/test_sampling.py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any comments or documentation about stream safety
rg "stream|concurrent|thread|sync" flashinfer/sampling.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the RadixRowState structure definition
rg "struct RadixRowState|class RadixRowState" -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 1665


🏁 Script executed:

# Verify lines 1379-1390 and 1451-1465 contain the high-level wrappers
sed -n '1375,1395p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 655


🏁 Script executed:

sed -n '1447,1470p' flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 829


🏁 Script executed:

# Check what those functions are
rg "^def top_k_renorm_probs|^def top_k_mask_logits" -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 394


Shared row_states_buffer in sampling top‑k APIs has concurrency risk with multiple CUDA streams

Both the custom op wrappers and high-level sampling APIs for top-k operations share a single 1 MB row_states_buffer per device:

  • Custom ops top_k_renorm_probs (lines 372–395) and top_k_mask_logits (lines 407–440) accept row_states_buffer and pass it to CUDA kernels, which use it for inter-CTA synchronization via RadixRowState (histogram, arrival_counter, sum_topk, etc.).
  • High-level APIs allocate this buffer via _get_cache_buf("..._row_states_{device}", 1MB, device, zero_init=True) at lines 1379–1390 and 1451–1465, keyed only by device.

The buffer cache in _get_cache_buf is keyed by (name, device) only, not by stream. This means all concurrent kernel launches on the same device that access the same row_states_buffer share one RadixRowState[] array. When kernels launch concurrently on different CUDA streams, their CTAs can interleave reads and writes to arrival_counter, histogram[], output_counter, and sum_topk, causing data corruption.

The kernel does reset arrival_counter at completion (via st_release), but this happens only after the kernel finishes. Concurrent launches see the buffer in a partially-updated state.

Recommended mitigations:

  • Key the buffer cache by both (name, device, stream) to isolate concurrent launches, or
  • Document that these APIs are single-stream only and add assertions/tests to catch multi-stream usage, or
  • Allocate buffers per-call when concurrency is required instead of caching per-device.

If your users commonly parallelize sampling across multiple streams, add a regression test that launches top_k_renorm_probs and top_k_mask_logits concurrently on distinct streams and verifies outputs against a single-stream baseline.

Also applies to: lines 407–440, 1379–1390, 1451–1465

🤖 Prompt for AI Agents
In flashinfer/sampling.py around lines 372-395, the top_k_renorm_probs wrapper
(and related top_k_mask_logits and the high-level APIs at lines 1379–1390 and
1451–1465) accepts a shared 1MB row_states_buffer that is cached per (name,
device) only, which causes data races when kernels are launched concurrently on
different CUDA streams; fix by changing the buffer-cache keying to include the
CUDA stream (e.g., (name, device, stream)) so each simultaneous launch gets an
isolated row_states_buffer, or alternatively stop caching per-device and
allocate a fresh buffer per-call when stream isolation is required, and add a
regression test that launches the top-k kernels on multiple streams and verifies
outputs against a single-stream baseline; if choosing not to change allocation,
add clear documentation and runtime assertions preventing multi-stream usage.

Comment on lines 155 to 164
# Allocate row_states buffer for multi-CTA path
# For single-CTA path this buffer is not used but we always allocate for simplicity
# 1MB is enough for any reasonable GPU (covers up to ~500 groups)
# zero_init=True ensures arrival_counter starts at 0 on first use
row_states_buffer: Optional[torch.Tensor] = _get_cache_buf(
f"radix_topk_row_states_{input.device}",
1024 * 1024, # 1MB
input.device,
zero_init=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat flashinfer/topk.py

Repository: flashinfer-ai/flashinfer

Length of output: 6884


Shared row_states_buffer can race across concurrent calls on the same device

row_states_buffer is cached per device via _get_cache_buf (line 154), keyed only by device without stream specificity. All top_k invocations on a given GPU share the same 1 MB buffer. Since the radix_topk kernel mutates this buffer for inter-CTA synchronization (line 31), concurrent launches on different CUDA streams can trample each other's state and corrupt results.

Consider either:

  • Making the cache key stream- or call-specific, or
  • Documenting and enforcing single-stream usage, or
  • Allocating a per-call buffer when heavy concurrency is expected.

Additionally, input.size(1) (line 154) has its result unused, and sorted=True is silently ignored when return_values=False (sorting only applied at lines 176–177).

Comment on lines 2883 to 3258
* \brief Multi-CTA Radix Top-K kernel that returns indices of top-k elements.
*
* Uses cooperative multi-CTA radix select to find the k-th largest element,
* then collects indices of all elements >= pivot.
*/
template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType,
typename IdType>
__global__ void __launch_bounds__(BLOCK_THREADS)
RadixTopKKernel_MultiCTA(DType* input, // [batch, vocab_size]
IdType* output_indices, // [batch, top_k]
DType* output_values, // [batch, top_k] or nullptr
IdType* top_k_arr, // [batch] or nullptr
uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size,
RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA)
uint32_t chunk_size, // elements per CTA
uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA)
{
// Type traits for FP16/BF16/FP32 support
using Traits = RadixTopKTraits<DType>;
using OrderedType = typename Traits::OrderedType;

constexpr uint32_t RADIX = 256;
constexpr uint32_t RADIX_BITS = 8;
constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>();
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;

const uint32_t global_cta_id = blockIdx.x;
const uint32_t group_id = global_cta_id / ctas_per_group;
const uint32_t cta_in_group = global_cta_id % ctas_per_group;
const uint32_t tx = threadIdx.x;

// Shared memory layout: [fixed storage] [ordered values cache]
extern __shared__ uint8_t smem[];

// Fixed shared memory (at the beginning)
// When SINGLE_CTA, we need an extra uint32 for output_counter (no global state)
constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4;
constexpr size_t fixed_smem_size =
sizeof(uint32_t) * (RADIX + RADIX + num_scalars); // histogram + suffix + scalars
uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem);
uint32_t* suffix_sum = local_histogram + RADIX;
uint32_t* shared_scalars = suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket,
// found_remaining_k, (output_counter)]

// Align ordered values cache to 16 bytes
size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16;
OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset);

// Aliases for scalar shared variables
#define prefix_cache shared_scalars[0]
#define remaining_k_cache shared_scalars[1]
#define found_bucket shared_scalars[2]
#define found_remaining_k shared_scalars[3]
#define shared_output_counter shared_scalars[4] // Only valid when SINGLE_CTA

// State pointer only used when not SINGLE_CTA
RadixRowState* state = nullptr;
if constexpr (!SINGLE_CTA) {
state = &row_states[group_id];
}

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups;

int barrier_phase = 0;

// Persistent loop over rows
for (uint32_t iter = 0; iter < total_iterations; iter++) {
uint32_t row_idx = group_id + iter * num_groups;

if (row_idx >= batch_size) break;

const uint32_t chunk_start = cta_in_group * chunk_size;
const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size);

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

const uint32_t actual_chunk_size = chunk_end - chunk_start;

if (k >= vocab_size) {
// k >= vocab_size: return all indices
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
if (chunk_start + i < k) {
output_indices[row_idx * top_k_val + chunk_start + i] =
static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + chunk_start + i] =
input[row_idx * vocab_size + chunk_start + i];
}
}
}
continue;
}

// ========== Stage 1: Load and convert to ordered representation in shared memory ==========
vec_t<DType, VEC_SIZE> input_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
input_vec.cast_load(input + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
shared_ordered[i + j] = Traits::ToOrdered(input_vec[j]);
}
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
shared_ordered[i] = Traits::ToOrdered(input[row_idx * vocab_size + chunk_start + i]);
}
__syncthreads();

// Initialize local caches and clear global state
if (tx == 0) {
prefix_cache = 0;
remaining_k_cache = k;
if constexpr (SINGLE_CTA) {
shared_output_counter = 0; // Use shared memory counter for single CTA
}
}
// Clear global histograms (only needed for multi-CTA)
if constexpr (!SINGLE_CTA) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[0][i] = 0;
state->histogram[1][i] = 0;
}
}
__syncthreads();

// Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA)
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

// CTA 0 clears output counter AFTER barrier
if (cta_in_group == 0 && tx == 0) {
st_release(&state->output_counter, 0);
}
__syncthreads();
}

// ========== Stage 2: NUM_ROUNDS of radix select ==========
// Using double-buffering: round N uses histogram[N % 2]
// Round N clears histogram[(N+1) % 2] for next round's use
for (uint32_t round = 0; round < NUM_ROUNDS; ++round) {
uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS;
// Read from local cache (no global memory access needed!)
uint32_t prefix = prefix_cache;
uint32_t remaining_k = remaining_k_cache;

// For multi-CTA: pointers to global histograms
// For single-CTA: these are not used
uint32_t* current_hist = nullptr;
uint32_t* other_hist = nullptr;
if constexpr (!SINGLE_CTA) {
current_hist = state->histogram[round % 2];
other_hist = state->histogram[(round + 1) % 2];
}

// Clear local histogram AND (for multi-CTA) clear the "other" global histogram
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
local_histogram[i] = 0;
if constexpr (!SINGLE_CTA) {
other_hist[i] = 0; // Prepare for next round (no barrier needed!)
}
}
__syncthreads();

// Build local histogram from SHARED MEMORY (no global memory access!)
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered = shared_ordered[i];

// Check if this element matches the prefix (high bits determined so far)
OrderedType mask =
(round == 0)
? OrderedType(0)
: static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS));
if ((ordered & mask) == static_cast<OrderedType>(prefix)) {
uint32_t bucket = (ordered >> shift) & 0xFF;
atomicAdd(&local_histogram[bucket], 1);
}
}
__syncthreads();

// For multi-CTA: add to global histogram and barrier
// For single-CTA: local_histogram is already the complete histogram
if constexpr (!SINGLE_CTA) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
if (local_histogram[i] > 0) {
atomicAdd(&current_hist[i], local_histogram[i]);
}
}

// Barrier: wait for all CTAs to finish histogram accumulation
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

// Load from global histogram to suffix_sum
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = current_hist[i];
}
} else {
// Single-CTA: copy local histogram directly to suffix_sum
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = local_histogram[i];
}
}
__syncthreads();

// Parallel suffix sum in shared memory
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) {
val += suffix_sum[tx + stride];
}
}
__syncthreads();
if (tx < RADIX) {
suffix_sum[tx] = val;
}
__syncthreads();
}

// ALL CTAs: find threshold bucket (all compute same result)
if (tx == 0) {
found_bucket = 0;
found_remaining_k = remaining_k;
}
__syncthreads();

if (tx < RADIX) {
uint32_t count_ge = suffix_sum[tx];
uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0;
if (count_ge >= remaining_k && count_gt < remaining_k) {
found_bucket = tx;
found_remaining_k = remaining_k - count_gt;
}
}
__syncthreads();

// Update local caches (all CTAs have same values)
if (tx == 0) {
prefix_cache = prefix | (found_bucket << shift);
remaining_k_cache = found_remaining_k;
}
__syncthreads();
}

// Get final ordered pivot from prefix_cache
OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache);

// ========== Stage 3: Collect indices >= pivot ==========
// Two-pass approach to handle ties correctly:
// Pass 1: collect all elements strictly > pivot (these must be in top-k)
// Pass 2: fill remaining slots with elements == pivot
//
// Optimization for Pass 1 (> pivot): Use shared memory atomic to count locally,
// then one global atomic per CTA to get base position, then shared atomic to write.
// This works because all > pivot elements are guaranteed to be in top-k.
//
// For Pass 2 (== pivot): Use global atomic directly since we need cross-CTA
// coordination to respect the k limit (some == pivot elements may be truncated).

// Reuse local_histogram[0..1] as counters
#define local_counter local_histogram[0]
#define global_base local_histogram[1]

// Pass 1: Count elements > pivot locally, then write with one global atomic
if (tx == 0) {
local_counter = 0;
}
__syncthreads();

// First pass: count how many elements > pivot in this CTA
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val > ordered_pivot) {
atomicAdd(&local_counter, 1);
}
}
__syncthreads();

// Get base position for this CTA
uint32_t cta_count_gt = local_counter;
if (tx == 0 && cta_count_gt > 0) {
if constexpr (SINGLE_CTA) {
global_base = atomicAdd(&shared_output_counter, cta_count_gt);
} else {
global_base = atomicAdd(&state->output_counter, cta_count_gt);
}
}
__syncthreads();

// Second pass: write elements > pivot using local shared atomic for position
if (tx == 0) {
local_counter = 0; // Reset for use as write position
}
__syncthreads();

if (cta_count_gt > 0) {
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val > ordered_pivot) {
uint32_t local_pos = atomicAdd(&local_counter, 1);
int pos = global_base + local_pos;
// No need to check pos < k here since all > pivot elements are in top-k
output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val);
}
}
}
}

// Barrier to ensure all > pivot elements are collected first (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
}
__syncthreads();

// Pass 2: Write elements == pivot
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val == ordered_pivot) {
int pos;
if constexpr (SINGLE_CTA) {
pos = atomicAdd(&shared_output_counter, 1);
} else {
pos = atomicAdd(&state->output_counter, 1);
}
if (pos < static_cast<int>(k)) {
output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val);
}
}
}
}

#undef local_counter
#undef global_base
// No barrier needed here - the barrier at the start of next iteration
// ensures all CTAs complete Stage 3 before output_counter is reset
}

// Reset arrival counter for next kernel launch (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
if (cta_in_group == 0 && tx == 0) {
st_release(&state->arrival_counter, 0);
}
}

#undef prefix_cache
#undef remaining_k_cache
#undef found_bucket
#undef found_remaining_k
#undef shared_output_counter
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find topk.cu and understand how the kernel is called
fd -e cu -e h | grep -i topk | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 112


🏁 Script executed:

# Look for the RadixTopKMultiCTA function signature and how it's called
rg "RadixTopKMultiCTA|RadixTopKKernel_MultiCTA" -A 5 -B 5 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4674


🏁 Script executed:

# Find lines 3273-3341 mentioned in the review
sed -n '3273,3341p' include/flashinfer/sampling.cuh | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3545


🏁 Script executed:

# Check for any existing assertions on top_k_arr in the kernel
rg "assert|Assert|CHECK" include/flashinfer/sampling.cuh | grep -i "top_k" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for all callers of RadixTopKMultiCTA
rg "RadixTopKMultiCTA" --type cu --type h

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

# Verify the exact range 3273-3341 and see what's there
wc -l include/flashinfer/sampling.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 104


🏁 Script executed:

# Get lines 3273-3341 to confirm it's the wrapper function
sed -n '3273,3300p' include/flashinfer/sampling.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 1592


🏁 Script executed:

# Check if there are any other calls to RadixTopKMultiCTA outside of topk.cu
rg "RadixTopKMultiCTA" --glob "*.cu" --glob "*.h" --glob "*.cuh"

Repository: flashinfer-ai/flashinfer

Length of output: 261


🏁 Script executed:

# Examine the exact lines referenced for "Also applies to" - check if wrapper also has the same issue
sed -n '3273,3341p' include/flashinfer/sampling.cuh | grep -A 20 "RadixTopKMultiCTA"

Repository: flashinfer-ai/flashinfer

Length of output: 1207


Add validation for top_k_arr parameter to prevent out-of-bounds writes

RadixTopKMultiCTA and RadixTopKKernel_MultiCTA accept an optional per-row top_k_arr parameter but index into output buffers using a fixed stride of top_k_val:

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];
output_indices[row_idx * top_k_val + pos] = /* ... */;

This will cause out-of-bounds writes if any element in top_k_arr exceeds top_k_val, since the output tensors are allocated as [batch_size, top_k_val].

Currently safe because the only caller (topk.cu) passes top_k_arr == nullptr, but future callers could inadvertently exceed the buffer bounds. Add a validation check at the top of RadixTopKMultiCTA:

if (top_k_arr != nullptr) {
  // Assert all top_k_arr[i] <= top_k_val, or adjust buffer allocation
  assert(/* validate contract */);
}

Alternatively, if per-row k values are needed, allocate output buffers based on max(top_k_arr) rather than top_k_val.

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2883 to 3258, the kernel uses
per-row k from top_k_arr but always strides outputs by top_k_val which can cause
out-of-bounds writes if any top_k_arr[row] > top_k_val; add a validation at the
start of RadixTopKMultiCTA (before launching kernels / inside the host wrapper
function that prepares top_k_arr) that, if top_k_arr != nullptr, computes the
maximum requested k and either (a) asserts/returns an error when max_k >
top_k_val, or (b) requires/allocates output buffers sized to max_k instead of
top_k_val; ensure the check is deterministic and runs before any kernel
invocation so no kernel can write beyond allocated output buffer.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/sampling.cuh (2)

2500-2569: Critical: shared-memory chunk sizing can exceed opt-in SMEM (duplicate of prior review)
The pattern max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); can force chunk_size beyond what fits, causing launch failure. Also max_smem_per_block - fixed_smem_aligned can underflow if fixed overhead already exceeds opt-in SMEM. This matches the earlier critical comment and should be fixed in all three launchers.

-  const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
+  if (max_smem_per_block <= static_cast<int>(fixed_smem_aligned)) {
+    return cudaErrorInvalidConfiguration;
+  }
+  const size_t available_for_ordered =
+      static_cast<size_t>(max_smem_per_block) - fixed_smem_aligned;

   uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
   max_chunk_elements = round_down(max_chunk_elements, vec_size);
   const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
-  max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
+  if (max_chunk_elements < min_chunk_size) {
+    // Not enough SMEM to run with current BLOCK_THREADS / vec_size.
+    return cudaErrorInvalidConfiguration;
+  }

   uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
   uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
   chunk_size = round_up(chunk_size, vec_size);
   chunk_size = std::min(chunk_size, max_chunk_elements);

   const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);
+  if (smem_size > static_cast<uint32_t>(max_smem_per_block)) {
+    return cudaErrorInvalidConfiguration;
+  }

Apply the same fix pattern in:

  • RadixTopKMaskLogitsMultiCTA (Line 2500+)
  • RadixTopKRenormProbMultiCTA (Line 2834+)
  • RadixTopKMultiCTA (Line 3305+)

Also applies to: 2834-2903, 3305-3373


2913-3291: Prevent OOB writes when top_k_arr[row] > top_k_val (duplicate of prior review)
output_indices/output_values are indexed with a row_idx * top_k_val + pos stride, but k can come from top_k_arr[row]. If any top_k_arr[row] > top_k_val, this can write past the output allocation.

A pragmatic safety fix is to cap the effective k to top_k_val inside RadixTopKKernel_MultiCTA:

-    uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];
+    const uint32_t k_req = top_k_arr == nullptr ? top_k_val : static_cast<uint32_t>(top_k_arr[row_idx]);
+    const uint32_t k_out = min(k_req, top_k_val);
+    const uint32_t k = min(k_out, vocab_size);

     ...
-    if (k >= vocab_size) {
+    if (k_req >= vocab_size) {
       // k >= vocab_size: return all indices
       for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
-        if (chunk_start + i < k) {
+        if (chunk_start + i < k_out) {
           output_indices[row_idx * top_k_val + chunk_start + i] =
               static_cast<IdType>(chunk_start + i);
           if (output_values != nullptr) {
             output_values[row_idx * top_k_val + chunk_start + i] =
                 input[row_idx * vocab_size + chunk_start + i];
           }
         }
       }
       continue;
     }
     ...
-        if (pos < static_cast<int>(k)) {
+        if (pos < static_cast<int>(k_out)) {
           output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i);
           ...
         }

(If you require per-row k > top_k_val, then the allocator/API needs to size outputs by max(top_k_arr) instead.)

Also applies to: 3305-3373

🧹 Nitpick comments (3)
include/flashinfer/sampling.cuh (3)

119-205: Add compile-time guards for radix configuration and document NaN behavior
RadixTopKTraits<>::num_rounds() assumes RADIX_BITS divides the type width; add a static_assert to prevent silent misconfigs. Also worth a short note on NaN ordering (bitwise mapping makes NaNs “sortable” but not numerically meaningful).

 template <typename DType>
 struct RadixTopKTraits;

 template <>
 struct RadixTopKTraits<float> {
   using OrderedType = uint32_t;

   template <uint32_t RADIX_BITS>
   static __host__ __device__ constexpr uint32_t num_rounds() {
+    static_assert((sizeof(OrderedType) * 8) % RADIX_BITS == 0, "RADIX_BITS must divide type width");
     return sizeof(OrderedType) * 8 / RADIX_BITS;
   }
   ...
 };

1825-1885: Spin-wait could use a backoff to reduce SM contention under load
wait_ge() busy-spins; consider a tiny backoff (__nanosleep) in the loop to reduce pressure when groups are imbalanced.


1889-2198: RadixFindThresholdBucket writes shared scalars from multiple lanes; ensure single-writer invariant is explicit
It’s likely only one bucket satisfies the predicate, but that contract isn’t enforced. Consider using a sentinel + atomicCAS (or a warp ballot + single lane write) to make “exactly one writer” deterministic.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 566b432 and 3843536.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🪛 GitHub Actions: pre-commit
include/flashinfer/sampling.cuh

[error] 2030-2032: clang-format formatting check failed. Some files were modified by this hook.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)

1801-1821: TopPRenormProb wrapper wiring looks consistent with the updated kernel
Launch config / SMEM sizing / dispatch structure is coherent.

Comment on lines 1823 to 3364
// ==================== Multi-CTA Top-K Implementation ====================

min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, MinReduceOp{});
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, MaxReduceOp{});
if (tx == 0) {
temp_storage.block_aggregate.counts[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.counts[1] = aggregate_gt_pivot_1;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.counts[0];
aggregate_gt_pivot_1 = temp_storage.block_aggregate.counts[1];
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;

if (aggregate_gt_pivot_1 >= k) {
low = pivot_1;
} else if (aggregate_gt_pivot_0 >= k) {
low = pivot_0;
high = min(pivot_1, max_le_high);
} else {
high = min(pivot_0, max_le_high);
}
} while (min_gt_low != max_le_high);
pivot = low;
// Acquire/Release primitives for inter-CTA synchronization
__device__ __forceinline__ int ld_acquire(int* ptr) {
int state = 0;

#if (__CUDA_ARCH__ >= 700)
// SM70 and newer use memory consistency qualifiers
// Acquire pattern using acquire modifier
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#else
asm volatile("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr));
#endif

return state;
}

__device__ __forceinline__ void red_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
// SM70 and newer use memory consistency qualifiers
// Release pattern using acq_rel fence + relaxed modifier
// (The fence also releases data that was weakly-written by other threads prior to the last
// syncthreads)
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicAdd(ptr, val);
#endif
}

__device__ __forceinline__ void st_release(int* ptr, int val) {
#if (__CUDA_ARCH__ >= 700)
// SM70 and newer use memory consistency qualifiers
// Release pattern: fence + release store
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicExch(ptr, val);
#endif
}

__device__ __forceinline__ void st_release(uint32_t* ptr, uint32_t val) {
#if (__CUDA_ARCH__ >= 700)
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val));
#else
__threadfence();
atomicExch(ptr, val);
#endif
}

// Wait until the value at ptr reaches target_val using acquire semantics
// Only thread 0 spins, then all threads synchronize
__device__ __forceinline__ void wait_ge(int* ptr, int target_val, int thread_idx) {
if (thread_idx == 0) {
#pragma unroll 1
while (ld_acquire(ptr) < target_val) {
}
}
__syncthreads();
}

// masking
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
logits_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
// ==================== Multi-CTA Radix Top-K Mask Logits ====================

// Global state for multi-CTA radix reduction (one per group)
struct RadixRowState {
uint32_t histogram[3][256]; // Triple-buffered histograms for 1-barrier-per-round
uint32_t remaining_k; // Remaining k after current round
uint32_t prefix; // Accumulated prefix (high bits of k-th element)
int arrival_counter; // For inter-CTA synchronization
int output_counter; // For collecting top-k indices (RadixTopK)
float sum_topk; // For RenormProb: sum of top-k elements
};

// ==================== Common Device Functions for Radix Top-K ====================

/*!
* \brief Compute suffix sum in shared memory using parallel reduction.
*
* After this function, suffix_sum[i] contains the count of elements >= bucket i.
* This is computed by summing all histogram values from bucket i to 255.
*
* \param suffix_sum Shared memory array of size RADIX (256)
* \param tx Thread index within the block
*/
template <uint32_t BLOCK_THREADS>
__device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t tx) {
constexpr uint32_t RADIX = 256;
// Parallel suffix sum: compute count of elements >= each bucket
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) {
val += suffix_sum[tx + stride];
}
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
logits_vec[j] =
(logits_vec[j] > pivot) ? logits_vec[j] : -cuda::std::numeric_limits<float>::infinity();
__syncthreads();
if (tx < RADIX) {
suffix_sum[tx] = val;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
logits_vec.store(masked_logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
__syncthreads();
}
}

/*!
* \brief Find the threshold bucket that contains the k-th largest element.
*
* The threshold bucket satisfies: count_ge >= k && count_gt < k
* where count_ge = suffix_sum[bucket] and count_gt = suffix_sum[bucket+1].
*
* \param suffix_sum Shared memory array containing suffix sums
* \param remaining_k Number of top-k elements still to find
* \param found_bucket Output: the found threshold bucket
* \param found_remaining_k Output: remaining_k minus count of elements > threshold
* \param tx Thread index within the block
*/
__device__ __forceinline__ void RadixFindThresholdBucket(uint32_t* suffix_sum, uint32_t remaining_k,
uint32_t* found_bucket,
uint32_t* found_remaining_k, uint32_t tx) {
constexpr uint32_t RADIX = 256;
// Initialize (only thread 0)
if (tx == 0) {
*found_bucket = 0;
*found_remaining_k = remaining_k;
}
__syncthreads();

// All threads in RADIX range check their bucket
if (tx < RADIX) {
uint32_t count_ge = suffix_sum[tx];
uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0;
if (count_ge >= remaining_k && count_gt < remaining_k) {
*found_bucket = tx;
*found_remaining_k = remaining_k - count_gt;
}
}
__syncthreads();
}

template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t top_k_val, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.max_val = 0;

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);

double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
double pivot_0 = (high + 2 * low) / 3;
double pivot_1 = (2 * high + low) / 3;

ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
min_gt_low = high;
max_le_high = low;
ValueCount<float> threadlocal_aggregate_gt_pivot_0{0, 0},
threadlocal_aggregate_gt_pivot_1{0, 0};
#pragma unroll 1
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0_pair[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1_pair[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};

if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_pair[j];
threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_pair[j];
}
}
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(threadlocal_aggregate_gt_pivot_0);
__syncthreads();
/*!
* \brief Build local histogram for one round of radix select.
*
* Counts elements in shared_ordered that match the current prefix and bins them
* by their byte at the current shift position.
*
* \tparam OrderedType The ordered integer type (uint16_t or uint32_t)
* \param shared_ordered Shared memory containing ordered values
* \param actual_chunk_size Number of elements in this CTA's chunk
* \param local_histogram Output shared memory histogram
* \param prefix Current prefix (high bits determined so far)
* \param shift Bit shift for extracting current byte
* \param round Current round (0 to NUM_ROUNDS-1)
* \param tx Thread index
*/
template <uint32_t BLOCK_THREADS, typename OrderedType>
__device__ __forceinline__ void RadixBuildLocalHistogram(const OrderedType* shared_ordered,
uint32_t actual_chunk_size,
uint32_t* local_histogram, uint32_t prefix,
uint32_t shift, uint32_t round,
uint32_t tx) {
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;
constexpr uint32_t RADIX_BITS = 8;

for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered = shared_ordered[i];

// Check if this element matches the prefix (high bits determined so far)
OrderedType mask =
(round == 0)
? OrderedType(0)
: static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS));
if ((ordered & mask) == static_cast<OrderedType>(prefix)) {
uint32_t bucket = (ordered >> shift) & 0xFF;
atomicAdd(&local_histogram[bucket], 1);
}
}
}

aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(threadlocal_aggregate_gt_pivot_1);
__syncthreads();
/*!
* \brief Perform one round of radix select with optional multi-CTA synchronization.
*
* This is the core radix select logic used by all TopK kernels.
* It builds histogram, aggregates across CTAs (if multi-CTA), computes suffix sum,
* and finds the threshold bucket.
*
* \tparam BLOCK_THREADS Number of threads per block
* \tparam SINGLE_CTA True if single-CTA mode (no inter-CTA sync needed)
* \tparam OrderedType The ordered integer type
*
* \param shared_ordered Shared memory containing ordered values
* \param actual_chunk_size Number of elements in this CTA's chunk
* \param local_histogram Shared memory for local histogram (size RADIX)
* \param suffix_sum Shared memory for suffix sum computation (size RADIX)
* \param state Pointer to RadixRowState for multi-CTA sync (nullptr if SINGLE_CTA)
* \param prefix Current prefix value
* \param remaining_k Current remaining k value
* \param round Current round (0 to NUM_ROUNDS-1)
* \param barrier_phase Reference to barrier phase counter
* \param ctas_per_group Number of CTAs per group
* \param tx Thread index
* \param out_new_prefix Output: updated prefix after this round
* \param out_new_remaining_k Output: updated remaining_k after this round
*/
template <uint32_t BLOCK_THREADS, bool SINGLE_CTA, typename OrderedType>
__device__ __forceinline__ void RadixSelectOneRound(
const OrderedType* shared_ordered, uint32_t actual_chunk_size, uint32_t* local_histogram,
uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state, uint32_t prefix,
uint32_t remaining_k, uint32_t round, uint32_t iter, int& barrier_phase, uint32_t ctas_per_group,
uint32_t cta_in_group, uint32_t tx, uint32_t* out_new_prefix, uint32_t* out_new_remaining_k) {
constexpr uint32_t RADIX = 256;
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;
constexpr uint32_t RADIX_BITS = 8;
constexpr uint32_t NUM_ROUNDS = ORDERED_BITS / RADIX_BITS;
uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS;
uint32_t global_round = iter * NUM_ROUNDS + round;

// For multi-CTA: pointers to global histograms (triple buffer)
uint32_t* current_hist = nullptr;
uint32_t* next_hist = nullptr;
if constexpr (!SINGLE_CTA) {
current_hist = state->histogram[global_round % 3];
next_hist = state->histogram[(global_round + 1) % 3];
}

min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, MinReduceOp{});
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, MaxReduceOp{});
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
// Clear local histogram only
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
local_histogram[i] = 0;
}
__syncthreads();

// Build local histogram from shared memory
RadixBuildLocalHistogram<BLOCK_THREADS, OrderedType>(shared_ordered, actual_chunk_size,
local_histogram, prefix, shift, round, tx);
__syncthreads();

// For multi-CTA: write → (leading CTA clears next) → barrier → read
// For single-CTA: local_histogram is already the complete histogram
if constexpr (!SINGLE_CTA) {
// Accumulate local histogram to global
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
if (local_histogram[i] > 0) {
atomicAdd(&current_hist[i], local_histogram[i]);
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;

if (aggregate_gt_pivot_1.count >= k) {
low = pivot_1;
sum_low = float(aggregate_gt_pivot_1.value);
} else if (aggregate_gt_pivot_0.count >= k) {
low = pivot_0;
high = min(pivot_1, max_le_high);
sum_low = float(aggregate_gt_pivot_0.value);
} else {
high = min(pivot_0, max_le_high);
}

// Only leading CTA clears next round's histogram BEFORE barrier
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
next_hist[i] = 0;
}
} while (min_gt_low != max_le_high);
}

// Barrier: wait for all CTAs to finish atomicAdd and clearing
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

normalizer = math::ptx_rcp(max(sum_low, 1e-8));
pivot = low;
// Read current histogram (after barrier, all atomicAdds are complete)
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = current_hist[i];
}
} else {
// Single-CTA: copy local histogram directly to suffix_sum
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = local_histogram[i];
}
}
__syncthreads();

// Compute suffix sum
RadixSuffixSum<BLOCK_THREADS>(suffix_sum, tx);

// Find threshold bucket using shared_scalars for found_bucket and found_remaining_k
// shared_scalars[0] = found_bucket, shared_scalars[1] = found_remaining_k
RadixFindThresholdBucket(suffix_sum, remaining_k, &shared_scalars[0], &shared_scalars[1], tx);

// Output new prefix and remaining_k
*out_new_prefix = prefix | (shared_scalars[0] << shift);
*out_new_remaining_k = shared_scalars[1];
}

/*!
* \brief Find the k-th largest element pivot using radix select.
*
* This is the main entry point for the radix select algorithm.
* It performs NUM_ROUNDS of radix select to find the exact pivot value.
*
* \tparam BLOCK_THREADS Number of threads per block
* \tparam VEC_SIZE Vector size for memory access
* \tparam SINGLE_CTA True if single-CTA mode
* \tparam DType Data type (float, half, nv_bfloat16)
*
* \param input Input data pointer (for this row)
* \param shared_ordered Shared memory for ordered values
* \param local_histogram Shared memory for local histogram
* \param suffix_sum Shared memory for suffix sum
* \param shared_scalars Shared memory for temporary scalar values (size >= 2)
* \param state RadixRowState pointer (nullptr if SINGLE_CTA)
* \param chunk_start Start index in vocab for this CTA
* \param actual_chunk_size Number of elements in this chunk
* \param k Number of top elements to select
* \param barrier_phase Reference to barrier phase counter
* \param ctas_per_group Number of CTAs per group
* \param tx Thread index
* \return The pivot value (k-th largest element)
*/
template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType>
__device__ __forceinline__ DType RadixSelectFindPivot(
const DType* input, typename RadixTopKTraits<DType>::OrderedType* shared_ordered,
uint32_t* local_histogram, uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state,
uint32_t chunk_start, uint32_t actual_chunk_size, uint32_t k, int& barrier_phase,
uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, uint32_t iter = 0) {
using Traits = RadixTopKTraits<DType>;
using OrderedType = typename Traits::OrderedType;
constexpr uint32_t RADIX = 256;
constexpr uint32_t RADIX_BITS = 8;
constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>();
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;

// Stage 1: Load and convert to ordered representation
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;
vec_t<DType, VEC_SIZE> data_vec;

// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
data_vec.cast_load(input + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
shared_ordered[i + j] = Traits::ToOrdered(data_vec[j]);
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
shared_ordered[i] = Traits::ToOrdered(input[chunk_start + i]);
}
__syncthreads();

// Initialize prefix and remaining_k
uint32_t prefix = 0;
uint32_t remaining_k = k;

// Initial barrier (skip for single CTA)
// Histograms are pre-cleared externally (Python side) and cleared at end of each iteration
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();
}
}

template <typename DType>
cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr,
uint32_t batch_size, float top_p_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
// Stage 2: NUM_ROUNDS of radix select
// Double buffer with leading CTA clearing at start of each round
for (uint32_t round = 0; round < NUM_ROUNDS; ++round) {
uint32_t new_prefix, new_remaining_k;
RadixSelectOneRound<BLOCK_THREADS, SINGLE_CTA, OrderedType>(
shared_ordered, actual_chunk_size, local_histogram, suffix_sum, shared_scalars, state,
prefix, remaining_k, round, iter, barrier_phase, ctas_per_group, cta_in_group, tx,
&new_prefix, &new_remaining_k);
prefix = new_prefix;
remaining_k = new_remaining_k;
__syncthreads();
}

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
});
// Convert final ordered representation back to DType pivot
return Traits::FromOrdered(static_cast<OrderedType>(prefix));
}

template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType,
typename IdType>
__global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_MultiCTA(
DType* logits, // [batch, vocab_size]
DType* masked_logits, // [batch, vocab_size]
IdType* top_k_arr, // [batch] or nullptr
uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size,
RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA)
uint32_t chunk_size, // elements per CTA
uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA)
{
// Type traits for FP16/BF16/FP32 support
using Traits = RadixTopKTraits<DType>;
using OrderedType = typename Traits::OrderedType;

constexpr uint32_t RADIX = 256; // 8-bit radix
constexpr uint32_t RADIX_BITS = 8;
constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>();
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;

const uint32_t global_cta_id = blockIdx.x;
const uint32_t group_id = global_cta_id / ctas_per_group;
const uint32_t cta_in_group = global_cta_id % ctas_per_group;
const uint32_t tx = threadIdx.x;

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
// Shared memory layout: [fixed storage] [ordered values cache]
extern __shared__ uint8_t smem[];

// Fixed shared memory (at the beginning)
constexpr size_t fixed_smem_size =
sizeof(uint32_t) * (RADIX + RADIX + 4); // histogram + suffix + 4 scalars
uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem);
uint32_t* suffix_sum = local_histogram + RADIX;
uint32_t* shared_scalars =
suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket, found_remaining_k]

// Align ordered values cache to 16 bytes
size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16;
OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset);

// Aliases for scalar shared variables
#define prefix_cache shared_scalars[0]
#define remaining_k_cache shared_scalars[1]
#define found_bucket shared_scalars[2]
#define found_remaining_k shared_scalars[3]

// State pointer only used when not SINGLE_CTA
RadixRowState* state = nullptr;
if constexpr (!SINGLE_CTA) {
state = &row_states[group_id];
}

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups;

int barrier_phase = 0;

// Persistent loop over rows
for (uint32_t iter = 0; iter < total_iterations; iter++) {
uint32_t row_idx = group_id + iter * num_groups;

if (row_idx >= batch_size) break;

const uint32_t chunk_start = cta_in_group * chunk_size;
const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size);

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

DType pivot = Traits::NegInf();

const uint32_t actual_chunk_size = chunk_end - chunk_start;

if (k >= vocab_size) {
// k >= vocab_size: no masking needed, just copy
vec_t<DType, VEC_SIZE> logits_vec_copy;
#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < actual_chunk_size; i += BLOCK_THREADS * VEC_SIZE) {
if (i + VEC_SIZE <= actual_chunk_size) {
logits_vec_copy.cast_load(logits + row_idx * vocab_size + chunk_start + i);
logits_vec_copy.store(masked_logits + row_idx * vocab_size + chunk_start + i);
}
}
// Handle tail
for (uint32_t i = (actual_chunk_size / VEC_SIZE) * VEC_SIZE + tx; i < actual_chunk_size;
i += BLOCK_THREADS) {
masked_logits[row_idx * vocab_size + chunk_start + i] =
logits[row_idx * vocab_size + chunk_start + i];
}
continue;
}

// ========== Stage 1: Load and convert to ordered representation in shared memory ==========
// This is done ONCE per row, avoiding NUM_ROUNDS global memory reads
vec_t<DType, VEC_SIZE> logits_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
logits_vec.cast_load(logits + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
// Use type traits for FP16/BF16/FP32 support
shared_ordered[i + j] = Traits::ToOrdered(logits_vec[j]);
}
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
shared_ordered[i] = Traits::ToOrdered(logits[row_idx * vocab_size + chunk_start + i]);
}
__syncthreads();

// Initialize local caches
if (tx == 0) {
prefix_cache = 0;
remaining_k_cache = k;
}
__syncthreads();

// Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA)
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();
}

// ========== Stage 2: NUM_ROUNDS of radix select ==========
// Triple-buffer optimization: only 1 barrier per round
// - Use global_round = iter * NUM_ROUNDS + round for buffer indexing
// - Only leading CTA clears next buffer before barrier
for (uint32_t round = 0; round < NUM_ROUNDS; ++round) {
uint32_t global_round = iter * NUM_ROUNDS + round;
uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS;
// Read from local cache (no global memory access needed!)
uint32_t prefix = prefix_cache;
uint32_t remaining_k = remaining_k_cache;

// For multi-CTA: pointers to global histograms (triple buffer)
uint32_t* current_hist = nullptr;
uint32_t* next_hist = nullptr;
if constexpr (!SINGLE_CTA) {
current_hist = state->histogram[global_round % 3];
next_hist = state->histogram[(global_round + 1) % 3];
}

// Clear local histogram only
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
local_histogram[i] = 0;
}
__syncthreads();

// Build local histogram from SHARED MEMORY (no global memory access!)
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered = shared_ordered[i];

// Check if this element matches the prefix (high bits determined so far)
// Use generic mask based on OrderedType bits
OrderedType mask =
(round == 0)
? OrderedType(0)
: static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS));
if ((ordered & mask) == static_cast<OrderedType>(prefix)) {
uint32_t bucket = (ordered >> shift) & 0xFF;
atomicAdd(&local_histogram[bucket], 1);
}
}
__syncthreads();

// For multi-CTA: write → (leading CTA clears next) → barrier → read
// For single-CTA: local_histogram is already the complete histogram
if constexpr (!SINGLE_CTA) {
// Accumulate local histogram to global
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
if (local_histogram[i] > 0) {
atomicAdd(&current_hist[i], local_histogram[i]);
}
}

// Only leading CTA clears next round's histogram BEFORE barrier
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
next_hist[i] = 0;
}
}

// Barrier: wait for all CTAs to finish atomicAdd and clearing
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

// Read current histogram (after barrier, all atomicAdds are complete)
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = current_hist[i];
}
} else {
// Single-CTA: copy local histogram directly to suffix_sum
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = local_histogram[i];
}
}
__syncthreads();

// Parallel suffix sum in shared memory (much faster than global memory!)
// Compute count of elements >= each bucket value
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) {
val += suffix_sum[tx + stride];
}
}
__syncthreads();
if (tx < RADIX) {
suffix_sum[tx] = val;
}
__syncthreads();
}

// ALL CTAs: find threshold bucket (all compute same result)
// Use shared variable to communicate the found bucket (via macros to shared_scalars[2..3])
if (tx == 0) {
found_bucket = 0;
found_remaining_k = remaining_k;
}
__syncthreads();

if (tx < RADIX) {
uint32_t count_ge = suffix_sum[tx];
uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0;
if (count_ge >= remaining_k && count_gt < remaining_k) {
found_bucket = tx;
found_remaining_k = remaining_k - count_gt;
}
}
__syncthreads();

// Update local caches (all CTAs have same values)
if (tx == 0) {
prefix_cache = prefix | (found_bucket << shift);
remaining_k_cache = found_remaining_k;
}
__syncthreads();
}

// Convert final ordered representation back to DType pivot using type traits
OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache);
pivot = Traits::FromOrdered(ordered_pivot);

// ========== Stage 3: Final masking pass ==========
// Reuse logits_vec from Stage 1
const DType neg_inf = Traits::NegInf();

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
logits_vec.cast_load(logits + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
logits_vec[j] = (logits_vec[j] >= pivot) ? logits_vec[j] : neg_inf;
}
logits_vec.store(masked_logits + row_idx * vocab_size + chunk_start + i);
}

// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
DType val = logits[row_idx * vocab_size + chunk_start + i];
masked_logits[row_idx * vocab_size + chunk_start + i] = (val >= pivot) ? val : neg_inf;
}
}

// Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
// Only leading CTA clears the buffers using release semantics
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; ++buf) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
st_release(&state->histogram[buf][i], 0u);
}
}

if (tx == 0) {
st_release(&state->arrival_counter, 0);
}
}
}

#undef prefix_cache
#undef remaining_k_cache
#undef found_bucket
#undef found_remaining_k
}

template <typename DType, typename IdType>
cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits, DType* masked_logits, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val,
uint32_t vocab_size, RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);

// Get device properties
int device;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sms;
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device));
int max_smem_per_block;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));

// Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4);
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);

// Calculate max chunk size that fits in shared memory
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);
const bool single_cta = (ctas_per_group == 1);

// Calculate number of groups (how many rows to process concurrently)
uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size);
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
if (single_cta) {
auto kernel =
RadixTopKMaskLogitsKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &vocab_size,
&batch_size, &row_states_buffer, &chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
} else {
auto kernel =
RadixTopKMaskLogitsKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &vocab_size,
&batch_size, &row_states_buffer, &chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
}
});

return cudaSuccess;
}

// ==================== Multi-CTA Radix Top-K Renorm Probs ====================

/*!
* \brief Multi-CTA Radix Top-K RenormProb kernel with unified single/multi-CTA paths.
*
* Finds the k-th largest probability, then normalizes all probs >= pivot to sum to 1,
* setting all others to 0. Uses the shared RadixSelectFindPivot function.
*/
template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType,
typename IdType>
__global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_MultiCTA(
DType* probs, // [batch, vocab_size]
DType* renormed_prob, // [batch, vocab_size]
IdType* top_k_arr, // [batch] or nullptr
uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size,
RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA)
uint32_t chunk_size, // elements per CTA
uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA)
{
using Traits = RadixTopKTraits<DType>;
using OrderedType = typename Traits::OrderedType;

constexpr uint32_t RADIX = 256;

const uint32_t global_cta_id = blockIdx.x;
const uint32_t group_id = global_cta_id / ctas_per_group;
const uint32_t cta_in_group = global_cta_id % ctas_per_group;
const uint32_t tx = threadIdx.x;

// Shared memory layout: [fixed storage] [ordered values cache]
extern __shared__ uint8_t smem[];

// Fixed shared memory (at the beginning)
// histogram[256] + suffix[256] + scalars[4] + sum_local[1]
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 4) + sizeof(float);
uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem);
uint32_t* suffix_sum = local_histogram + RADIX;
uint32_t* shared_scalars = suffix_sum + RADIX;
float* shared_sum = reinterpret_cast<float*>(shared_scalars + 4);

// Align ordered values cache to 16 bytes
size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16;
OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset);

// State pointer only used when not SINGLE_CTA
RadixRowState* state = nullptr;
if constexpr (!SINGLE_CTA) {
state = &row_states[group_id];
}

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups;

int barrier_phase = 0;

// Persistent loop over rows
for (uint32_t iter = 0; iter < total_iterations; iter++) {
uint32_t row_idx = group_id + iter * num_groups;

if (row_idx >= batch_size) break;

const uint32_t chunk_start = cta_in_group * chunk_size;
const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size);
const uint32_t actual_chunk_size = chunk_end - chunk_start;

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

// For RenormProb, pivot is compared with probs (must be non-negative)
DType pivot = DType(0);
float normalizer = 1.0f;

if (k >= vocab_size) {
// k >= vocab_size: no filtering needed, just compute sum and renormalize
// Stage 1: Compute sum
float thread_sum = 0.0f;
vec_t<DType, VEC_SIZE> data_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
thread_sum += float(data_vec[j]);
}
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
thread_sum += float(probs[row_idx * vocab_size + chunk_start + i]);
}

// Block reduction for sum
typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float block_sum = BlockReduce(temp_storage).Sum(thread_sum);
__syncthreads();

if constexpr (!SINGLE_CTA) {
// Multi-CTA: atomic add to global sum
if (tx == 0) {
if (cta_in_group == 0) {
state->sum_topk = 0.0f; // First CTA initializes
}
}
// Barrier for initialization
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

if (tx == 0 && block_sum > 0) {
atomicAdd(&state->sum_topk, block_sum);
}

// Barrier to ensure all CTAs have contributed
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f));
} else {
// Single-CTA: use block_sum directly
if (tx == 0) {
*shared_sum = block_sum;
}
__syncthreads();
normalizer = math::ptx_rcp(max(*shared_sum, 1e-8f));
}

// Normalize and store
#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
data_vec[j] = DType(float(data_vec[j]) * normalizer);
}
data_vec.store(renormed_prob + row_idx * vocab_size + chunk_start + i);
}
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
renormed_prob[row_idx * vocab_size + chunk_start + i] =
DType(float(probs[row_idx * vocab_size + chunk_start + i]) * normalizer);
}
continue;
}

// ========== Stage 1: Find pivot using RadixSelectFindPivot ==========
pivot = RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DType>(
probs + row_idx * vocab_size, shared_ordered, local_histogram, suffix_sum, shared_scalars,
state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, tx,
iter);

// ========== Stage 2: Compute sum of elements >= pivot ==========
float thread_sum = 0.0f;
vec_t<DType, VEC_SIZE> data_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (data_vec[j] >= pivot) {
thread_sum += float(data_vec[j]);
}
}
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
DType val = probs[row_idx * vocab_size + chunk_start + i];
if (val >= pivot) {
thread_sum += float(val);
}
}

// Block reduction for sum
typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
float block_sum = BlockReduce(temp_storage).Sum(thread_sum);
__syncthreads();

if constexpr (!SINGLE_CTA) {
// Multi-CTA: atomic add to global sum
if (tx == 0) {
if (cta_in_group == 0) {
state->sum_topk = 0.0f; // First CTA initializes
}
}
// Barrier for initialization
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

if (tx == 0 && block_sum > 0) {
atomicAdd(&state->sum_topk, block_sum);
}

// Barrier to ensure all CTAs have contributed
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f));
} else {
// Single-CTA: use block_sum directly
if (tx == 0) {
*shared_sum = block_sum;
}
__syncthreads();
normalizer = math::ptx_rcp(max(*shared_sum, 1e-8f));
}

// ========== Stage 3: Normalize elements >= pivot, set others to 0 ==========
#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
data_vec[j] = (data_vec[j] >= pivot) ? DType(float(data_vec[j]) * normalizer) : DType(0);
}
data_vec.store(renormed_prob + row_idx * vocab_size + chunk_start + i);
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
DType val = probs[row_idx * vocab_size + chunk_start + i];
renormed_prob[row_idx * vocab_size + chunk_start + i] =
(val >= pivot) ? DType(float(val) * normalizer) : DType(0);
}
}

// Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
// Only leading CTA clears the buffers using release semantics
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; ++buf) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
st_release(&state->histogram[buf][i], 0u);
}
}

if (tx == 0) {
st_release(&state->arrival_counter, 0);
}
}
}
}

template <typename DType, typename IdType>
cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
cudaError_t RadixTopKRenormProbMultiCTA(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t top_k_val,
uint32_t vocab_size, RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);

// Get device properties
int device;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sms;
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device));
int max_smem_per_block;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));

// Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars + 1 float
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4) + sizeof(float);
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);

// Calculate max chunk size that fits in shared memory
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);
const bool single_cta = (ctas_per_group == 1);

// Calculate number of groups (how many rows to process concurrently)
uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size);
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
if (single_cta) {
auto kernel =
RadixTopKRenormProbKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &vocab_size,
&batch_size, &row_states_buffer, &chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
auto kernel =
RadixTopKRenormProbKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &vocab_size,
&batch_size, &row_states_buffer, &chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}
});

return cudaSuccess;
}

// ==================== Multi-CTA Radix Top-K (Returns Indices) ====================

/*!
* \brief Multi-CTA Radix Top-K kernel that returns indices of top-k elements.
*
* Uses cooperative multi-CTA radix select to find the k-th largest element,
* then collects indices of all elements >= pivot.
*/
template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType,
typename IdType>
__global__ void __launch_bounds__(BLOCK_THREADS)
RadixTopKKernel_MultiCTA(DType* input, // [batch, vocab_size]
IdType* output_indices, // [batch, top_k]
DType* output_values, // [batch, top_k] or nullptr
IdType* top_k_arr, // [batch] or nullptr
uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size,
RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA)
uint32_t chunk_size, // elements per CTA
uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA)
{
// Type traits for FP16/BF16/FP32 support
using Traits = RadixTopKTraits<DType>;
using OrderedType = typename Traits::OrderedType;

constexpr uint32_t RADIX = 256;
constexpr uint32_t RADIX_BITS = 8;
constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>();
constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8;

const uint32_t global_cta_id = blockIdx.x;
const uint32_t group_id = global_cta_id / ctas_per_group;
const uint32_t cta_in_group = global_cta_id % ctas_per_group;
const uint32_t tx = threadIdx.x;

// Shared memory layout: [fixed storage] [ordered values cache]
extern __shared__ uint8_t smem[];

// Fixed shared memory (at the beginning)
// When SINGLE_CTA, we need an extra uint32 for output_counter (no global state)
constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4;
constexpr size_t fixed_smem_size =
sizeof(uint32_t) * (RADIX + RADIX + num_scalars); // histogram + suffix + scalars
uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem);
uint32_t* suffix_sum = local_histogram + RADIX;
uint32_t* shared_scalars = suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket,
// found_remaining_k, (output_counter)]

// Align ordered values cache to 16 bytes
size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16;
OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset);

// Aliases for scalar shared variables
#define prefix_cache shared_scalars[0]
#define remaining_k_cache shared_scalars[1]
#define found_bucket shared_scalars[2]
#define found_remaining_k shared_scalars[3]
#define shared_output_counter shared_scalars[4] // Only valid when SINGLE_CTA

// State pointer only used when not SINGLE_CTA
RadixRowState* state = nullptr;
if constexpr (!SINGLE_CTA) {
state = &row_states[group_id];
}

// Calculate total number of iterations for persistent loop
uint32_t num_groups = gridDim.x / ctas_per_group;
uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups;

int barrier_phase = 0;

// Persistent loop over rows
for (uint32_t iter = 0; iter < total_iterations; iter++) {
uint32_t row_idx = group_id + iter * num_groups;

if (row_idx >= batch_size) break;

const uint32_t chunk_start = cta_in_group * chunk_size;
const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size);

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

const uint32_t actual_chunk_size = chunk_end - chunk_start;

if (k >= vocab_size) {
// k >= vocab_size: return all indices
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
if (chunk_start + i < k) {
output_indices[row_idx * top_k_val + chunk_start + i] =
static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + chunk_start + i] =
input[row_idx * vocab_size + chunk_start + i];
}
}
}
continue;
}

// ========== Stage 1: Load and convert to ordered representation in shared memory ==========
vec_t<DType, VEC_SIZE> input_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
input_vec.cast_load(input + row_idx * vocab_size + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
shared_ordered[i + j] = Traits::ToOrdered(input_vec[j]);
}
}
// Handle tail
for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) {
shared_ordered[i] = Traits::ToOrdered(input[row_idx * vocab_size + chunk_start + i]);
}
__syncthreads();

// Initialize local caches and clear global state
if (tx == 0) {
prefix_cache = 0;
remaining_k_cache = k;
if constexpr (SINGLE_CTA) {
shared_output_counter = 0; // Use shared memory counter for single CTA
}
}
__syncthreads();

// Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA)
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

// CTA 0 clears output counter AFTER barrier (needed for every iteration)
if (cta_in_group == 0 && tx == 0) {
st_release(&state->output_counter, 0);
}
}

// ========== Stage 2: NUM_ROUNDS of radix select ==========
// Triple-buffer optimization: only 1 barrier per round
// - Use global_round = iter * NUM_ROUNDS + round for buffer indexing
// - Only leading CTA clears next buffer before barrier
for (uint32_t round = 0; round < NUM_ROUNDS; ++round) {
uint32_t global_round = iter * NUM_ROUNDS + round;
uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS;
// Read from local cache (no global memory access needed!)
uint32_t prefix = prefix_cache;
uint32_t remaining_k = remaining_k_cache;

// For multi-CTA: pointers to global histograms (triple buffer)
uint32_t* current_hist = nullptr;
uint32_t* next_hist = nullptr;
if constexpr (!SINGLE_CTA) {
current_hist = state->histogram[global_round % 3];
next_hist = state->histogram[(global_round + 1) % 3];
}

// Clear local histogram only
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
local_histogram[i] = 0;
}
__syncthreads();

// Build local histogram from SHARED MEMORY (no global memory access!)
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered = shared_ordered[i];

// Check if this element matches the prefix (high bits determined so far)
OrderedType mask =
(round == 0)
? OrderedType(0)
: static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS));
if ((ordered & mask) == static_cast<OrderedType>(prefix)) {
uint32_t bucket = (ordered >> shift) & 0xFF;
atomicAdd(&local_histogram[bucket], 1);
}
}
__syncthreads();

// For multi-CTA: write → (leading CTA clears next) → barrier → read
// For single-CTA: local_histogram is already the complete histogram
if constexpr (!SINGLE_CTA) {
// Accumulate local histogram to global
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
if (local_histogram[i] > 0) {
atomicAdd(&current_hist[i], local_histogram[i]);
}
}

// Only leading CTA clears next round's histogram BEFORE barrier
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
next_hist[i] = 0;
}
}

// Barrier: wait for all CTAs to finish atomicAdd and clearing
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();

// Read current histogram (after barrier, all atomicAdds are complete)
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = current_hist[i];
}
} else {
// Single-CTA: copy local histogram directly to suffix_sum
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
suffix_sum[i] = local_histogram[i];
}
}
__syncthreads();

// Parallel suffix sum in shared memory
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) {
val += suffix_sum[tx + stride];
}
}
__syncthreads();
if (tx < RADIX) {
suffix_sum[tx] = val;
}
__syncthreads();
}

// ALL CTAs: find threshold bucket (all compute same result)
if (tx == 0) {
found_bucket = 0;
found_remaining_k = remaining_k;
}
__syncthreads();

if (tx < RADIX) {
uint32_t count_ge = suffix_sum[tx];
uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0;
if (count_ge >= remaining_k && count_gt < remaining_k) {
found_bucket = tx;
found_remaining_k = remaining_k - count_gt;
}
}
__syncthreads();

// Update local caches (all CTAs have same values)
if (tx == 0) {
prefix_cache = prefix | (found_bucket << shift);
remaining_k_cache = found_remaining_k;
}
__syncthreads();
}

// Get final ordered pivot from prefix_cache
OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache);

// ========== Stage 3: Collect indices >= pivot ==========
// Two-pass approach to handle ties correctly:
// Pass 1: collect all elements strictly > pivot (these must be in top-k)
// Pass 2: fill remaining slots with elements == pivot
//
// Optimization for Pass 1 (> pivot): Use shared memory atomic to count locally,
// then one global atomic per CTA to get base position, then shared atomic to write.
// This works because all > pivot elements are guaranteed to be in top-k.
//
// For Pass 2 (== pivot): Use global atomic directly since we need cross-CTA
// coordination to respect the k limit (some == pivot elements may be truncated).

// Reuse local_histogram[0..1] as counters
#define local_counter local_histogram[0]
#define global_base local_histogram[1]

// Pass 1: Count elements > pivot locally, then write with one global atomic
if (tx == 0) {
local_counter = 0;
}
__syncthreads();

// First pass: count how many elements > pivot in this CTA
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val > ordered_pivot) {
atomicAdd(&local_counter, 1);
}
}
__syncthreads();

// Get base position for this CTA
uint32_t cta_count_gt = local_counter;
if (tx == 0 && cta_count_gt > 0) {
if constexpr (SINGLE_CTA) {
global_base = atomicAdd(&shared_output_counter, cta_count_gt);
} else {
global_base = atomicAdd(&state->output_counter, cta_count_gt);
}
}
__syncthreads();

// Second pass: write elements > pivot using local shared atomic for position
if (tx == 0) {
local_counter = 0; // Reset for use as write position
}
__syncthreads();

if (cta_count_gt > 0) {
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val > ordered_pivot) {
uint32_t local_pos = atomicAdd(&local_counter, 1);
int pos = global_base + local_pos;
// No need to check pos < k here since all > pivot elements are in top-k
output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val);
}
}
}
}

// Barrier to ensure all > pivot elements are collected first (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
}
__syncthreads();

// Pass 2: Write elements == pivot
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
OrderedType ordered_val = shared_ordered[i];
if (ordered_val == ordered_pivot) {
int pos;
if constexpr (SINGLE_CTA) {
pos = atomicAdd(&shared_output_counter, 1);
} else {
pos = atomicAdd(&state->output_counter, 1);
}
if (pos < static_cast<int>(k)) {
output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val);
}
}
}
}

#undef local_counter
#undef global_base
// No barrier needed here - the barrier at the start of next iteration
// ensures all CTAs complete Stage 3 before output_counter is reset
}

// Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA)
if constexpr (!SINGLE_CTA) {
// Only leading CTA clears the buffers using release semantics
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; ++buf) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
st_release(&state->histogram[buf][i], 0u);
}
}

if (tx == 0) {
st_release(&state->arrival_counter, 0);
}
}
}

#undef prefix_cache
#undef remaining_k_cache
#undef found_bucket
#undef found_remaining_k
#undef shared_output_counter
}

/*!
* \brief Launch multi-CTA Radix Top-K kernel (returns indices and optionally values)
*
* \param input Input tensor [batch_size, vocab_size]
* \param output_indices Output indices tensor [batch_size, top_k]
* \param output_values Output values tensor [batch_size, top_k] or nullptr if not needed
* \param top_k_arr Per-row top-k values or nullptr for uniform top_k
* \param batch_size Number of rows
* \param top_k_val Default top-k value (used when top_k_arr is nullptr)
* \param vocab_size Number of elements per row
* \param row_states_buffer Buffer for inter-CTA synchronization
* \param stream CUDA stream
*/
template <typename DType, typename IdType>
cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* output_values,
IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val,
uint32_t vocab_size, RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);

int device;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sms;
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device));
int max_smem_per_block;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));

// Fixed smem: histogram[256] + suffix_sum[256] + scalars (5 for single-CTA path)
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5);
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);

const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

// Determine if we use single-CTA path
const bool single_cta = (ctas_per_group == 1);

// Calculate smem_size
const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);

// Calculate number of groups (how many rows to process concurrently)
uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size);
if (num_groups == 0) num_groups = 1;
uint32_t total_ctas = num_groups * ctas_per_group;

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
if (single_cta) {
auto kernel = RadixTopKKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&input, &output_indices, &output_values, &top_k_arr,
&top_k_val, &vocab_size, &batch_size, &row_states_buffer,
&chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
auto kernel = RadixTopKKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 nblks(total_ctas);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&input, &output_indices, &output_values, &top_k_arr,
&top_k_val, &vocab_size, &batch_size, &row_states_buffer,
&chunk_size, &ctas_per_group};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
}
});

return cudaSuccess;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rerun clang-format and commit the formatted changes to unblock CI
The pre-commit hook detected formatting issues in the file. Please run clang-format locally and push the updated file to resolve the CI failure.

🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 2030-2032: clang-format formatting check failed. Some files were modified by this hook.

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 1823 to 3373 the pre-commit hook
flagged formatting issues; run your project's clang-format configuration on this
file to apply the expected style, stage the updated file, create a commit with
the formatted changes, and push it to the branch so CI/pre-commit checks pass.
Ensure you use the repo’s clang-format config (or the same pre-commit hook) so
the exact formatting rules are applied before committing.

upd

upd
@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !190 has been updated with latest changes, and the CI pipeline #40074508 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !190 has been updated with latest changes, and the CI pipeline #40074623 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
include/flashinfer/sampling.cuh (3)

2843-2857: Same shared memory calculation bug as RadixTopKMaskLogitsMultiCTA.

Lines 2851-2852 have the identical issue:

const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

This can force chunk_size to exceed available shared memory. Apply the same fix as suggested for lines 2509-2526.


3313-3326: Same shared memory calculation bug affects this launcher too.

Lines 3320-3321:

const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

Apply the same validation fix as suggested for the other launchers.


3216-3249: Out-of-bounds writes in index collection phase.

Both passes of the index collection phase write to:

output_indices[row_idx * top_k_val + pos]
output_values[row_idx * top_k_val + pos]  // if not nullptr

But pos can reach k which may exceed top_k_val when top_k_arr != nullptr. The check at line 3245 (if (pos < static_cast<int>(k))) doesn't prevent the issue—it should check pos < top_k_val.

This is part of the same validation issue flagged for lines 2975-2992.

🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)

1840-1852: Comment clarification is helpful but could be more precise.

The existing review comment suggested adding clarification about the reduction. The current comment at line 1844-1845 states "Release pattern using acq_rel fence + relaxed modifier" but doesn't explicitly mention that this also performs an atomic add (reduction). Consider expanding:

 __device__ __forceinline__ void red_release(int* ptr, int val) {
 #if (__CUDA_ARCH__ >= 700)
   // SM70 and newer use memory consistency qualifiers
-  // Release pattern using acq_rel fence + relaxed modifier
+  // Release pattern using acq_rel fence + relaxed atomic add (reduction + release)
   // (The fence also releases data that was weakly-written by other threads prior to the last
   // syncthreads)

Based on learnings, this addresses the past review comment suggesting clarification.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3843536 and c912dec.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/sampling.cuh (6)

119-204: Type traits implementation looks solid.

The RadixTopKTraits specializations correctly implement IEEE 754 bit manipulation for descending order:

  • Positive values: flip sign bit to maintain order
  • Negative values: flip all bits to reverse order
  • The negative infinity constants are correct for each format

The num_rounds() template function appropriately computes the number of radix iterations based on the ordered type size and radix bits.


1826-1875: Synchronization primitives are correctly implemented.

The acquire/release memory ordering for SM70+ and conservative fallbacks for older architectures are appropriate. The wait_ge function correctly has only thread 0 spin to avoid warp divergence overhead, followed by __syncthreads() to broadcast the result.


1880-1887: RadixRowState design is efficient.

The triple-buffered histograms enable pipelining without extra synchronization barriers, which is a good optimization. The struct size (~3KB per state) is reasonable for the coordination it provides.

[approve_code_challenges]


1900-2189: Device helper functions are well-structured and correct.

The radix selection algorithm is properly decomposed:

  • RadixSuffixSum: Efficient parallel suffix sum
  • RadixFindThresholdBucket: Correct threshold detection
  • RadixBuildLocalHistogram: Proper prefix matching with atomic operations
  • RadixSelectOneRound: Complex but correct multi-CTA coordination with triple buffering
  • RadixSelectFindPivot: Clean entry point

The triple-buffering pattern (using global_round % 3) with the leading CTA clearing the next buffer before the barrier is a clever optimization that avoids race conditions.


2634-2713: Fast-path optimization for k >= vocab_size is well-implemented.

The sum-and-normalize fast path correctly handles the case where no filtering is needed. The multi-CTA sum aggregation with proper barrier synchronization (initialize → barrier → atomic adds → barrier → read) is correct.


2321-2442: Inline radix select implementation is correct and well-optimized.

The kernel inlines the radix selection logic with local caches (prefix_cache, remaining_k_cache) to avoid repeated global memory access. The triple-buffering with global_round = iter * NUM_ROUNDS + round correctly handles the persistent loop over multiple rows. The barrier choreography (write → leading CTA clears → barrier → read) matches the helper function and is race-free.

Comment on lines +2509 to +2526
// Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4);
constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16);

// Calculate max chunk size that fits in shared memory
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);

const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);
const bool single_cta = (ctas_per_group == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Shared memory calculation can exceed available capacity.

The logic at lines 2517-2518 can inflate max_chunk_elements beyond the available shared memory:

const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);

If available_for_ordered is small (e.g., 10K elements worth), but min_chunk_size is large (e.g., 16K elements for vec_size=16, BLOCK_THREADS=1024), then max_chunk_elements becomes 16K. Subsequently:

uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements);  // Still 16K!

The final smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType) will exceed max_smem_per_block, causing cudaErrorInvalidConfiguration at kernel launch.

Apply the fix suggested in the past review:

   const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
   uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
   max_chunk_elements = round_down(max_chunk_elements, vec_size);
   const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
-  max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
+  if (max_chunk_elements < min_chunk_size) {
+    return cudaErrorInvalidConfiguration;
+  }

   uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
   uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
   chunk_size = round_up(chunk_size, vec_size);
-  chunk_size = std::min(chunk_size, max_chunk_elements);

Also applies to: lines 2843-2857 (RadixTopKRenormProbMultiCTA) and lines 3313-3326 (RadixTopKMultiCTA).

Based on learnings, this matches the critical issue flagged by gemini-code-assist in past reviews.

🤖 Prompt for AI Agents
include/flashinfer/sampling.cuh lines 2509-2526: the code allows
max_chunk_elements to be raised above the actual available_for_ordered by taking
std::max with min_chunk_size; to fix, compute available_limit =
round_down(available_for_ordered / sizeof(OrderedType), vec_size) first, then
cap the minimum chunk size to that limit (if min_chunk_size > available_limit
set min_chunk_size = available_limit) and finally set max_chunk_elements =
std::min(available_limit, /*previous computation*/ available_limit) or simply
use max_chunk_elements = available_limit after rounding; ensure chunk_size
calculation remains bounded by available_limit so smem_size cannot exceed
max_smem_per_block; apply the same clamping change to the analogous blocks at
lines 2843-2857 and 3313-3326.

Comment on lines 2975 to 2992
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

const uint32_t actual_chunk_size = chunk_end - chunk_start;

if (k >= vocab_size) {
// k >= vocab_size: return all indices
for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) {
if (chunk_start + i < k) {
output_indices[row_idx * top_k_val + chunk_start + i] =
static_cast<IdType>(chunk_start + i);
if (output_values != nullptr) {
output_values[row_idx * top_k_val + chunk_start + i] =
input[row_idx * vocab_size + chunk_start + i];
}
}
}
continue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Out-of-bounds writes when per-row k exceeds top_k_val.

At line 2975, k is read from top_k_arr[row_idx] which can differ from top_k_val:

uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];

However, output buffers are allocated with stride top_k_val:

  • output_indices[row_idx * top_k_val + pos] (lines 2983, 3216, 3246)
  • output_values[row_idx * top_k_val + pos] (lines 2986, 3218, 3248)

If any top_k_arr[row_idx] > top_k_val, writes will overflow into the next row's buffer. For example, at lines 2982-2989:

if (chunk_start + i < k) {
  output_indices[row_idx * top_k_val + chunk_start + i] = ...

When k > top_k_val and chunk_start + i >= top_k_val, this writes beyond the allocated row.

Add validation in the host launcher (lines 3297-3364):

template <typename DType, typename IdType>
cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* output_values,
                              IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val,
                              uint32_t vocab_size, RadixRowState* row_states_buffer,
                              cudaStream_t stream = 0) {
  using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
  constexpr uint32_t BLOCK_THREADS = 1024;
  const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);
  
  // Validate top_k_arr if provided
  if (top_k_arr != nullptr) {
    // Check that all top_k_arr[i] <= top_k_val
    // For now, document the requirement or add host-side validation
    // Kernel cannot safely validate this without extra overhead
  }

Alternatively, change the output buffer allocation to use max(top_k_arr) instead of top_k_val, but that requires upstream changes.

Also applies to lines 2982-2989 (same kernel, k >= vocab_size case).

Based on learnings, this matches the major issue flagged by coderabbitai in past reviews.

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2975-2992 there is an
out-of-bounds write when per-row k is read from top_k_arr[row_idx] but output
rows are laid out with stride top_k_val; if top_k_arr[row] > top_k_val writes
overflow the next row. Fix by validating top_k_arr in the host launcher
(RadixTopKMultiCTA around lines 3297-3364): if top_k_arr != nullptr iterate the
array on the host to ensure every top_k_arr[i] <= top_k_val and return a
cudaError / fail early (or document and assert) when violated; alternatively
change the allocation/stride to use max(top_k_arr) instead of top_k_val if
upstream allocation can be adjusted, but prefer the host-side validation and
explicit error return to prevent kernel OOB writes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/logits_processor/operators.py (1)

121-142: Cast maybe_top_k_arr to int32 before calling the native module.
These operators call get_sampling_module().top_k_{...} directly (Line 139-141, 185-187) but don’t apply the .int() cast that exists in flashinfer/sampling.py. If a user passes top_k as an int64 tensor, this can break or produce wrong reads in C++.

         top_k = self._get_param("top_k", kwargs, required=True)
         maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k)
@@
+        if maybe_top_k_arr is not None:
+            maybe_top_k_arr = maybe_top_k_arr.int()
@@
         renorm_probs = get_sampling_module().top_k_renorm_probs(
             tensor.data, maybe_top_k_arr, top_k_val, row_states_buffer
         )

Apply the same in LogitsTopKOp.__call__.

Also applies to: 167-187

♻️ Duplicate comments (4)
flashinfer/topk.py (1)

69-165: Concurrency hazard: cached row_states_buffer is shared across streams on the same device.
row_states_buffer is a single per-device cache entry (Line 138-145), but kernels mutate it for inter-CTA sync. Concurrent top_k() calls on different CUDA streams can trample each other.

Concrete options:

  • include the current stream in the cache key (best-effort in Python), or
  • allocate per-call when concurrency is expected, or
  • clearly document “single-stream per device” usage constraints.
include/flashinfer/sampling.cuh (2)

2509-2524: Shared-memory sizing bug can force invalid configurations (must clamp, don’t std::max).
These blocks can inflate max_chunk_elements beyond what available_for_ordered supports (Line 2518-2519, 2851-2853, 3314-3316), which can make smem_size exceed cudaDevAttrMaxSharedMemoryPerBlockOptin and fail launches.

Use the pattern from the prior review: if max_chunk_elements < min_chunk_size, return cudaErrorInvalidConfiguration (or reduce BLOCK_THREADS/vec) rather than forcing max_chunk_elements upward.

Also applies to: 2843-2857, 3311-3321


2975-2988: Potential OOB writes when top_k_arr[row] > top_k_val (validate contract in launcher).
Kernel indexes outputs with stride top_k_val but uses per-row k from top_k_arr (Line 2975, writes at Line 2983-2986 and Line 3214-3216 / 3242-3244). If any k > top_k_val, this will write past the row’s allocation.

Fix: in RadixTopKMultiCTA(...) (Line 3290+), if top_k_arr != nullptr, validate max(top_k_arr) <= top_k_val on host (or document + assert in debug builds) before launching.

Also applies to: 3213-3216, 3241-3244, 3290-3358

flashinfer/logits_processor/operators.py (1)

132-138: Same multi-stream race risk: cached row_states_buffer is shared per device.
Even with op-specific names, each op still shares a single buffer across streams for that op+device. Concurrent calls can collide.

Also applies to: 178-184

🧹 Nitpick comments (2)
flashinfer/topk.py (1)

54-63: Static analysis: unused args in fake op — either prefix with _ or ignore in linter config.
Ruff flags row_states_buffer/output_values as unused in _fake_radix_topk (Line 55-62). Consider renaming to _row_states_buffer / _output_values.

include/flashinfer/sampling.cuh (1)

1866-1875: Avoid hard spin in wait_ge (reduce contention / TDR risk).
while (ld_acquire(ptr) < target_val) {} (Line 1871-1872) can saturate memory and SM issue slots. Consider inserting __nanosleep() / exponential backoff, and/or a max-iteration safety break in debug builds.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c912dec and 00b8bcf.

📒 Files selected for processing (6)
  • csrc/flashinfer_topk_binding.cu (1 hunks)
  • csrc/topk.cu (1 hunks)
  • flashinfer/logits_processor/operators.py (2 hunks)
  • flashinfer/topk.py (1 hunks)
  • include/flashinfer/sampling.cuh (2 hunks)
  • tests/utils/test_logits_processor.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/flashinfer_topk_binding.cu
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🧬 Code graph analysis (3)
flashinfer/logits_processor/operators.py (2)
flashinfer/utils.py (1)
  • _get_cache_buf (206-217)
flashinfer/sampling.py (4)
  • top_k_renorm_probs (375-394)
  • top_k_renorm_probs (1320-1390)
  • top_k_mask_logits (410-430)
  • top_k_mask_logits (1397-1465)
csrc/topk.cu (3)
flashinfer/topk.py (1)
  • top_k (69-164)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (294-296)
flashinfer/topk.py (4)
flashinfer/jit/topk.py (1)
  • gen_topk_module (21-28)
flashinfer/utils.py (1)
  • _get_cache_buf (206-217)
csrc/flashinfer_topk_binding.cu (1)
  • radix_topk (20-21)
csrc/topk.cu (2)
  • radix_topk (24-60)
  • radix_topk (24-25)
🪛 Ruff (0.14.8)
flashinfer/topk.py

58-58: Unused function argument: row_states_buffer

(ARG001)


59-59: Unused function argument: output_values

(ARG001)

🔇 Additional comments (3)
tests/utils/test_logits_processor.py (2)

642-642: LGTM: Appropriate tolerance for floating-point probability comparisons.

The change from strict equality to torch.allclose is reasonable given the multi-CTA optimizations and expanded data-type support introduced in this PR. Default tolerances (rtol=1e-5, atol=1e-8) are appropriate for comparing normalized probabilities.


666-666: Verify the necessity of equal_nan=True in top-k logits masking assertion.

The test includes equal_nan=True when comparing LogitsPipe's TopK implementation against the reference flashinfer.sampling.top_k_mask_logits. Both implementations are documented to set non-top-k values to -inf, yet the parameter suggests NaN values may be produced. Given that this test parametrizes neginf_input to specifically test edge cases with -inf values in the logits, clarify whether:

  • The multi-CTA kernel produces NaN in specific edge cases (e.g., all logits are -inf, fewer than k finite values)
  • NaN production is intentional behavior or an unintended side effect
  • equal_nan=True is a temporary workaround or the correct long-term solution

If NaN production is unintended, the kernel should be fixed to consistently produce -inf instead.

include/flashinfer/sampling.cuh (1)

119-204: Traits look consistent; good foundation for fp16/bf16 support.
The ordered transforms and NegInf() encodings are sensible and keep the radix code generic across float/half/bf16.

Comment on lines +24 to +35
void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(output_values);
CHECK_DIM(2, input); // input: (batch_size, d)
CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k)
CHECK_DIM(2, output_values); // output_values: (batch_size, top_k)

unsigned int batch_size = input.size(0);
unsigned int d = input.size(1);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add basic shape + top_k range validation (prevents kernel OOB / nonsense launches).
Only rank is validated today. At minimum, enforce:

  • top_k > 0
  • top_k <= input.size(1)
  • output_{indices,values}.size(1) == top_k
  • batch dims match across tensors

This avoids silent misbehavior when callers pass a mismatched k vs output shapes.

🤖 Prompt for AI Agents
In csrc/topk.cu around lines 24 to 35, add explicit shape and top_k range
validation: check top_k > 0 and top_k <= input.size(1), verify
output_indices.size(0) and output_values.size(0) equal input.size(0), verify
output_indices.size(1) and output_values.size(1) equal top_k, and ensure
maybe_row_states_buffer (if present) has a compatible batch dimension; use the
same CHECK_* macros (or error paths) already used in the file so failures are
reported consistently and prevent launching kernels with out-of-bounds sizes.

Comment on lines +24 to +60
void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(output_values);
CHECK_DIM(2, input); // input: (batch_size, d)
CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k)
CHECK_DIM(2, output_values); // output_values: (batch_size, top_k)

unsigned int batch_size = input.size(0);
unsigned int d = input.size(1);

cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

cudaError_t status;
auto dtype = input.dtype();

// Get row_states_buffer if provided (for multi-CTA path)
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<c_type*>(output_values.data_ptr()),
nullptr, // top_k_arr
batch_size, static_cast<uint32_t>(top_k), d, row_states_ptr, stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RadixTopK failed with error code " << cudaGetErrorString(status);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix BF16 dispatch + initialize status to avoid UB / wrong dtype behavior.
Right now the kernel dispatch only covers FP32/FP16 (Line 49), but the Python API advertises BF16 support. Also status is uninitialized if dispatch doesn’t run, and the final TVM_FFI_ICHECK(status == cudaSuccess) becomes undefined behavior.

 void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
                 Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
@@
-  cudaSetDevice(input.device().device_id);
+  TVM_FFI_ICHECK(cudaSetDevice(input.device().device_id) == cudaSuccess);
   auto stream = get_stream(input.device());
 
-  cudaError_t status;
+  cudaError_t status = cudaErrorInvalidValue;
   auto dtype = input.dtype();
@@
-  DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
+  DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16(dtype, c_type, [&] {
     status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
@@
   TVM_FFI_ICHECK(status == cudaSuccess)
       << "RadixTopK failed with error code " << cudaGetErrorString(status);
 }

If there is no ..._BF16 dispatch macro available, you should either add it (preferred, given nv_bfloat16 support in include/flashinfer/sampling.cuh) or explicitly reject bf16 here with a clear error before launch.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In csrc/topk.cu around lines 24 to 60, status is left uninitialized and the
dtype dispatch only covers FP32/FP16, so BF16 calls lead to UB or wrong
behavior; initialize status (e.g., to cudaErrorInvalidValue) before the dispatch
to avoid undefined reads, then extend the dispatch to include BF16 by using the
BF16-capable dispatch macro (or add a
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 variant that maps nv_bfloat16 to
the correct c_type and launches sampling::RadixTopKMultiCTA for bf16), and if
you cannot add a BF16 dispatch, explicitly check for bf16 and return a clear
error (set status to an appropriate cudaError and fail fast) before the kernel
launch so the final TVM_FFI_ICHECK sees a defined value.

Comment on lines +31 to +53
@register_custom_op(
"flashinfer::radix_topk", mutates_args=("row_states_buffer", "output_values")
)
def radix_topk(
input: torch.Tensor,
top_k: int,
row_states_buffer: Optional[torch.Tensor],
output_values: torch.Tensor,
) -> torch.Tensor:
device = input.device
# Supports float32, float16, bfloat16
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
)
batch_size = input.size(0)
output_indices = torch.empty(
batch_size, top_k, dtype=torch.int32, device=device
)
module.radix_topk(
input, output_indices, output_values, row_states_buffer, top_k
)
return output_indices

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace assert dtype guard with a real exception.
assert can be stripped with python -O, turning invalid dtypes into hard-to-debug native crashes.

-        assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
-            f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
-        )
+        if input.dtype not in (torch.float32, torch.float16, torch.bfloat16):
+            raise TypeError(
+                f"Unsupported dtype {input.dtype}, expected float32/float16/bfloat16"
+            )
🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 31 to 53, replace the assert-based dtype
guard with an explicit runtime exception: check if input.dtype is one of
(torch.float32, torch.float16, torch.bfloat16) and if not raise a TypeError (or
ValueError) with the existing descriptive message (e.g. f"Unsupported dtype
{input.dtype}, expected float32, float16, or bfloat16"); do not use assert so
the check remains active under optimized execution.

Comment on lines +69 to +165
def top_k(
input: torch.Tensor,
k: int,
sorted: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Radix-based Top-K selection.
This function selects the top-k largest elements from each row of the input
tensor. It uses an efficient radix-based selection algorithm that is
particularly fast for large vocabularies.
This is designed as a drop-in replacement for ``torch.topk`` with better
performance for large tensors (vocab_size > 10000).
Parameters
----------
input : torch.Tensor
Input tensor of shape ``(batch_size, d)`` containing the values to select from.
Supported dtypes: ``float32``, ``float16``, ``bfloat16``.
k : int
Number of top elements to select from each row.
sorted : bool, optional
If True, the returned top-k elements will be sorted in descending order.
Default is False (unsorted, which is faster).
Returns
-------
values : torch.Tensor
Tensor of shape ``(batch_size, k)`` containing the top-k values.
Same dtype as input.
indices : torch.Tensor
Tensor of shape ``(batch_size, k)`` with int64 dtype containing the
indices of the top-k elements.
Note
----
- Unlike ``torch.topk``, the default behavior returns unsorted results for
better performance. Set ``sorted=True`` if you need sorted output.
- The radix-based algorithm is O(n) in vocabulary size, compared to O(n log k)
for heap-based methods, making it faster for large vocabularies.
- For small vocabularies (< 1000), ``torch.topk`` may be faster.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
>>> batch_size = 4
>>> vocab_size = 32000
>>> k = 256
>>> logits = torch.randn(batch_size, vocab_size, device="cuda")
>>> values, indices = flashinfer.top_k(logits, k)
>>> values.shape, indices.shape
(torch.Size([4, 256]), torch.Size([4, 256]))
With sorting enabled (for compatibility with torch.topk):
>>> values_sorted, indices_sorted = flashinfer.top_k(logits, k, sorted=True)
>>> # Values are now in descending order within each row
See Also
--------
torch.topk : PyTorch's built-in top-k function
sampling.top_k_mask_logits : Top-k masking for logits (sets non-top-k to -inf)
sampling.top_k_renorm_probs : Top-k filtering and renormalization for probabilities
"""
batch_size = input.size(0)
device = input.device

# Allocate row_states buffer for multi-CTA path
# 1MB is enough for any reasonable GPU (covers up to ~500 groups)
row_states_buffer: Optional[torch.Tensor] = _get_cache_buf(
f"radix_topk_row_states_{input.device}",
1024 * 1024, # 1MB
input.device,
zero_init=True,
)

# Allocate output_values for kernel to write directly
output_values = torch.empty(batch_size, k, dtype=input.dtype, device=device)

# Get indices using radix-based selection
indices_int32 = get_topk_module().radix_topk(
input, k, row_states_buffer, output_values
)

# Convert to int64 for compatibility
indices = indices_int32.long()

if sorted:
# Sort within each row by value (descending)
sorted_values, sort_indices = torch.sort(output_values, dim=-1, descending=True)
sorted_indices = torch.gather(indices, dim=-1, index=sort_indices)
return sorted_values, sorted_indices

return output_values, indices

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate input rank + k bounds in Python (fast fail, better errors).
Add checks for input.ndim == 2, k > 0, and k <= input.size(1) before allocating outputs.

🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 69 to 165, add an early fast-fail input
validation before any buffer/allocation: verify input.ndim == 2 and raise a
ValueError with a clear message if not; verify k is an int > 0 and raise
ValueError if not; verify k <= input.size(1) and raise ValueError if exceeded.
Perform these checks immediately after computing batch_size/device (or at the
top of the function) so allocations and kernel calls are skipped on invalid
inputs, using concise, user-friendly error messages.

Comment on lines 821 to 823
# Allow small differences due to floating point precision in intermediate steps
diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider tightening the tolerance for deterministic sampling tests.

Allowing up to 1% of samples to differ seems lenient for a test using deterministic generators. While floating-point precision differences in intermediate probability normalization steps can affect boundary cases, consider:

  1. Tightening the threshold: For batch_size=989, this allows ~10 samples to differ. Could this be reduced to 0.1% (< 1 sample per 1000)?
  2. Adding diagnostics: Log which samples differ and their probability differences to distinguish genuine floating-point edge cases from correctness issues.
  3. Documenting expectations: Expand the comment to explain specific scenarios where differences are expected (e.g., ties near the top-k or top-p boundaries).

Apply this diff to add diagnostic output:

     # Allow small differences due to floating point precision in intermediate steps
     diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
-    assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%"
+    assert diff_ratio < 0.01, (
+        f"Too many differences: {diff_ratio * 100:.2f}% "
+        f"({(samples_pipe != samples_direct).sum().item()} out of {batch_size} samples differ)"
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Allow small differences due to floating point precision in intermediate steps
diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%"
# Allow small differences due to floating point precision in intermediate steps
diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
assert diff_ratio < 0.01, (
f"Too many differences: {diff_ratio * 100:.2f}% "
f"({(samples_pipe != samples_direct).sum().item()} out of {batch_size} samples differ)"
)
🤖 Prompt for AI Agents
In tests/utils/test_logits_processor.py around lines 821 to 823, the
deterministic sampling test tolerance is too permissive (1%); tighten it to 0.1%
(0.001) to reduce allowed differing samples, add diagnostic output that logs the
indices of differing samples and a small summary of their probability
differences (e.g., count, max diff, example diffs) to help distinguish FP edge
cases from logic bugs, and expand the existing comment to explain when small
differences are expected (ties near top-k/top-p boundaries and floating-point
normalization rounding). Ensure the assertion message includes the diagnostic
summary when it fails so CI shows which samples differ.

@yzh119
Copy link
Collaborator Author

yzh119 commented Dec 12, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !190 has been updated with latest changes, and the CI pipeline #40075621 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/utils/test_logits_processor.py (1)

821-823: The 2% mismatch allowance is very permissive; tighten + add diagnostics
This is the same concern previously raised (then 1%); increasing to 2% further reduces the test’s ability to catch regressions. Add failure diagnostics so CI tells you which samples differ.

         # Allow small differences due to floating point precision in intermediate steps
         diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
-        assert diff_ratio < 0.02, f"Too many differences: {diff_ratio * 100:.2f}%"
+        diff_cnt = (samples_pipe != samples_direct).sum().item()
+        diff_idx = (samples_pipe != samples_direct).nonzero(as_tuple=False).flatten()
+        assert diff_ratio < 0.001, (
+            f"Too many differences: {diff_ratio * 100:.2f}% "
+            f"({diff_cnt}/{batch_size}); first_diff_idx={diff_idx[:10].tolist()}"
+        )
🧹 Nitpick comments (2)
tests/utils/test_logits_processor.py (2)

642-642: Prefer explicit tolerances (and/or torch.testing.assert_close) instead of default torch.allclose
Default rtol/atol can be surprisingly strict/loose depending on dtype/device and PyTorch version; this can turn into CI flakiness or masked regressions.

-        assert torch.allclose(samples_pipe, samples_direct)
+        torch.testing.assert_close(samples_pipe, samples_direct, rtol=1e-4, atol=1e-6)

666-666: For logits masking, validate mask pattern explicitly (don’t only rely on allclose)
If this is validating top_k_mask_logits, it’s useful to assert that the “masked positions” (e.g., -inf) match exactly, then use assert_close for the remaining finite values.

-        assert torch.allclose(samples_pipe, samples_direct, equal_nan=True)
+        # masked pattern should match exactly (e.g., -inf locations)
+        assert torch.equal(torch.isneginf(samples_pipe), torch.isneginf(samples_direct))
+        # remaining finite values should be close
+        torch.testing.assert_close(
+            samples_pipe[~torch.isneginf(samples_pipe)],
+            samples_direct[~torch.isneginf(samples_direct)],
+            rtol=1e-4,
+            atol=1e-6,
+        )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 00b8bcf and 63f07f1.

📒 Files selected for processing (1)
  • tests/utils/test_logits_processor.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@yzh119 yzh119 merged commit f6a9899 into flashinfer-ai:main Dec 12, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants