-
Notifications
You must be signed in to change notification settings - Fork 593
perf: bunch of features and optimizations for top-k (sampling + sparse attention) #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThreads a per-row device-side row_states_buffer through Python wrappers, TVM bindings, CUDA callers, and kernels; adds radix multi‑CTA Top‑K kernels, inter‑CTA sync primitives, dtype dispatch helpers, JIT wiring, cache buffer options, and expanded tests for FP32/FP16/BF16 and large‑batch cases. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python caller
participant PyAPI as flashinfer.topk / sampling wrapper
participant JIT as JIT module loader
participant TVMBind as TVM CUDA binding
participant Kernel as Radix Multi‑CTA Kernel (device)
participant RowState as RadixRowState (device buffer)
User->>PyAPI: call top_k / top_k_renorm_probs / top_k_mask_logits
alt no row_states_buffer
PyAPI->>PyAPI: allocate row_states_buffer (~1MB) via _get_cache_buf
end
PyAPI->>JIT: ensure topk module loaded / get binding
JIT->>TVMBind: call radix_topk / module.top_k_* (ptrs, maybe_row_states_buffer, k)
TVMBind->>TVMBind: validate shapes/dtypes, select device & stream
TVMBind->>Kernel: launch RadixTopKMultiCTA / RadixTopKRenormProbMultiCTA / RadixTopKMaskLogitsMultiCTA(...)
Kernel->>RowState: initialize/use per-row RadixRowState, CTAs coordinate via atomics
Kernel-->>TVMBind: write outputs (indices, values, or renormed/masked arrays)
TVMBind-->>PyAPI: return results
PyAPI-->>User: postprocess (dtype conv, optional sort) and return
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance enhancement for the top-k/top-p renormalization and sampling routines. By refactoring the CUDA kernels to leverage a multi-CTA architecture and prioritizing shared memory usage for logits and probabilities, the changes aim to drastically reduce memory latency and improve throughput, especially for operations involving multiple scan rounds. This optimization builds upon previous work and focuses on efficient parallel processing across the GPU. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a multi-CTA optimization for top-k/top-p sampling, which is a significant performance enhancement. The implementation is well-structured, leveraging advanced CUDA features like inter-CTA synchronization via atomic operations and memory fences to efficiently process large vocabularies. The changes in the Python bindings and C++ interface are consistent with the new kernel's requirements. However, I've identified a critical bug in the kernel launch configuration logic within TopKMaskLogitsMultiCTA. The calculation for chunk_size can lead to requesting more shared memory than available, which would cause kernel launch failures under certain conditions. I have provided a detailed comment with a suggested fix for this issue. Overall, this is a great performance improvement, and with the suggested fix, it should be robust.
include/flashinfer/sampling.cuh
Outdated
| constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| // Calculate how many CTAs needed per row | ||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| // Round up chunk_size to multiple of VEC_SIZE | ||
| chunk_size = round_up(chunk_size, VEC_SIZE); | ||
| // Ensure minimum chunk size | ||
| chunk_size = std::max(chunk_size, min_chunk_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for calculating chunk_size is incorrect and can lead to requesting more shared memory than is available, causing a kernel launch failure.
Specifically:
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);at line 2468 can inflatemax_chunk_elementsbeyond the available shared memory.chunk_size = std::max(chunk_size, min_chunk_size);at line 2476 can similarly causechunk_sizeto exceed shared memory limits, as it ignores themax_chunk_elementsconstraint.
This can happen if the available shared memory is small, making max_chunk_elements smaller than min_chunk_size.
I suggest replacing this block with logic that validates against min_chunk_size instead of forcing it, to ensure the kernel configuration is always valid.
constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS;
if (max_chunk_elements < min_chunk_size) {
// Not enough shared memory for even the minimum chunk size.
return cudaErrorInvalidConfiguration;
}
// Calculate how many CTAs needed per row
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
// Round up chunk_size to multiple of VEC_SIZE
chunk_size = round_up(chunk_size, VEC_SIZE);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/flashinfer_sampling_binding.cu(1 hunks)csrc/renorm.cu(1 hunks)flashinfer/sampling.py(3 hunks)flashinfer/utils.py(1 hunks)include/flashinfer/sampling.cuh(1 hunks)include/flashinfer/utils.cuh(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(272-274)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
register_custom_op(319-328)register_custom_op(338-357)_get_cache_buf(206-217)csrc/flashinfer_sampling_binding.cu (1)
top_k_mask_logits(57-59)csrc/renorm.cu (2)
top_k_mask_logits(61-84)top_k_mask_logits(61-63)flashinfer/logits_processor/operators.py (1)
_to_tensor_scalar_tuple(28-34)
🪛 Ruff (0.14.5)
flashinfer/sampling.py
407-407: Unused function argument: row_states_buffer
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
|
||
| def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: | ||
| def _get_cache_buf( | ||
| name: str, bytes: int, device: torch.device, zero_init: bool = False | ||
| ) -> torch.Tensor: | ||
| key = (name, device) | ||
| buf = _cache_buf.get(key) | ||
| if buf is None or buf.size(0) < bytes: | ||
| buf = torch.empty(bytes, dtype=torch.uint8, device=device) | ||
| if zero_init: | ||
| buf = torch.zeros(bytes, dtype=torch.uint8, device=device) | ||
| else: | ||
| buf = torch.empty(bytes, dtype=torch.uint8, device=device) | ||
| _cache_buf[key] = buf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zero-initialized cache must be cleared on reuse
When zero_init=True we only zero the tensor on first allocation; subsequent reuses skip the zero-fill. The new multi-CTA top-k path stores RowReductionState objects in this cache and assumes every launch starts from a fully cleared buffer. A reused buffer now comes back with stale counters/min/max values, so the first CTA observes non-zero state and the reductions diverge (easy to repro by calling top_k_mask_logits twice on the same device). Please zero the buffer whenever zero_init is requested.
if buf is None or buf.size(0) < bytes:
if zero_init:
buf = torch.zeros(bytes, dtype=torch.uint8, device=device)
else:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
+ elif zero_init:
+ buf.zero_()
return buf🤖 Prompt for AI Agents
In flashinfer/utils.py around lines 205 to 216, the cache allocator only zeroes
the tensor on first allocation but does not clear reused buffers when
zero_init=True; update the function so that when an existing cached buffer is
found and zero_init is True you explicitly zero it (e.g., buf.zero_() or
buf.fill_(0)) before returning/using it, and keep the existing behavior of
allocating a zeroed tensor for new buffers; ensure the zeroing runs on the
correct device and dtype (torch.uint8).
| __device__ __forceinline__ void red_release(int* ptr, int val) { | ||
| #if (__CUDA_ARCH__ >= 700) | ||
| // SM70 and newer use memory consistency qualifiers | ||
| // Release pattern using acq_rel fence + relaxed modifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some more clarifiction: besides releasing, this also performs a reduction (sum)
include/flashinfer/sampling.cuh
Outdated
| int persistent_iteration = 0; | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shoud we add tests that explicitly trigger num_groups > 1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do have, for vocab_size=128256, it will be splitted into 4 chunks (one per SM), so when batch_size is greater than 33, the num_groups will be greater than 132 (the number of SMs on hopper) for H100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)
2231-2286: Shared‑memory chunk sizing can exceed hardware limits (same issue as earlier review)In all of the multi‑CTA Top‑K helpers you compute a per‑CTA chunk size from
max_smem_per_block, but then forcibly enlarge it to a minimum without re‑checking against the shared‑memory budget:
TopKMaskLogitsMultiCTA:
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);chunk_size = std::max(chunk_size, min_chunk_size);RadixTopKMaskLogitsMultiCTARadixTopKRenormProbMultiCTARadixTopKMultiCTAOn GPUs with relatively small opt‑in shared memory,
available_for_* / sizeof(...)can be less thanmin_chunk_size. Forcingmax_chunk_elementsandchunk_sizeup past what fits inmax_smem_per_blockthen leads to:
smem_size>max_smem_per_block, andcudaFuncSetAttribute/cudaLaunchKernelfailing at runtime withcudaErrorInvalidConfiguration.Instead of forcing the minimum, you should validate and bail if the minimum is not achievable:
- const size_t available_for_logits = max_smem_per_block - temp_storage_aligned; - uint32_t max_chunk_elements = available_for_logits / sizeof(DType); - - // Round down to multiple of VEC_SIZE - max_chunk_elements = round_down(max_chunk_elements, VEC_SIZE); - - // Ensure minimum chunk size for vectorized access - constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS; - max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); + const size_t available_for_logits = max_smem_per_block - temp_storage_aligned; + uint32_t max_chunk_elements = available_for_logits / sizeof(DType); + + // Round down to multiple of VEC_SIZE + max_chunk_elements = round_down(max_chunk_elements, VEC_SIZE); + + // Ensure we have enough shared memory for at least one chunk + constexpr uint32_t min_chunk_size = VEC_SIZE * BLOCK_THREADS; + if (max_chunk_elements < min_chunk_size) { + // Not enough shared memory for even a single minimally sized chunk + return cudaErrorInvalidConfiguration; + } @@ - uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); - uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); - // Round up chunk_size to multiple of VEC_SIZE - chunk_size = round_up(chunk_size, VEC_SIZE); - // Ensure minimum chunk size - chunk_size = std::max(chunk_size, min_chunk_size); + uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); + uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); + // Round up chunk_size to multiple of VEC_SIZE; we already know this + // rounded size cannot exceed max_chunk_elements. + chunk_size = round_up(chunk_size, VEC_SIZE);Apply the same pattern in:
RadixTopKMaskLogitsMultiCTA(usesizeof(OrderedType)and itsmin_chunk_size),RadixTopKRenormProbMultiCTA,RadixTopKMultiCTA.This keeps the per‑CTA shared‑memory requirement provably ≤
max_smem_per_blockand avoids runtime launch failures on lower‑SMEM parts.Also applies to: 2915-2963, 3255-3299, 3731-3776
🧹 Nitpick comments (7)
flashinfer/jit/topk.py (1)
17-28: LGTM!The new JIT module generator follows the established pattern for other modules. The
JitSpecimport on line 18 appears unused (onlygen_jit_specis called), but this is a minor nit that doesn't affect functionality.-from .core import JitSpec, gen_jit_spec +from .core import gen_jit_speccsrc/topk.cu (1)
27-33: Consider adding output shape validation.The function validates input dimensions but doesn't verify that
output_indiceshas the expected shape(batch_size, top_k). This could lead to out-of-bounds writes if the caller passes an incorrectly sized tensor.CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_DIM(2, input); // input: (batch_size, d) CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k) unsigned int batch_size = input.size(0); unsigned int d = input.size(1); + + TVM_FFI_ICHECK_EQ(output_indices.size(0), batch_size) + << "output_indices batch size mismatch"; + TVM_FFI_ICHECK_EQ(output_indices.size(1), top_k) + << "output_indices second dimension should equal top_k";flashinfer/topk.py (2)
71-192: Tightentop_kAPI to be a safer drop‑in fortorch.topkRight now
top_kdoesn’t validatekor input shape, andsorted=Trueis ignored whenreturn_values=False(indices are returned unsorted). For better compatibility and clearer failures:
- Check
input.dim() == 2and0 < k <= input.size(1)and raise aValueErrorotherwise.- Either:
- sort indices even when
return_values=False, or- explicitly disallow
sorted=Truewhenreturn_values=False(e.g., raise), and document that combination.
56-64: Silence lint for unusedrow_states_buffer/output_valuesin fake opIn
_fake_radix_topk,row_states_bufferandoutput_valuesare intentionally unused but trigger ARG001. You can keep the signature while silencing lint by underscoring or explicitly discarding them.-@register_fake_op("flashinfer::radix_topk") -def _fake_radix_topk( - input: torch.Tensor, - top_k: int, - row_states_buffer: Optional[torch.Tensor], - output_values: Optional[torch.Tensor] = None, -) -> torch.Tensor: +@register_fake_op("flashinfer::radix_topk") +def _fake_radix_topk( + input: torch.Tensor, + top_k: int, + _row_states_buffer: Optional[torch.Tensor], + _output_values: Optional[torch.Tensor] = None, +) -> torch.Tensor:flashinfer/sampling.py (1)
378-384: Unusedrow_states_bufferin fake ops can be underscored to placate RuffThe fake implementations of
top_k_renorm_probsandtop_k_mask_logitsneed to keeprow_states_bufferin their signature but don’t use it, which triggers ARG001.-def _fake_top_k_renorm_probs( - probs: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - row_states_buffer: torch.Tensor, -) -> torch.Tensor: +def _fake_top_k_renorm_probs( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + _row_states_buffer: torch.Tensor, +) -> torch.Tensor: @@ -def _fake_top_k_mask_logits( - logits: torch.Tensor, - maybe_top_k_arr: Optional[torch.Tensor], - top_k_val: int, - row_states_buffer: torch.Tensor, -) -> torch.Tensor: +def _fake_top_k_mask_logits( + logits: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + _row_states_buffer: torch.Tensor, +) -> torch.Tensor:Also applies to: 413-420
tests/utils/test_topk.py (2)
33-47:verify_topk_correctnessis currently unused and has an unused parameterThe helper takes
indicesbut never uses it, and the function itself isn’t referenced by any test. You can either remove it or wire it into the suite; in both cases, consider dropping the unusedindicesparameter.-def verify_topk_correctness(logits, values, indices, k): +def verify_topk_correctness(logits, values, k): @@ - for i in range(batch_size): + for i in range(batch_size): # Get the k-th largest value (ground truth threshold) kth_largest = torch.kthvalue(-logits[i], k).values.item() * -1
66-67: Avoid binding unusedref_valuesin testsIn several tests you unpack
torch.topkas(ref_values, ref_indices)but only useref_indices. Renaming the unused variable (e.g., to_) will silence Ruff’s RUF059 and make intent clearer.- ref_values, ref_indices = torch.topk(logits, k, dim=-1) + _ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ - ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True) + _ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=True) @@ - ref_values, ref_indices = torch.topk(logits, k, dim=-1) + _ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ - ref_values, ref_indices = torch.topk(logits, k, dim=-1) + _ref_values, ref_indices = torch.topk(logits, k, dim=-1) @@ - ref_values, ref_indices = torch.topk(logits, k, dim=-1) + _ref_values, ref_indices = torch.topk(logits, k, dim=-1)Also applies to: 104-105, 138-139, 161-162, 188-189
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/flashinfer_sampling_binding.cu(1 hunks)csrc/flashinfer_topk_binding.cu(1 hunks)csrc/renorm.cu(1 hunks)csrc/topk.cu(1 hunks)csrc/tvm_ffi_utils.h(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/jit/core.py(1 hunks)flashinfer/jit/topk.py(1 hunks)flashinfer/sampling.py(9 hunks)flashinfer/topk.py(1 hunks)include/flashinfer/sampling.cuh(3 hunks)tests/utils/test_sampling.py(1 hunks)tests/utils/test_topk.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🧬 Code graph analysis (7)
flashinfer/__init__.py (1)
flashinfer/topk.py (1)
top_k(71-192)
csrc/flashinfer_topk_binding.cu (2)
flashinfer/topk.py (2)
top_k(71-192)radix_topk(36-54)csrc/topk.cu (2)
radix_topk(24-65)radix_topk(24-26)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
top_k_renorm_probs(356-375)top_k_renorm_probs(1249-1320)softmax(52-72)softmax(522-576)top_k_mask_logits(391-411)top_k_mask_logits(1326-1395)
flashinfer/jit/topk.py (1)
flashinfer/jit/core.py (2)
JitSpec(216-315)gen_jit_spec(318-384)
csrc/renorm.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(289-291)
tests/utils/test_topk.py (1)
flashinfer/topk.py (1)
top_k(71-192)
flashinfer/sampling.py (6)
flashinfer/utils.py (1)
_get_cache_buf(206-217)csrc/flashinfer_sampling_binding.cu (2)
top_k_renorm_probs(54-56)top_k_mask_logits(58-60)csrc/renorm.cu (4)
top_k_renorm_probs(42-70)top_k_renorm_probs(42-44)top_k_mask_logits(72-101)top_k_mask_logits(72-74)flashinfer/logits_processor/types.py (2)
probs(81-85)logits(74-78)csrc/tvm_ffi_utils.h (1)
Tensor(299-301)flashinfer/topk.py (1)
top_k(71-192)
🪛 Ruff (0.14.8)
flashinfer/jit/core.py
125-128: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
tests/utils/test_topk.py
33-33: Unused function argument: indices
(ARG001)
66-66: Unpacked variable ref_values is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
104-104: Unpacked variable ref_values is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
138-138: Unpacked variable ref_values is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
161-161: Unpacked variable ref_values is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
188-188: Unpacked variable ref_values is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer/sampling.py
382-382: Unused function argument: row_states_buffer
(ARG001)
418-418: Unused function argument: row_states_buffer
(ARG001)
flashinfer/topk.py
60-60: Unused function argument: row_states_buffer
(ARG001)
61-61: Unused function argument: output_values
(ARG001)
🔇 Additional comments (14)
csrc/tvm_ffi_utils.h (1)
95-111: LGTM!The new
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16macro follows the established pattern of existing dtype dispatchers, correctly adds FP32 support alongside FP16/BF16, and includes proper error handling for unsupported types.flashinfer/jit/core.py (1)
125-128: LGTM!The addition of
-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLEDflag aligns with the new multi-CTA kernel requirements. The concatenation pattern is consistent with other similar flag definitions in this file (e.g.,sm100a_nvcc_flags,sm103a_nvcc_flags).flashinfer/__init__.py (1)
144-145: LGTM!The new
topkmodule import andtop_kfunction re-export follow the established import patterns in this file and correctly expose the new radix-based top-k functionality at the package level.csrc/flashinfer_topk_binding.cu (1)
1-25: LGTM!Clean FFI binding file following the established pattern. The forward declaration correctly matches the implementation signature in
csrc/topk.cu, and the export registration properly exposes the function.csrc/topk.cu (1)
48-61: LGTM!The dtype dispatch and kernel invocation logic is correct. The optional output validation inside the dispatch block properly handles the case when values output is requested.
csrc/renorm.cu (2)
42-70: LGTM!The
top_k_renorm_probsfunction is correctly updated with the newrow_states_bufferparameter and multi-CTA radix-based implementation. Input validation and dtype dispatch are properly handled.
72-101: LGTM!The
top_k_mask_logitsfunction mirrors the changes totop_k_renorm_probscorrectly. The implementation properly supports the multi-CTA path with appropriate validation and error handling.tests/utils/test_sampling.py (2)
450-516: Solid test coverage for multi-dtype support.The dtype-parameterized test correctly handles the precision differences between FP32 and FP16/BF16:
- FP32 uses exact ground truth comparison
- FP16/BF16 uses property-based validation (sum-to-one, non-zero counts with tolerance)
The tolerance calculation
max(k // 5, 20)appropriately accounts for ties at the pivot in low precision.
519-584: LGTM!The
test_top_k_mask_logitsfunction properly validates the mask logits functionality across all supported dtypes with appropriate precision-aware assertions.csrc/flashinfer_sampling_binding.cu (1)
54-60: Signatures correctly extended withrow_states_bufferThe added
TensorView row_states_bufferparameters ontop_k_renorm_probsandtop_k_mask_logitsmatch the updated implementations and FFI exports, so the binding surface remains consistent.flashinfer/sampling.py (2)
353-375: Ensure C++ kernels support all dtypes allowed by Python assertionsThe custom ops for
top_k_renorm_probsandtop_k_mask_logitsnow assert support forfloat32,float16, andbfloat16without casting tofloat32, and forward tensors directly into the CUDA path.In
csrc/renorm.cu, the launch sites currently dispatch viaDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16, whose name suggests only fp32/fp16 support. If that macro hasn’t been extended to cover BF16 as well, BF16 inputs will hit an unsupported-dtype path at runtime.Please double-check that:
- The dispatch macro used in
top_k_renorm_probs/top_k_mask_logitshandles BF16, or- If not, either:
- restrict the Python-side assertion to
{float32, float16}, or- insert an explicit
probs = probs.float()/logits = logits.float()cast for BF16.Also applies to: 388-410
1307-1320: Row-state buffer sizing and reuse look reasonableThe public wrappers allocate a 1MB cached
row_states_bufferper op/device and reuse it across calls. GivenRadixRowStateis ~2 KB, this leaves room for ~500 concurrent groups, which exceedsnum_smsfor current architectures and matches the kernel’snum_groupsdesign.No change requested here; just noting that the sizing and
zero_init=Truebehavior are consistent with the C++ multi-CTA implementation (which resetsarrival_counterat the end of each launch).Also applies to: 1379-1395
tests/utils/test_topk.py (1)
50-229: Top‑k test coverage looks solidThe parametrized tests exercise large vocabularies, multiple batch sizes, large k, all supported dtypes, sorted/unsorted behavior, and a compatibility case vs
torch.topk, which is appropriate for this new kernel.include/flashinfer/sampling.cuh (1)
119-205: Radix type traits for float/half/bfloat16 look correctThe new
RadixTopKTraitsspecializations implement the usual sign‑bit tricks to map FP values to an unsigned ordered domain and back, plus explicit-infconstants for each dtype. The implementation matches standard descending‑order radix‑sort patterns forfloat,half, andnv_bfloat16.
flashinfer/topk.py
Outdated
| # RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter + output_counter) | ||
| # = 2*256*4 + 4 + 4 + 4 + 4 = 2064 bytes | ||
| RADIX_ROW_STATE_SIZE = 2064 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RADIX_ROW_STATE_SIZE is out of sync with the C++ RadixRowState struct
RadixRowState in include/flashinfer/sampling.cuh now includes an extra sum_topk field, so its size is 2068 bytes, not 2064. The Python constant and comment still reflect the older layout.
To avoid users under‑allocating custom buffers based on this constant, update it (and the comment) to match the current struct layout.
-# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter + output_counter)
-# = 2*256*4 + 4 + 4 + 4 + 4 = 2064 bytes
-RADIX_ROW_STATE_SIZE = 2064
+# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter
+# + output_counter + sum_topk)
+# = 2*256*4 + 4 + 4 + 4 + 4 + 4 = 2068 bytes
+RADIX_ROW_STATE_SIZE = 2068🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 26 to 28, the RADIX_ROW_STATE_SIZE constant
and its comment are stale because the C++ RadixRowState struct in
include/flashinfer/sampling.cuh added a sum_topk field; update the comment to
show the new size calculation and change RADIX_ROW_STATE_SIZE to 2068 to match
the C++ struct, and verify the byte math matches the struct field sizes so
Python buffer allocations stay in sync with the C++ layout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)
372-404: Fix unused arguments in fake ops and replace dtype assertions with explicit exceptionsThe custom ops for
top_k_renorm_probsandtop_k_mask_logitscorrectly threadrow_states_bufferdown to the C++ module and restrict inputs to FP32/FP16/BF16. Two issues need fixing:
- Fake ops: unused
row_states_bufferargument (Ruff ARG001)
The fake implementations acceptrow_states_buffer(required for API parity with the custom ops) but never reference it, triggering linter warnings. Add a trivial reference:@register_fake_op("flashinfer::top_k_renorm_probs") def _fake_top_k_renorm_probs( probs: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, row_states_buffer: torch.Tensor, ) -> torch.Tensor: + _ = row_states_buffer # kept for API parity with the real op return torch.empty_like(probs) @register_fake_op("flashinfer::top_k_mask_logits") def _fake_top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, row_states_buffer: torch.Tensor, ) -> torch.Tensor: + _ = row_states_buffer # kept for API parity with the real op return torch.empty_like(logits)
- Dtype validation: use explicit exceptions instead of
assert
assertcan be stripped withpython -Oand produces less controlled error messages. Replace with explicitTypeError:def top_k_renorm_probs( probs: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, row_states_buffer: torch.Tensor, ) -> torch.Tensor: - assert probs.dtype in [torch.float32, torch.float16, torch.bfloat16], ( - f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16" - ) + if probs.dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise TypeError( + f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16" + ) def top_k_mask_logits( logits: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, row_states_buffer: torch.Tensor, ) -> torch.Tensor: - assert logits.dtype in [torch.float32, torch.float16, torch.bfloat16], ( - f"Unsupported dtype {logits.dtype}, expected float32, float16, or bfloat16" - ) + if logits.dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise TypeError( + f"Unsupported dtype {logits.dtype}, expected float32, float16, or bfloat16" + )
1254-1313: Define_check_tensor_paramand_get_cache_buf, then use them in top-p/top-k APIs to fix undefined names and restore validationBoth
top_k_renorm_probsandtop_k_mask_logitscall_check_tensor_paramand_get_cache_buf, but these helpers are missing entirely, causing NameError at runtime. Additionally,top_p_renorm_probslacks parameter validation that the other APIs have.Fix this by:
Add
_check_tensor_paramand_get_cache_bufafter_to_tensor_scalar_tuple(which is also undefined)Implement
_check_tensor_paramto validate scalar or per-row tensor sampling parameters:def _check_tensor_param(param, probs: torch.Tensor) -> None: """Validate scalar or per-row tensor sampling parameters against probs/logits batch size.""" if isinstance(param, torch.Tensor): if param.dim() == 0: raise ValueError( "Expected a 1D tensor of shape (batch_size,) or scalar for sampling parameter, " "got a 0-dimensional tensor" ) if param.dim() == 1: if param.size(0) != probs.size(0): raise ValueError("Sampling parameter tensor batch size mismatch") return raise ValueError( f"Expected a 1D tensor or scalar for sampling parameter, got a {param.dim()}D tensor" )Implement
_get_cache_bufto manage row_states_buffer allocation.Add
_check_tensor_paramcalls to all three renorm/mask APIs before calling the moduleFor
top_p_renorm_probs(lines 1254–1313):_check_tensor_param(top_p, probs) return get_sampling_module().top_p_renorm_probs( probs, *_to_tensor_scalar_tuple(top_p) )Ensure
_to_tensor_scalar_tupleis properly defined to convert scalars/tensors for the kernel calls.This resolves the NameError failures, restores explicit parameter validation for top-p, and ensures all three APIs validate inputs consistently.
♻️ Duplicate comments (1)
flashinfer/utils.py (1)
206-216: Zero-initialized cache buffers must be cleared on reuse to avoid stale row state
_get_cache_bufonly zero-fills on first allocation or resize; when a cached buffer is reused withzero_init=True, it is returned without being cleared. The new multi-CTA top‑k paths (both sampling and radix_topk) store per-rowRowReductionStateobjects in this cache and assume a clean state each launch; reusing a “dirty” buffer leads to stale counters/min/max and incorrect results on subsequent calls.You should zero the buffer whenever it’s reused with
zero_init=True:def _get_cache_buf( name: str, bytes: int, device: torch.device, zero_init: bool = False ) -> torch.Tensor: key = (name, device) buf = _cache_buf.get(key) if buf is None or buf.size(0) < bytes: if zero_init: buf = torch.zeros(bytes, dtype=torch.uint8, device=device) else: buf = torch.empty(bytes, dtype=torch.uint8, device=device) _cache_buf[key] = buf + elif zero_init: + buf.zero_() return bufThis keeps the cache behavior but guarantees callers requesting
zero_init=Truealways get a cleared workspace.
🧹 Nitpick comments (2)
tests/utils/test_sampling.py (1)
450-517: Good per-dtype coverage fortop_k_renorm_probs; consider tightening FP16/BF16 nonzero-count toleranceThe FP32 ground-truth path and the FP16/BF16 softmax-based checks look solid and exercise the new kernels well. For the low-precision branch,
tolerance = max(k // 5, 20)can be quite loose for smallk(e.g.,k=10gives a lower bound below zero), so the nonzero-count assertions effectively only cap the upper side. If you want stronger guarantees that the kernel respects the intendedmin(k, nonzero_input)behavior, you might consider scaling the tolerance relative toexpected_countsor clamping it to something likemin(expected_counts.min().item(), max(k // 5, 20)).csrc/renorm.cu (1)
44-73: Multi-CTA top‑k bindings look correct; consider enforcing device match forrow_states_bufferThe switch to
RadixTopK*MultiCTAwithDISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16and the extrarow_states_bufferargument is wired correctly: FP32/FP16/BF16 are dispatched viac_type, scalar/tensortop_kis handled viamaybe_top_k_arrvstop_k_val, and the row-state workspace is threaded through assampling::RadixRowState*.One minor robustness improvement:
CHECK_INPUT(row_states_buffer)ensures CUDA + contiguity, but doesn’t guard against accidentally passing a buffer on a different CUDA device thanprobs/logits. Adding an explicit device check would make the FFI safer for non-Python callers:void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, Optional<TensorView> maybe_top_k_arr, int64_t top_k_val, TensorView row_states_buffer) { CHECK_INPUT(probs); - CHECK_INPUT(row_states_buffer); + CHECK_INPUT(row_states_buffer); + CHECK_DEVICE(probs, row_states_buffer); ... } void top_k_mask_logits(TensorView logits, TensorView mask_logits, Optional<TensorView> maybe_top_k_arr, int64_t top_k_val, TensorView row_states_buffer) { CHECK_INPUT(logits); - CHECK_INPUT(row_states_buffer); + CHECK_INPUT(row_states_buffer); + CHECK_DEVICE(logits, row_states_buffer); ... }Not blocking for this PR if all current callers pass a same-device buffer, but it will help catch misuses early.
Also applies to: 75-105
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/renorm.cu(3 hunks)csrc/tvm_ffi_utils.h(1 hunks)flashinfer/sampling.py(9 hunks)flashinfer/utils.py(1 hunks)tests/utils/test_sampling.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
flashinfer/utils.py (1)
csrc/fmha_v2/fused_multihead_attention_utils.h (1)
zeros(768-768)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
top_k_renorm_probs(375-394)top_k_renorm_probs(1320-1391)softmax(53-73)softmax(524-578)top_k_mask_logits(410-430)top_k_mask_logits(1398-1467)
flashinfer/sampling.py (4)
flashinfer/utils.py (3)
register_custom_op(319-328)register_custom_op(338-357)_get_cache_buf(206-217)csrc/renorm.cu (4)
top_k_renorm_probs(44-73)top_k_renorm_probs(44-46)top_k_mask_logits(75-105)top_k_mask_logits(75-77)csrc/flashinfer_sampling_binding.cu (2)
top_k_renorm_probs(54-56)top_k_mask_logits(58-60)flashinfer/topk.py (1)
top_k(71-192)
🪛 GitHub Actions: pre-commit
flashinfer/sampling.py
[error] 1378-1378: mypy: Name '_check_tensor_param' is not defined. (see ruff: F821 undefined name)
[error] 1451-1451: mypy: Name '_check_tensor_param' is not defined. (see ruff: F821 undefined name)
🪛 Ruff (0.14.8)
flashinfer/sampling.py
401-401: Unused function argument: row_states_buffer
(ARG001)
437-437: Unused function argument: row_states_buffer
(ARG001)
1378-1378: Undefined name _check_tensor_param
(F821)
1451-1451: Undefined name _check_tensor_param
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/utils/test_sampling.py (1)
519-585:top_k_mask_logitstests nicely mirrortop_k_renorm_probsbehavior across dtypesThe FP32 path’s exact comparison against
top_k_renorm_proband the FP16/BF16 tolerance-based checks (finite counts, sum-to-one after softmax,-infhandling) give good coverage of the new masking kernels. I don’t see functional issues here.csrc/tvm_ffi_utils.h (1)
96-112: FP32/FP16/BF16 dispatcher macro is consistent with existing patternsThe new
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16macro cleanly extends the existing FP16/BF16 dispatcher to handlefloat32_codeasfloatwhile reusing_DISPATCH_CASE_F16and_DISPATCH_CASE_BF16for half/bfloat16. The lambda/return __VA_ARGS__()style matches the other dispatch helpers, and theTVM_FFI_ICHECKdefault keeps unsupported dtypes failing fast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/sampling.py (1)
407-430: LGTM: Parameter threading is correct.The
row_states_bufferparameter is properly integrated intotop_k_mask_logits, with correct dtype validation and module call wiring.Minor style note: Line 422 has an extra blank line that could be removed for consistency.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/sampling.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/sampling.py (5)
flashinfer/utils.py (3)
register_custom_op(319-328)register_custom_op(338-357)_get_cache_buf(206-217)csrc/flashinfer_sampling_binding.cu (2)
top_k_renorm_probs(54-56)top_k_mask_logits(58-60)csrc/renorm.cu (4)
top_k_renorm_probs(44-73)top_k_renorm_probs(44-46)top_k_mask_logits(75-105)top_k_mask_logits(75-77)flashinfer/logits_processor/operators.py (1)
_to_tensor_scalar_tuple(28-34)flashinfer/topk.py (1)
top_k(71-192)
🪛 Ruff (0.14.8)
flashinfer/sampling.py
401-401: Unused function argument: row_states_buffer
(ARG001)
437-437: Unused function argument: row_states_buffer
(ARG001)
🔇 Additional comments (5)
flashinfer/sampling.py (5)
381-384: LGTM: Dtype validation correctly implemented.The runtime dtype checks appropriately validate FP32, FP16, and BF16 support, matching the C++ dispatcher capabilities shown in the relevant code snippets.
Also applies to: 416-419
372-394: LGTM: Parameter threading is correct.The
row_states_bufferparameter is properly added to the function signature, validated, and passed through to the underlying module call, matching the C++ binding signature fromcsrc/flashinfer_sampling_binding.cu.
396-403: Static analysis false positive: Unused parameters are expected in fake implementations.The fake/mock implementations (
_fake_top_k_renorm_probsand_fake_top_k_mask_logits) correctly accept therow_states_bufferparameter to match the real function signatures but don't use it, which is the intended behavior for test stubs. The Ruff warnings at lines 401 and 437 can be safely ignored.Also applies to: 432-439
1378-1389: LGTM: Buffer allocation and initialization are correct.The 1MB
row_states_bufferallocation strategy is well-designed:
zero_init=Trueensures the arrival_counter starts at 0 for multi-CTA synchronization- Device-specific caching avoids cross-device conflicts
- Function-specific cache keys prevent interference between concurrent operations
- Buffer size is justified by the PR objectives (sufficient for ~500 groups)
The explanatory comment on lines 1458-1460 helpfully documents the kernel's reset behavior.
Also applies to: 1449-1463
1330-1330: LGTM: Documentation accurately reflects dtype support.The documentation updates correctly specify the supported input dtypes (FP32, FP16, BF16) and clarify that output dtypes match input dtypes, which represents an important behavioral improvement over the previous FP32-only implementation.
Also applies to: 1342-1342, 1405-1405, 1417-1417
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
include/flashinfer/sampling.cuh (1)
2824-2906: Verify shared memory bounds in these launchers.The same shared memory calculation issue applies here as in
RadixTopKMaskLogitsMultiCTA(lines 2517-2529). Both launchers forcemax_chunk_elements = std::max(max_chunk_elements, min_chunk_size), which can exceed available shared memory.Apply the same validation fix suggested for
RadixTopKMaskLogitsMultiCTAto preventcudaErrorInvalidConfigurationon GPUs with limited shared memory.Also applies to: 3301-3373
flashinfer/topk.py (1)
26-26:RADIX_ROW_STATE_SIZEconstant is out of sync with the C++RadixRowStatestruct.The
RadixRowStatestruct ininclude/flashinfer/sampling.cuh(lines 1880-1887) now includes asum_topkfield (line 1886), making the struct size 2068 bytes, not 2064. Update this constant to match:+# RadixRowState size (histogram[2][256] + remaining_k + prefix + arrival_counter +# + output_counter + sum_topk) +# = 2*256*4 + 4 + 4 + 4 + 4 + 4 = 2068 bytes -RADIX_ROW_STATE_SIZE = 2064 +RADIX_ROW_STATE_SIZE = 2068
🧹 Nitpick comments (1)
flashinfer/topk.py (1)
149-149: Remove unused computation.Line 149 computes
input.size(1)but doesn't assign or use the result. This appears to be dead code.Apply this diff:
- input.size(1)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/topk.py(1 hunks)include/flashinfer/sampling.cuh(2 hunks)scripts/task_jit_run_tests_part3.sh(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🧬 Code graph analysis (1)
flashinfer/topk.py (5)
flashinfer/jit/topk.py (1)
gen_topk_module(21-28)flashinfer/utils.py (1)
_get_cache_buf(206-217)flashinfer/jit/core.py (1)
build_and_load(303-315)csrc/flashinfer_topk_binding.cu (1)
radix_topk(20-22)csrc/topk.cu (2)
radix_topk(24-65)radix_topk(24-26)
🪛 Ruff (0.14.8)
flashinfer/topk.py
56-56: Unused function argument: row_states_buffer
(ARG001)
57-57: Unused function argument: output_values
(ARG001)
🔇 Additional comments (14)
scripts/task_jit_run_tests_part3.sh (1)
15-15: LGTM!The new test invocation follows the established pattern and correctly isolates the top-k tests to run separately, consistent with the script's design goal of isolating CUDA memory issues. This aligns well with the PR's expanded test coverage for top-k functionality.
flashinfer/topk.py (4)
52-60: Static analysis false positive: unused parameters in fake op are intentional.Ruff flags
row_states_bufferandoutput_valuesas unused (lines 56-57), but this is expected. Fake ops must match the real op's signature for PyTorch's operator registration system, even if they don't use all parameters.
73-148: LGTM! Excellent documentation.The docstring is comprehensive, well-structured, and clearly documents:
- Parameter types and meanings
- Differences from
torch.topk(unsorted by default, O(n) complexity)- Performance characteristics (optimal for large vocabularies)
- Clear examples with expected outputs
153-162: LGTM! Buffer allocation strategy is sound.The 1MB cache buffer allocation with
zero_init=Trueis well-designed:
- Size calculation (1MB / 2068 bytes ≈ 494 groups) matches the comment's "up to ~500 groups"
- Zero initialization ensures
arrival_counterstarts at 0 for first use- Using
_get_cache_bufavoids repeated allocations
180-192: LGTM! Sorting and return logic are correct.The implementation properly handles both sorted and unsorted cases:
- Sorting uses
torch.sortwithdescending=Truefor values- Indices are reordered via
torch.gatherto match sorted values- Alias
topk = top_kprovides backward compatibilityinclude/flashinfer/sampling.cuh (9)
119-204: LGTM! Type traits implementation is correct.The
RadixTopKTraitsspecializations properly handle FP32, FP16, and BF16:
ToOrderedcorrectly transforms floats to unsigned integers for descending radix sort using IEEE 754 bit manipulationFromOrderedcorrectly reverses the transformationNegInfconstants are accurate (0xFC00 for FP16, 0xFF80 for BF16)num_roundscomputation is generic and correct
1801-1821: LGTM!The launcher function correctly dispatches to the kernel with appropriate template parameters.
1825-1875: LGTM! Multi-CTA synchronization primitives are correctly implemented.The memory ordering primitives properly implement acquire-release semantics:
ld_acquireuses acquire loads on SM70+, with CG fallbackred_releasecorrectly uses acq_rel fence before relaxed atomic (addresses earlier review comment with clarification on lines 1844-1845)st_releaseproperly fences before release storewait_gecorrectly spins only on thread 0, then synchronizes all threads
1880-1887: LGTM!RadixRowStatestruct layout is well-designed.The struct correctly includes all necessary fields for multi-CTA coordination:
- Double-buffered histograms enable barrier-free ping-pong
- Size is 2068 bytes (matches the corrected constant in Python)
sum_topkfield supports RenormProb aggregation
1890-2092: LGTM! Radix select helper functions are correctly implemented.The device functions properly implement the multi-round radix select algorithm:
RadixSuffixSumuses parallel stride-doubling reductionRadixFindThresholdBucketcorrectly identifies the bucket containing the k-th elementRadixBuildLocalHistogramaccurately counts elements matching the current prefixRadixSelectOneRoundproperly coordinates single-CTA and multi-CTA paths with appropriate barriers
2190-2483: LGTM! Kernel implementation is well-optimized.The
RadixTopKMaskLogitsKernel_MultiCTAimplementation demonstrates excellent optimization:
- Single load of data into shared memory (Stage 1) avoids repeated global memory access
- Double-buffering histograms eliminates second barrier per round
- Persistent kernel pattern efficiently handles large batch sizes
- Fast path for k >= vocab_size avoids unnecessary computation
2517-2529: Verify shared memory bounds before launching kernel.Lines 2517-2519 force
max_chunk_elementsto at leastmin_chunk_size, which can exceed available shared memory. Similarly, line 2523'schunk_sizecalculation may exceed limits. This could cause kernel launch failures withcudaErrorInvalidConfigurationwhenmax_smem_per_blockis small.Consider adding validation logic to prevent configuration errors:
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; if (available_for_ordered < min_chunk_size * sizeof(OrderedType)) { // Not enough shared memory for even the minimum chunk size return cudaErrorInvalidConfiguration; } // Then proceed with existing logic, knowing min_chunk_size will fit uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = (max_chunk_elements / vec_size) * vec_size; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); // ... rest of calculationNote: This concern was also raised in a previous review but may not have been addressed yet.
2579-2822: LGTM! RenormProb kernel correctly implements normalization.The
RadixTopKRenormProbKernel_MultiCTAproperly:
- Finds the pivot using the shared
RadixSelectFindPivotfunction- Aggregates sum of top-k elements across CTAs via
atomicAddtostate->sum_topk- Normalizes elements >= pivot and zeros others
- Handles k >= vocab_size with a fast normalization path
2916-3286: LGTM! Index collection kernel handles ties correctly.The
RadixTopKKernel_MultiCTAimplementation uses a smart two-pass approach:
- Pass 1 collects all elements > pivot with optimized local counting + single global atomic per CTA
- Pass 2 fills remaining slots with elements == pivot using global atomic coordination
- This ensures ties are handled correctly and respects the k limit
|
/bot run |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)
2486-2554: Shared‑memory chunk sizing can exceed device limits (same issue in all three multi‑CTA launchers)In
RadixTopKMaskLogitsMultiCTA,RadixTopKRenormProbMultiCTA, andRadixTopKMultiCTA, the shared‑mem sizing pattern is:const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);If
available_for_ordered / sizeof(OrderedType) < min_chunk_size,max_chunk_elementsgets forcibly promoted tomin_chunk_size, ignoring the real shared‑memory limit. That can makesmem_size > max_smem_per_block, causingcudaFuncSetAttribute/ kernel launch to fail withcudaErrorInvalidConfigurationon devices with smaller dynamic shared memory.The safer pattern is:
- Treat
min_chunk_sizeas a requirement, not something to enforce viastd::max.- If the hardware doesn’t have enough shared memory to fit even
min_chunk_size, return an error early instead of over‑requesting.For example, factor this into a helper and use it in all three launchers:
constexpr uint32_t BLOCK_THREADS = 1024; using OrderedType = typename RadixTopKTraits<DType>::OrderedType; const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); constexpr size_t fixed_smem_aligned = /* existing fixed overhead, already rounded up */; const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); // Not enough shared mem even for minimum viable chunk: bail out. if (max_chunk_elements < min_chunk_size) { return cudaErrorInvalidConfiguration; } // Normal path: choose chunk_size ≤ max_chunk_elements. uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType);This addresses the previously raised concern and keeps the three launchers consistent.
Please confirm against
cudaDeviceGetAttribute(cudaDevAttrMaxSharedMemoryPerBlockOptin, ...)for your lowest‑SMEM target (e.g., older consumer GPUs) that the revised logic never requests more than the allowed dynamic shared memory.Also applies to: 2809-2878, 3273-3341
🧹 Nitpick comments (2)
flashinfer/topk.py (1)
52-60: Silence Ruff ARG001 in fake op by explicitly “using” the extra arguments
row_states_bufferandoutput_valuesin_fake_radix_topkare intentionally unused but trigger Ruff ARG001. To keep signatures aligned with the real op while satisfying linting, you can add a trivial use:@register_fake_op("flashinfer::radix_topk") def _fake_radix_topk( input: torch.Tensor, top_k: int, row_states_buffer: Optional[torch.Tensor], output_values: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # Mark unused arguments as used to satisfy linters; required by the op schema. + _ = row_states_buffer, output_values batch_size = input.size(0) return torch.empty(batch_size, top_k, dtype=torch.int32, device=input.device)flashinfer/sampling.py (1)
396-404: Fake sampling ops: unusedrow_states_bufferargument (Ruff ARG001)
_fake_top_k_renorm_probsand_fake_top_k_mask_logitsmust acceptrow_states_bufferto mirror the real kernels, but Ruff flags them as unused. You can placate the linter without changing behavior:@register_fake_op("flashinfer::top_k_renorm_probs") def _fake_top_k_renorm_probs( probs: torch.Tensor, maybe_top_k_arr: Optional[torch.Tensor], top_k_val: int, row_states_buffer: torch.Tensor, ) -> torch.Tensor: - return torch.empty_like(probs) + _ = row_states_buffer # required by op schema; unused in fake implementation + return torch.empty_like(probs)(and similarly for
_fake_top_k_mask_logits).Also applies to: 432-439
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
flashinfer/aot.py(2 hunks)flashinfer/sampling.py(9 hunks)flashinfer/topk.py(1 hunks)include/flashinfer/sampling.cuh(2 hunks)tests/utils/test_sampling.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🧬 Code graph analysis (3)
flashinfer/topk.py (3)
flashinfer/jit/topk.py (1)
gen_topk_module(21-28)csrc/flashinfer_topk_binding.cu (1)
radix_topk(20-22)csrc/topk.cu (2)
radix_topk(24-65)radix_topk(24-26)
flashinfer/aot.py (1)
flashinfer/jit/topk.py (1)
gen_topk_module(21-28)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (6)
softmax(53-73)softmax(524-578)top_k_renorm_probs(375-394)top_k_renorm_probs(1320-1390)top_k_mask_logits(410-430)top_k_mask_logits(1397-1465)
🪛 Ruff (0.14.8)
flashinfer/topk.py
56-56: Unused function argument: row_states_buffer
(ARG001)
57-57: Unused function argument: output_values
(ARG001)
flashinfer/sampling.py
401-401: Unused function argument: row_states_buffer
(ARG001)
437-437: Unused function argument: row_states_buffer
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/aot.py (1)
70-71: Top‑k JIT module correctly integrated into AOT misc setImporting
gen_topk_moduleand adding it to theadd_misclist is consistent with how other JIT modules are wired, so top‑k will be built both under JIT and AOT flows. No issues from this change.Also applies to: 524-533
tests/utils/test_sampling.py (2)
362-393: Relaxed logits/probs alignment threshold for top‑k/top‑p sampling is reasonableLowering the required match rate to 95% and explicitly documenting both the softmax‑ordering differences and k‑th‑position ties makes this test more robust to numerics without weakening its intent (catching gross divergences between logits and probs paths). The improved failure message with mismatch counts is also helpful.
451-514: New top‑k renorm/mask tests give strong coverage across dtypes, distributions, and ties
test_top_k_renorm_probsandtest_top_k_mask_logitslook solid:
- They drive large batch/vocab sizes, three distributions, and {fp32, fp16, bf16}.
- They assert output dtype matches input and that renormalized probabilities/softmax(masked_logits) still sum to 1 within a reasonable tolerance.
- Tie‑aware logic (
pivot,num_greater,num_ties) checks that the number of non‑zero / finite elements stays within the mathematically valid range (≥ expected k, ≤ k plus ties), accounting for underflow with a ±1 slack.This should catch a wide range of regressions in the multi‑CTA radix top‑k implementation.
Also applies to: 516-597
include/flashinfer/sampling.cuh (1)
119-205: RadixTopKTraits specializations look correct for fp32/fp16/bf16The
RadixTopKTraitsimplementations (ordered key mapping,num_rounds, andNegInf) forfloat,half, andnv_bfloat16follow the standard sign‑bit trick and IEEE encodings (0xFC00for fp16 −inf,0xFF80for bf16 −inf). This gives you a clean, type‑generic radix pipeline across all three dtypes.
| @register_custom_op( | ||
| "flashinfer::top_k_renorm_probs", mutates_args=("row_states_buffer",) | ||
| ) | ||
| def top_k_renorm_probs( | ||
| probs: torch.Tensor, | ||
| maybe_top_k_arr: Optional[torch.Tensor], | ||
| top_k_val: int, | ||
| row_states_buffer: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| probs = probs.float() | ||
| # Support FP32, FP16, BF16 | ||
| assert probs.dtype in [torch.float32, torch.float16, torch.bfloat16], ( | ||
| f"Unsupported dtype {probs.dtype}, expected float32, float16, or bfloat16" | ||
| ) | ||
| maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None | ||
| renorm_probs = torch.empty_like(probs) | ||
| module.top_k_renorm_probs( | ||
| probs, | ||
| renorm_probs, | ||
| maybe_top_k_arr, | ||
| top_k_val, | ||
| row_states_buffer, | ||
| ) | ||
| return renorm_probs | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's understand the structure and find the relevant files
git ls-files | grep -E "(sampling|cache)" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 767
🏁 Script executed:
# Check the file size of sampling.py to determine reading strategy
wc -l flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 95
🏁 Script executed:
# Read the custom op wrapper at lines 372-395
sed -n '360,400p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1442
🏁 Script executed:
# Read the other mentioned sections (407-440, 1379-1390, 1451-1465)
sed -n '407,440p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1216
🏁 Script executed:
sed -n '1379,1390p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 484
🏁 Script executed:
sed -n '1451,1465p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 645
🏁 Script executed:
# Find _get_cache_buf implementation
rg "_get_cache_buf" -A 10 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 20383
🏁 Script executed:
# Search for buffer cache key structure and device/stream handling
rg "row_states" -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 21881
🏁 Script executed:
# Check if there are existing tests for concurrent stream usage
fd "test.*sampling" -o fd "sampling.*test"Repository: flashinfer-ai/flashinfer
Length of output: 191
🏁 Script executed:
# Search for arrival_counter and RadixRowState definition
rg "arrival_counter|RadixRowState" -A 5 -B 2 include/Repository: flashinfer-ai/flashinfer
Length of output: 19138
🏁 Script executed:
# Check the exact behavior of arrival_counter reset
rg "arrival_counter" -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 11160
🏁 Script executed:
# Look at test file for concurrent stream testing
cat tests/utils/test_sampling.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3562
🏁 Script executed:
# Search for any stream-related testing in sampling tests
rg "stream|concurrent" tests/utils/test_sampling.py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any comments or documentation about stream safety
rg "stream|concurrent|thread|sync" flashinfer/sampling.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the RadixRowState structure definition
rg "struct RadixRowState|class RadixRowState" -A 20Repository: flashinfer-ai/flashinfer
Length of output: 1665
🏁 Script executed:
# Verify lines 1379-1390 and 1451-1465 contain the high-level wrappers
sed -n '1375,1395p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 655
🏁 Script executed:
sed -n '1447,1470p' flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 829
🏁 Script executed:
# Check what those functions are
rg "^def top_k_renorm_probs|^def top_k_mask_logits" -A 2Repository: flashinfer-ai/flashinfer
Length of output: 394
Shared row_states_buffer in sampling top‑k APIs has concurrency risk with multiple CUDA streams
Both the custom op wrappers and high-level sampling APIs for top-k operations share a single 1 MB row_states_buffer per device:
- Custom ops
top_k_renorm_probs(lines 372–395) andtop_k_mask_logits(lines 407–440) acceptrow_states_bufferand pass it to CUDA kernels, which use it for inter-CTA synchronization viaRadixRowState(histogram, arrival_counter, sum_topk, etc.). - High-level APIs allocate this buffer via
_get_cache_buf("..._row_states_{device}", 1MB, device, zero_init=True)at lines 1379–1390 and 1451–1465, keyed only by device.
The buffer cache in _get_cache_buf is keyed by (name, device) only, not by stream. This means all concurrent kernel launches on the same device that access the same row_states_buffer share one RadixRowState[] array. When kernels launch concurrently on different CUDA streams, their CTAs can interleave reads and writes to arrival_counter, histogram[], output_counter, and sum_topk, causing data corruption.
The kernel does reset arrival_counter at completion (via st_release), but this happens only after the kernel finishes. Concurrent launches see the buffer in a partially-updated state.
Recommended mitigations:
- Key the buffer cache by both
(name, device, stream)to isolate concurrent launches, or - Document that these APIs are single-stream only and add assertions/tests to catch multi-stream usage, or
- Allocate buffers per-call when concurrency is required instead of caching per-device.
If your users commonly parallelize sampling across multiple streams, add a regression test that launches top_k_renorm_probs and top_k_mask_logits concurrently on distinct streams and verifies outputs against a single-stream baseline.
Also applies to: lines 407–440, 1379–1390, 1451–1465
🤖 Prompt for AI Agents
In flashinfer/sampling.py around lines 372-395, the top_k_renorm_probs wrapper
(and related top_k_mask_logits and the high-level APIs at lines 1379–1390 and
1451–1465) accepts a shared 1MB row_states_buffer that is cached per (name,
device) only, which causes data races when kernels are launched concurrently on
different CUDA streams; fix by changing the buffer-cache keying to include the
CUDA stream (e.g., (name, device, stream)) so each simultaneous launch gets an
isolated row_states_buffer, or alternatively stop caching per-device and
allocate a fresh buffer per-call when stream isolation is required, and add a
regression test that launches the top-k kernels on multiple streams and verifies
outputs against a single-stream baseline; if choosing not to change allocation,
add clear documentation and runtime assertions preventing multi-stream usage.
| # Allocate row_states buffer for multi-CTA path | ||
| # For single-CTA path this buffer is not used but we always allocate for simplicity | ||
| # 1MB is enough for any reasonable GPU (covers up to ~500 groups) | ||
| # zero_init=True ensures arrival_counter starts at 0 on first use | ||
| row_states_buffer: Optional[torch.Tensor] = _get_cache_buf( | ||
| f"radix_topk_row_states_{input.device}", | ||
| 1024 * 1024, # 1MB | ||
| input.device, | ||
| zero_init=True, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat flashinfer/topk.pyRepository: flashinfer-ai/flashinfer
Length of output: 6884
Shared row_states_buffer can race across concurrent calls on the same device
row_states_buffer is cached per device via _get_cache_buf (line 154), keyed only by device without stream specificity. All top_k invocations on a given GPU share the same 1 MB buffer. Since the radix_topk kernel mutates this buffer for inter-CTA synchronization (line 31), concurrent launches on different CUDA streams can trample each other's state and corrupt results.
Consider either:
- Making the cache key stream- or call-specific, or
- Documenting and enforcing single-stream usage, or
- Allocating a per-call buffer when heavy concurrency is expected.
Additionally, input.size(1) (line 154) has its result unused, and sorted=True is silently ignored when return_values=False (sorting only applied at lines 176–177).
| * \brief Multi-CTA Radix Top-K kernel that returns indices of top-k elements. | ||
| * | ||
| * Uses cooperative multi-CTA radix select to find the k-th largest element, | ||
| * then collects indices of all elements >= pivot. | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType, | ||
| typename IdType> | ||
| __global__ void __launch_bounds__(BLOCK_THREADS) | ||
| RadixTopKKernel_MultiCTA(DType* input, // [batch, vocab_size] | ||
| IdType* output_indices, // [batch, top_k] | ||
| DType* output_values, // [batch, top_k] or nullptr | ||
| IdType* top_k_arr, // [batch] or nullptr | ||
| uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size, | ||
| RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA) | ||
| uint32_t chunk_size, // elements per CTA | ||
| uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA) | ||
| { | ||
| // Type traits for FP16/BF16/FP32 support | ||
| using Traits = RadixTopKTraits<DType>; | ||
| using OrderedType = typename Traits::OrderedType; | ||
|
|
||
| constexpr uint32_t RADIX = 256; | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
| constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>(); | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
|
|
||
| const uint32_t global_cta_id = blockIdx.x; | ||
| const uint32_t group_id = global_cta_id / ctas_per_group; | ||
| const uint32_t cta_in_group = global_cta_id % ctas_per_group; | ||
| const uint32_t tx = threadIdx.x; | ||
|
|
||
| // Shared memory layout: [fixed storage] [ordered values cache] | ||
| extern __shared__ uint8_t smem[]; | ||
|
|
||
| // Fixed shared memory (at the beginning) | ||
| // When SINGLE_CTA, we need an extra uint32 for output_counter (no global state) | ||
| constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4; | ||
| constexpr size_t fixed_smem_size = | ||
| sizeof(uint32_t) * (RADIX + RADIX + num_scalars); // histogram + suffix + scalars | ||
| uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem); | ||
| uint32_t* suffix_sum = local_histogram + RADIX; | ||
| uint32_t* shared_scalars = suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket, | ||
| // found_remaining_k, (output_counter)] | ||
|
|
||
| // Align ordered values cache to 16 bytes | ||
| size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16; | ||
| OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset); | ||
|
|
||
| // Aliases for scalar shared variables | ||
| #define prefix_cache shared_scalars[0] | ||
| #define remaining_k_cache shared_scalars[1] | ||
| #define found_bucket shared_scalars[2] | ||
| #define found_remaining_k shared_scalars[3] | ||
| #define shared_output_counter shared_scalars[4] // Only valid when SINGLE_CTA | ||
|
|
||
| // State pointer only used when not SINGLE_CTA | ||
| RadixRowState* state = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| state = &row_states[group_id]; | ||
| } | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; | ||
| uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups; | ||
|
|
||
| int barrier_phase = 0; | ||
|
|
||
| // Persistent loop over rows | ||
| for (uint32_t iter = 0; iter < total_iterations; iter++) { | ||
| uint32_t row_idx = group_id + iter * num_groups; | ||
|
|
||
| if (row_idx >= batch_size) break; | ||
|
|
||
| const uint32_t chunk_start = cta_in_group * chunk_size; | ||
| const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size); | ||
|
|
||
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; | ||
|
|
||
| const uint32_t actual_chunk_size = chunk_end - chunk_start; | ||
|
|
||
| if (k >= vocab_size) { | ||
| // k >= vocab_size: return all indices | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| if (chunk_start + i < k) { | ||
| output_indices[row_idx * top_k_val + chunk_start + i] = | ||
| static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + chunk_start + i] = | ||
| input[row_idx * vocab_size + chunk_start + i]; | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // ========== Stage 1: Load and convert to ordered representation in shared memory ========== | ||
| vec_t<DType, VEC_SIZE> input_vec; | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| input_vec.cast_load(input + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| shared_ordered[i + j] = Traits::ToOrdered(input_vec[j]); | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| shared_ordered[i] = Traits::ToOrdered(input[row_idx * vocab_size + chunk_start + i]); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Initialize local caches and clear global state | ||
| if (tx == 0) { | ||
| prefix_cache = 0; | ||
| remaining_k_cache = k; | ||
| if constexpr (SINGLE_CTA) { | ||
| shared_output_counter = 0; // Use shared memory counter for single CTA | ||
| } | ||
| } | ||
| // Clear global histograms (only needed for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[0][i] = 0; | ||
| state->histogram[1][i] = 0; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| // CTA 0 clears output counter AFTER barrier | ||
| if (cta_in_group == 0 && tx == 0) { | ||
| st_release(&state->output_counter, 0); | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // ========== Stage 2: NUM_ROUNDS of radix select ========== | ||
| // Using double-buffering: round N uses histogram[N % 2] | ||
| // Round N clears histogram[(N+1) % 2] for next round's use | ||
| for (uint32_t round = 0; round < NUM_ROUNDS; ++round) { | ||
| uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS; | ||
| // Read from local cache (no global memory access needed!) | ||
| uint32_t prefix = prefix_cache; | ||
| uint32_t remaining_k = remaining_k_cache; | ||
|
|
||
| // For multi-CTA: pointers to global histograms | ||
| // For single-CTA: these are not used | ||
| uint32_t* current_hist = nullptr; | ||
| uint32_t* other_hist = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| current_hist = state->histogram[round % 2]; | ||
| other_hist = state->histogram[(round + 1) % 2]; | ||
| } | ||
|
|
||
| // Clear local histogram AND (for multi-CTA) clear the "other" global histogram | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| local_histogram[i] = 0; | ||
| if constexpr (!SINGLE_CTA) { | ||
| other_hist[i] = 0; // Prepare for next round (no barrier needed!) | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Build local histogram from SHARED MEMORY (no global memory access!) | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered = shared_ordered[i]; | ||
|
|
||
| // Check if this element matches the prefix (high bits determined so far) | ||
| OrderedType mask = | ||
| (round == 0) | ||
| ? OrderedType(0) | ||
| : static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS)); | ||
| if ((ordered & mask) == static_cast<OrderedType>(prefix)) { | ||
| uint32_t bucket = (ordered >> shift) & 0xFF; | ||
| atomicAdd(&local_histogram[bucket], 1); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // For multi-CTA: add to global histogram and barrier | ||
| // For single-CTA: local_histogram is already the complete histogram | ||
| if constexpr (!SINGLE_CTA) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| if (local_histogram[i] > 0) { | ||
| atomicAdd(¤t_hist[i], local_histogram[i]); | ||
| } | ||
| } | ||
|
|
||
| // Barrier: wait for all CTAs to finish histogram accumulation | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| // Load from global histogram to suffix_sum | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = current_hist[i]; | ||
| } | ||
| } else { | ||
| // Single-CTA: copy local histogram directly to suffix_sum | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = local_histogram[i]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Parallel suffix sum in shared memory | ||
| for (uint32_t stride = 1; stride < RADIX; stride *= 2) { | ||
| uint32_t val = 0; | ||
| if (tx < RADIX) { | ||
| val = suffix_sum[tx]; | ||
| if (tx + stride < RADIX) { | ||
| val += suffix_sum[tx + stride]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| if (tx < RADIX) { | ||
| suffix_sum[tx] = val; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // ALL CTAs: find threshold bucket (all compute same result) | ||
| if (tx == 0) { | ||
| found_bucket = 0; | ||
| found_remaining_k = remaining_k; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (tx < RADIX) { | ||
| uint32_t count_ge = suffix_sum[tx]; | ||
| uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0; | ||
| if (count_ge >= remaining_k && count_gt < remaining_k) { | ||
| found_bucket = tx; | ||
| found_remaining_k = remaining_k - count_gt; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Update local caches (all CTAs have same values) | ||
| if (tx == 0) { | ||
| prefix_cache = prefix | (found_bucket << shift); | ||
| remaining_k_cache = found_remaining_k; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Get final ordered pivot from prefix_cache | ||
| OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache); | ||
|
|
||
| // ========== Stage 3: Collect indices >= pivot ========== | ||
| // Two-pass approach to handle ties correctly: | ||
| // Pass 1: collect all elements strictly > pivot (these must be in top-k) | ||
| // Pass 2: fill remaining slots with elements == pivot | ||
| // | ||
| // Optimization for Pass 1 (> pivot): Use shared memory atomic to count locally, | ||
| // then one global atomic per CTA to get base position, then shared atomic to write. | ||
| // This works because all > pivot elements are guaranteed to be in top-k. | ||
| // | ||
| // For Pass 2 (== pivot): Use global atomic directly since we need cross-CTA | ||
| // coordination to respect the k limit (some == pivot elements may be truncated). | ||
|
|
||
| // Reuse local_histogram[0..1] as counters | ||
| #define local_counter local_histogram[0] | ||
| #define global_base local_histogram[1] | ||
|
|
||
| // Pass 1: Count elements > pivot locally, then write with one global atomic | ||
| if (tx == 0) { | ||
| local_counter = 0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // First pass: count how many elements > pivot in this CTA | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val > ordered_pivot) { | ||
| atomicAdd(&local_counter, 1); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Get base position for this CTA | ||
| uint32_t cta_count_gt = local_counter; | ||
| if (tx == 0 && cta_count_gt > 0) { | ||
| if constexpr (SINGLE_CTA) { | ||
| global_base = atomicAdd(&shared_output_counter, cta_count_gt); | ||
| } else { | ||
| global_base = atomicAdd(&state->output_counter, cta_count_gt); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Second pass: write elements > pivot using local shared atomic for position | ||
| if (tx == 0) { | ||
| local_counter = 0; // Reset for use as write position | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (cta_count_gt > 0) { | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val > ordered_pivot) { | ||
| uint32_t local_pos = atomicAdd(&local_counter, 1); | ||
| int pos = global_base + local_pos; | ||
| // No need to check pos < k here since all > pivot elements are in top-k | ||
| output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Barrier to ensure all > pivot elements are collected first (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Pass 2: Write elements == pivot | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val == ordered_pivot) { | ||
| int pos; | ||
| if constexpr (SINGLE_CTA) { | ||
| pos = atomicAdd(&shared_output_counter, 1); | ||
| } else { | ||
| pos = atomicAdd(&state->output_counter, 1); | ||
| } | ||
| if (pos < static_cast<int>(k)) { | ||
| output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #undef local_counter | ||
| #undef global_base | ||
| // No barrier needed here - the barrier at the start of next iteration | ||
| // ensures all CTAs complete Stage 3 before output_counter is reset | ||
| } | ||
|
|
||
| // Reset arrival counter for next kernel launch (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (cta_in_group == 0 && tx == 0) { | ||
| st_release(&state->arrival_counter, 0); | ||
| } | ||
| } | ||
|
|
||
| #undef prefix_cache | ||
| #undef remaining_k_cache | ||
| #undef found_bucket | ||
| #undef found_remaining_k | ||
| #undef shared_output_counter | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, find topk.cu and understand how the kernel is called
fd -e cu -e h | grep -i topk | head -20Repository: flashinfer-ai/flashinfer
Length of output: 112
🏁 Script executed:
# Look for the RadixTopKMultiCTA function signature and how it's called
rg "RadixTopKMultiCTA|RadixTopKKernel_MultiCTA" -A 5 -B 5 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4674
🏁 Script executed:
# Find lines 3273-3341 mentioned in the review
sed -n '3273,3341p' include/flashinfer/sampling.cuh | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3545
🏁 Script executed:
# Check for any existing assertions on top_k_arr in the kernel
rg "assert|Assert|CHECK" include/flashinfer/sampling.cuh | grep -i "top_k" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for all callers of RadixTopKMultiCTA
rg "RadixTopKMultiCTA" --type cu --type hRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Verify the exact range 3273-3341 and see what's there
wc -l include/flashinfer/sampling.cuhRepository: flashinfer-ai/flashinfer
Length of output: 104
🏁 Script executed:
# Get lines 3273-3341 to confirm it's the wrapper function
sed -n '3273,3300p' include/flashinfer/sampling.cuhRepository: flashinfer-ai/flashinfer
Length of output: 1592
🏁 Script executed:
# Check if there are any other calls to RadixTopKMultiCTA outside of topk.cu
rg "RadixTopKMultiCTA" --glob "*.cu" --glob "*.h" --glob "*.cuh"Repository: flashinfer-ai/flashinfer
Length of output: 261
🏁 Script executed:
# Examine the exact lines referenced for "Also applies to" - check if wrapper also has the same issue
sed -n '3273,3341p' include/flashinfer/sampling.cuh | grep -A 20 "RadixTopKMultiCTA"Repository: flashinfer-ai/flashinfer
Length of output: 1207
Add validation for top_k_arr parameter to prevent out-of-bounds writes
RadixTopKMultiCTA and RadixTopKKernel_MultiCTA accept an optional per-row top_k_arr parameter but index into output buffers using a fixed stride of top_k_val:
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];
output_indices[row_idx * top_k_val + pos] = /* ... */;This will cause out-of-bounds writes if any element in top_k_arr exceeds top_k_val, since the output tensors are allocated as [batch_size, top_k_val].
Currently safe because the only caller (topk.cu) passes top_k_arr == nullptr, but future callers could inadvertently exceed the buffer bounds. Add a validation check at the top of RadixTopKMultiCTA:
if (top_k_arr != nullptr) {
// Assert all top_k_arr[i] <= top_k_val, or adjust buffer allocation
assert(/* validate contract */);
}Alternatively, if per-row k values are needed, allocate output buffers based on max(top_k_arr) rather than top_k_val.
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2883 to 3258, the kernel uses
per-row k from top_k_arr but always strides outputs by top_k_val which can cause
out-of-bounds writes if any top_k_arr[row] > top_k_val; add a validation at the
start of RadixTopKMultiCTA (before launching kernels / inside the host wrapper
function that prepares top_k_arr) that, if top_k_arr != nullptr, computes the
maximum requested k and either (a) asserts/returns an error when max_k >
top_k_val, or (b) requires/allocates output buffers sized to max_k instead of
top_k_val; ensure the check is deterministic and runs before any kernel
invocation so no kernel can write beyond allocated output buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
include/flashinfer/sampling.cuh (2)
2500-2569: Critical: shared-memory chunk sizing can exceed opt-in SMEM (duplicate of prior review)
The patternmax_chunk_elements = std::max(max_chunk_elements, min_chunk_size);can forcechunk_sizebeyond what fits, causing launch failure. Alsomax_smem_per_block - fixed_smem_alignedcan underflow if fixed overhead already exceeds opt-in SMEM. This matches the earlier critical comment and should be fixed in all three launchers.- const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; + if (max_smem_per_block <= static_cast<int>(fixed_smem_aligned)) { + return cudaErrorInvalidConfiguration; + } + const size_t available_for_ordered = + static_cast<size_t>(max_smem_per_block) - fixed_smem_aligned; uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); max_chunk_elements = round_down(max_chunk_elements, vec_size); const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; - max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); + if (max_chunk_elements < min_chunk_size) { + // Not enough SMEM to run with current BLOCK_THREADS / vec_size. + return cudaErrorInvalidConfiguration; + } uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); chunk_size = round_up(chunk_size, vec_size); chunk_size = std::min(chunk_size, max_chunk_elements); const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType); + if (smem_size > static_cast<uint32_t>(max_smem_per_block)) { + return cudaErrorInvalidConfiguration; + }Apply the same fix pattern in:
RadixTopKMaskLogitsMultiCTA(Line 2500+)RadixTopKRenormProbMultiCTA(Line 2834+)RadixTopKMultiCTA(Line 3305+)Also applies to: 2834-2903, 3305-3373
2913-3291: Prevent OOB writes whentop_k_arr[row] > top_k_val(duplicate of prior review)
output_indices/output_valuesare indexed with arow_idx * top_k_val + posstride, butkcan come fromtop_k_arr[row]. If anytop_k_arr[row] > top_k_val, this can write past the output allocation.A pragmatic safety fix is to cap the effective
ktotop_k_valinsideRadixTopKKernel_MultiCTA:- uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; + const uint32_t k_req = top_k_arr == nullptr ? top_k_val : static_cast<uint32_t>(top_k_arr[row_idx]); + const uint32_t k_out = min(k_req, top_k_val); + const uint32_t k = min(k_out, vocab_size); ... - if (k >= vocab_size) { + if (k_req >= vocab_size) { // k >= vocab_size: return all indices for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { - if (chunk_start + i < k) { + if (chunk_start + i < k_out) { output_indices[row_idx * top_k_val + chunk_start + i] = static_cast<IdType>(chunk_start + i); if (output_values != nullptr) { output_values[row_idx * top_k_val + chunk_start + i] = input[row_idx * vocab_size + chunk_start + i]; } } } continue; } ... - if (pos < static_cast<int>(k)) { + if (pos < static_cast<int>(k_out)) { output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i); ... }(If you require per-row
k>top_k_val, then the allocator/API needs to size outputs bymax(top_k_arr)instead.)Also applies to: 3305-3373
🧹 Nitpick comments (3)
include/flashinfer/sampling.cuh (3)
119-205: Add compile-time guards for radix configuration and document NaN behavior
RadixTopKTraits<>::num_rounds()assumesRADIX_BITSdivides the type width; add astatic_assertto prevent silent misconfigs. Also worth a short note on NaN ordering (bitwise mapping makes NaNs “sortable” but not numerically meaningful).template <typename DType> struct RadixTopKTraits; template <> struct RadixTopKTraits<float> { using OrderedType = uint32_t; template <uint32_t RADIX_BITS> static __host__ __device__ constexpr uint32_t num_rounds() { + static_assert((sizeof(OrderedType) * 8) % RADIX_BITS == 0, "RADIX_BITS must divide type width"); return sizeof(OrderedType) * 8 / RADIX_BITS; } ... };
1825-1885: Spin-wait could use a backoff to reduce SM contention under load
wait_ge()busy-spins; consider a tiny backoff (__nanosleep) in the loop to reduce pressure when groups are imbalanced.
1889-2198:RadixFindThresholdBucketwrites shared scalars from multiple lanes; ensure single-writer invariant is explicit
It’s likely only one bucket satisfies the predicate, but that contract isn’t enforced. Consider using a sentinel +atomicCAS(or a warp ballot + single lane write) to make “exactly one writer” deterministic.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/sampling.cuh(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🪛 GitHub Actions: pre-commit
include/flashinfer/sampling.cuh
[error] 2030-2032: clang-format formatting check failed. Some files were modified by this hook.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)
1801-1821:TopPRenormProbwrapper wiring looks consistent with the updated kernel
Launch config / SMEM sizing / dispatch structure is coherent.
| // ==================== Multi-CTA Top-K Implementation ==================== | ||
|
|
||
| min_gt_low = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(min_gt_low, MinReduceOp{}); | ||
| __syncthreads(); | ||
| max_le_high = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(max_le_high, MaxReduceOp{}); | ||
| if (tx == 0) { | ||
| temp_storage.block_aggregate.counts[0] = aggregate_gt_pivot_0; | ||
| temp_storage.block_aggregate.counts[1] = aggregate_gt_pivot_1; | ||
| temp_storage.min_val = min_gt_low; | ||
| temp_storage.max_val = max_le_high; | ||
| } | ||
| __syncthreads(); | ||
| aggregate_gt_pivot_0 = temp_storage.block_aggregate.counts[0]; | ||
| aggregate_gt_pivot_1 = temp_storage.block_aggregate.counts[1]; | ||
| min_gt_low = temp_storage.min_val; | ||
| max_le_high = temp_storage.max_val; | ||
|
|
||
| if (aggregate_gt_pivot_1 >= k) { | ||
| low = pivot_1; | ||
| } else if (aggregate_gt_pivot_0 >= k) { | ||
| low = pivot_0; | ||
| high = min(pivot_1, max_le_high); | ||
| } else { | ||
| high = min(pivot_0, max_le_high); | ||
| } | ||
| } while (min_gt_low != max_le_high); | ||
| pivot = low; | ||
| // Acquire/Release primitives for inter-CTA synchronization | ||
| __device__ __forceinline__ int ld_acquire(int* ptr) { | ||
| int state = 0; | ||
|
|
||
| #if (__CUDA_ARCH__ >= 700) | ||
| // SM70 and newer use memory consistency qualifiers | ||
| // Acquire pattern using acquire modifier | ||
| asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); | ||
| #else | ||
| asm volatile("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); | ||
| #endif | ||
|
|
||
| return state; | ||
| } | ||
|
|
||
| __device__ __forceinline__ void red_release(int* ptr, int val) { | ||
| #if (__CUDA_ARCH__ >= 700) | ||
| // SM70 and newer use memory consistency qualifiers | ||
| // Release pattern using acq_rel fence + relaxed modifier | ||
| // (The fence also releases data that was weakly-written by other threads prior to the last | ||
| // syncthreads) | ||
| asm volatile("fence.acq_rel.gpu;\n"); | ||
| asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); | ||
| #else | ||
| __threadfence(); | ||
| atomicAdd(ptr, val); | ||
| #endif | ||
| } | ||
|
|
||
| __device__ __forceinline__ void st_release(int* ptr, int val) { | ||
| #if (__CUDA_ARCH__ >= 700) | ||
| // SM70 and newer use memory consistency qualifiers | ||
| // Release pattern: fence + release store | ||
| asm volatile("fence.acq_rel.gpu;\n"); | ||
| asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); | ||
| #else | ||
| __threadfence(); | ||
| atomicExch(ptr, val); | ||
| #endif | ||
| } | ||
|
|
||
| __device__ __forceinline__ void st_release(uint32_t* ptr, uint32_t val) { | ||
| #if (__CUDA_ARCH__ >= 700) | ||
| asm volatile("fence.acq_rel.gpu;\n"); | ||
| asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); | ||
| #else | ||
| __threadfence(); | ||
| atomicExch(ptr, val); | ||
| #endif | ||
| } | ||
|
|
||
| // Wait until the value at ptr reaches target_val using acquire semantics | ||
| // Only thread 0 spins, then all threads synchronize | ||
| __device__ __forceinline__ void wait_ge(int* ptr, int target_val, int thread_idx) { | ||
| if (thread_idx == 0) { | ||
| #pragma unroll 1 | ||
| while (ld_acquire(ptr) < target_val) { | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // masking | ||
| #pragma unroll 2 | ||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| logits_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| // ==================== Multi-CTA Radix Top-K Mask Logits ==================== | ||
|
|
||
| // Global state for multi-CTA radix reduction (one per group) | ||
| struct RadixRowState { | ||
| uint32_t histogram[3][256]; // Triple-buffered histograms for 1-barrier-per-round | ||
| uint32_t remaining_k; // Remaining k after current round | ||
| uint32_t prefix; // Accumulated prefix (high bits of k-th element) | ||
| int arrival_counter; // For inter-CTA synchronization | ||
| int output_counter; // For collecting top-k indices (RadixTopK) | ||
| float sum_topk; // For RenormProb: sum of top-k elements | ||
| }; | ||
|
|
||
| // ==================== Common Device Functions for Radix Top-K ==================== | ||
|
|
||
| /*! | ||
| * \brief Compute suffix sum in shared memory using parallel reduction. | ||
| * | ||
| * After this function, suffix_sum[i] contains the count of elements >= bucket i. | ||
| * This is computed by summing all histogram values from bucket i to 255. | ||
| * | ||
| * \param suffix_sum Shared memory array of size RADIX (256) | ||
| * \param tx Thread index within the block | ||
| */ | ||
| template <uint32_t BLOCK_THREADS> | ||
| __device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t tx) { | ||
| constexpr uint32_t RADIX = 256; | ||
| // Parallel suffix sum: compute count of elements >= each bucket | ||
| for (uint32_t stride = 1; stride < RADIX; stride *= 2) { | ||
| uint32_t val = 0; | ||
| if (tx < RADIX) { | ||
| val = suffix_sum[tx]; | ||
| if (tx + stride < RADIX) { | ||
| val += suffix_sum[tx + stride]; | ||
| } | ||
| } | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| logits_vec[j] = | ||
| (logits_vec[j] > pivot) ? logits_vec[j] : -cuda::std::numeric_limits<float>::infinity(); | ||
| __syncthreads(); | ||
| if (tx < RADIX) { | ||
| suffix_sum[tx] = val; | ||
| } | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| logits_vec.store(masked_logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| __syncthreads(); | ||
| } | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Find the threshold bucket that contains the k-th largest element. | ||
| * | ||
| * The threshold bucket satisfies: count_ge >= k && count_gt < k | ||
| * where count_ge = suffix_sum[bucket] and count_gt = suffix_sum[bucket+1]. | ||
| * | ||
| * \param suffix_sum Shared memory array containing suffix sums | ||
| * \param remaining_k Number of top-k elements still to find | ||
| * \param found_bucket Output: the found threshold bucket | ||
| * \param found_remaining_k Output: remaining_k minus count of elements > threshold | ||
| * \param tx Thread index within the block | ||
| */ | ||
| __device__ __forceinline__ void RadixFindThresholdBucket(uint32_t* suffix_sum, uint32_t remaining_k, | ||
| uint32_t* found_bucket, | ||
| uint32_t* found_remaining_k, uint32_t tx) { | ||
| constexpr uint32_t RADIX = 256; | ||
| // Initialize (only thread 0) | ||
| if (tx == 0) { | ||
| *found_bucket = 0; | ||
| *found_remaining_k = remaining_k; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // All threads in RADIX range check their bucket | ||
| if (tx < RADIX) { | ||
| uint32_t count_ge = suffix_sum[tx]; | ||
| uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0; | ||
| if (count_ge >= remaining_k && count_gt < remaining_k) { | ||
| *found_bucket = tx; | ||
| *found_remaining_k = remaining_k - count_gt; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, | ||
| typename DType, typename IdType> | ||
| __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, | ||
| uint32_t top_k_val, uint32_t d) { | ||
| const uint32_t bx = blockIdx.x, tx = threadIdx.x; | ||
| const uint32_t row_idx = bx; | ||
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; | ||
| double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1; | ||
| vec_t<float, VEC_SIZE> probs_vec; | ||
| if (k < d) { | ||
| extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>)) | ||
| uint8_t smem_renorm[]; | ||
| auto& temp_storage = | ||
| reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm); | ||
| temp_storage.max_val = 0; | ||
|
|
||
| float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM, | ||
| RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>( | ||
| probs, row_idx, d, temp_storage); | ||
|
|
||
| double low = 0, high = max_val; | ||
| float min_gt_low, max_le_high; | ||
| float sum_low = 1; | ||
| // f(x) = len(nonzero(probs > x)), f(x) is non-increasing | ||
| // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} | ||
| // loop invariant: | ||
| // - f(low) >= k, f(high) < k | ||
| // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) | ||
| // stopping condition: min_gt_low == max_le_high | ||
| // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k | ||
| do { | ||
| double pivot_0 = (high + 2 * low) / 3; | ||
| double pivot_1 = (2 * high + low) / 3; | ||
|
|
||
| ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0}; | ||
| min_gt_low = high; | ||
| max_le_high = low; | ||
| ValueCount<float> threadlocal_aggregate_gt_pivot_0{0, 0}, | ||
| threadlocal_aggregate_gt_pivot_1{0, 0}; | ||
| #pragma unroll 1 | ||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| probs_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| } | ||
| ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| probs_gt_pivot_0_pair[j] = { | ||
| (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, | ||
| (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; | ||
| probs_gt_pivot_1_pair[j] = { | ||
| (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, | ||
| (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; | ||
|
|
||
| if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { | ||
| min_gt_low = min(min_gt_low, probs_vec[j]); | ||
| } | ||
| if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { | ||
| max_le_high = max(max_le_high, probs_vec[j]); | ||
| } | ||
| threadlocal_aggregate_gt_pivot_0 += probs_gt_pivot_0_pair[j]; | ||
| threadlocal_aggregate_gt_pivot_1 += probs_gt_pivot_1_pair[j]; | ||
| } | ||
| } | ||
| aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>( | ||
| temp_storage.block_prim.reduce_value_count) | ||
| .Sum(threadlocal_aggregate_gt_pivot_0); | ||
| __syncthreads(); | ||
| /*! | ||
| * \brief Build local histogram for one round of radix select. | ||
| * | ||
| * Counts elements in shared_ordered that match the current prefix and bins them | ||
| * by their byte at the current shift position. | ||
| * | ||
| * \tparam OrderedType The ordered integer type (uint16_t or uint32_t) | ||
| * \param shared_ordered Shared memory containing ordered values | ||
| * \param actual_chunk_size Number of elements in this CTA's chunk | ||
| * \param local_histogram Output shared memory histogram | ||
| * \param prefix Current prefix (high bits determined so far) | ||
| * \param shift Bit shift for extracting current byte | ||
| * \param round Current round (0 to NUM_ROUNDS-1) | ||
| * \param tx Thread index | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, typename OrderedType> | ||
| __device__ __forceinline__ void RadixBuildLocalHistogram(const OrderedType* shared_ordered, | ||
| uint32_t actual_chunk_size, | ||
| uint32_t* local_histogram, uint32_t prefix, | ||
| uint32_t shift, uint32_t round, | ||
| uint32_t tx) { | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
|
|
||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered = shared_ordered[i]; | ||
|
|
||
| // Check if this element matches the prefix (high bits determined so far) | ||
| OrderedType mask = | ||
| (round == 0) | ||
| ? OrderedType(0) | ||
| : static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS)); | ||
| if ((ordered & mask) == static_cast<OrderedType>(prefix)) { | ||
| uint32_t bucket = (ordered >> shift) & 0xFF; | ||
| atomicAdd(&local_histogram[bucket], 1); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>( | ||
| temp_storage.block_prim.reduce_value_count) | ||
| .Sum(threadlocal_aggregate_gt_pivot_1); | ||
| __syncthreads(); | ||
| /*! | ||
| * \brief Perform one round of radix select with optional multi-CTA synchronization. | ||
| * | ||
| * This is the core radix select logic used by all TopK kernels. | ||
| * It builds histogram, aggregates across CTAs (if multi-CTA), computes suffix sum, | ||
| * and finds the threshold bucket. | ||
| * | ||
| * \tparam BLOCK_THREADS Number of threads per block | ||
| * \tparam SINGLE_CTA True if single-CTA mode (no inter-CTA sync needed) | ||
| * \tparam OrderedType The ordered integer type | ||
| * | ||
| * \param shared_ordered Shared memory containing ordered values | ||
| * \param actual_chunk_size Number of elements in this CTA's chunk | ||
| * \param local_histogram Shared memory for local histogram (size RADIX) | ||
| * \param suffix_sum Shared memory for suffix sum computation (size RADIX) | ||
| * \param state Pointer to RadixRowState for multi-CTA sync (nullptr if SINGLE_CTA) | ||
| * \param prefix Current prefix value | ||
| * \param remaining_k Current remaining k value | ||
| * \param round Current round (0 to NUM_ROUNDS-1) | ||
| * \param barrier_phase Reference to barrier phase counter | ||
| * \param ctas_per_group Number of CTAs per group | ||
| * \param tx Thread index | ||
| * \param out_new_prefix Output: updated prefix after this round | ||
| * \param out_new_remaining_k Output: updated remaining_k after this round | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, bool SINGLE_CTA, typename OrderedType> | ||
| __device__ __forceinline__ void RadixSelectOneRound( | ||
| const OrderedType* shared_ordered, uint32_t actual_chunk_size, uint32_t* local_histogram, | ||
| uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state, uint32_t prefix, | ||
| uint32_t remaining_k, uint32_t round, uint32_t iter, int& barrier_phase, uint32_t ctas_per_group, | ||
| uint32_t cta_in_group, uint32_t tx, uint32_t* out_new_prefix, uint32_t* out_new_remaining_k) { | ||
| constexpr uint32_t RADIX = 256; | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
| constexpr uint32_t NUM_ROUNDS = ORDERED_BITS / RADIX_BITS; | ||
| uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS; | ||
| uint32_t global_round = iter * NUM_ROUNDS + round; | ||
|
|
||
| // For multi-CTA: pointers to global histograms (triple buffer) | ||
| uint32_t* current_hist = nullptr; | ||
| uint32_t* next_hist = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| current_hist = state->histogram[global_round % 3]; | ||
| next_hist = state->histogram[(global_round + 1) % 3]; | ||
| } | ||
|
|
||
| min_gt_low = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(min_gt_low, MinReduceOp{}); | ||
| __syncthreads(); | ||
| max_le_high = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(max_le_high, MaxReduceOp{}); | ||
| if (tx == 0) { | ||
| temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; | ||
| temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; | ||
| temp_storage.min_val = min_gt_low; | ||
| temp_storage.max_val = max_le_high; | ||
| // Clear local histogram only | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| local_histogram[i] = 0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Build local histogram from shared memory | ||
| RadixBuildLocalHistogram<BLOCK_THREADS, OrderedType>(shared_ordered, actual_chunk_size, | ||
| local_histogram, prefix, shift, round, tx); | ||
| __syncthreads(); | ||
|
|
||
| // For multi-CTA: write → (leading CTA clears next) → barrier → read | ||
| // For single-CTA: local_histogram is already the complete histogram | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Accumulate local histogram to global | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| if (local_histogram[i] > 0) { | ||
| atomicAdd(¤t_hist[i], local_histogram[i]); | ||
| } | ||
| __syncthreads(); | ||
| aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0]; | ||
| aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1]; | ||
| min_gt_low = temp_storage.min_val; | ||
| max_le_high = temp_storage.max_val; | ||
|
|
||
| if (aggregate_gt_pivot_1.count >= k) { | ||
| low = pivot_1; | ||
| sum_low = float(aggregate_gt_pivot_1.value); | ||
| } else if (aggregate_gt_pivot_0.count >= k) { | ||
| low = pivot_0; | ||
| high = min(pivot_1, max_le_high); | ||
| sum_low = float(aggregate_gt_pivot_0.value); | ||
| } else { | ||
| high = min(pivot_0, max_le_high); | ||
| } | ||
|
|
||
| // Only leading CTA clears next round's histogram BEFORE barrier | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| next_hist[i] = 0; | ||
| } | ||
| } while (min_gt_low != max_le_high); | ||
| } | ||
|
|
||
| // Barrier: wait for all CTAs to finish atomicAdd and clearing | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| normalizer = math::ptx_rcp(max(sum_low, 1e-8)); | ||
| pivot = low; | ||
| // Read current histogram (after barrier, all atomicAdds are complete) | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = current_hist[i]; | ||
| } | ||
| } else { | ||
| // Single-CTA: copy local histogram directly to suffix_sum | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = local_histogram[i]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Compute suffix sum | ||
| RadixSuffixSum<BLOCK_THREADS>(suffix_sum, tx); | ||
|
|
||
| // Find threshold bucket using shared_scalars for found_bucket and found_remaining_k | ||
| // shared_scalars[0] = found_bucket, shared_scalars[1] = found_remaining_k | ||
| RadixFindThresholdBucket(suffix_sum, remaining_k, &shared_scalars[0], &shared_scalars[1], tx); | ||
|
|
||
| // Output new prefix and remaining_k | ||
| *out_new_prefix = prefix | (shared_scalars[0] << shift); | ||
| *out_new_remaining_k = shared_scalars[1]; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Find the k-th largest element pivot using radix select. | ||
| * | ||
| * This is the main entry point for the radix select algorithm. | ||
| * It performs NUM_ROUNDS of radix select to find the exact pivot value. | ||
| * | ||
| * \tparam BLOCK_THREADS Number of threads per block | ||
| * \tparam VEC_SIZE Vector size for memory access | ||
| * \tparam SINGLE_CTA True if single-CTA mode | ||
| * \tparam DType Data type (float, half, nv_bfloat16) | ||
| * | ||
| * \param input Input data pointer (for this row) | ||
| * \param shared_ordered Shared memory for ordered values | ||
| * \param local_histogram Shared memory for local histogram | ||
| * \param suffix_sum Shared memory for suffix sum | ||
| * \param shared_scalars Shared memory for temporary scalar values (size >= 2) | ||
| * \param state RadixRowState pointer (nullptr if SINGLE_CTA) | ||
| * \param chunk_start Start index in vocab for this CTA | ||
| * \param actual_chunk_size Number of elements in this chunk | ||
| * \param k Number of top elements to select | ||
| * \param barrier_phase Reference to barrier phase counter | ||
| * \param ctas_per_group Number of CTAs per group | ||
| * \param tx Thread index | ||
| * \return The pivot value (k-th largest element) | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType> | ||
| __device__ __forceinline__ DType RadixSelectFindPivot( | ||
| const DType* input, typename RadixTopKTraits<DType>::OrderedType* shared_ordered, | ||
| uint32_t* local_histogram, uint32_t* suffix_sum, uint32_t* shared_scalars, RadixRowState* state, | ||
| uint32_t chunk_start, uint32_t actual_chunk_size, uint32_t k, int& barrier_phase, | ||
| uint32_t ctas_per_group, uint32_t cta_in_group, uint32_t tx, uint32_t iter = 0) { | ||
| using Traits = RadixTopKTraits<DType>; | ||
| using OrderedType = typename Traits::OrderedType; | ||
| constexpr uint32_t RADIX = 256; | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
| constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>(); | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
|
|
||
| // Stage 1: Load and convert to ordered representation | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
| vec_t<DType, VEC_SIZE> data_vec; | ||
|
|
||
| // normalize | ||
| #pragma unroll 2 | ||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| probs_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| } | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| data_vec.cast_load(input + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; | ||
| shared_ordered[i + j] = Traits::ToOrdered(data_vec[j]); | ||
| } | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| shared_ordered[i] = Traits::ToOrdered(input[chunk_start + i]); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Initialize prefix and remaining_k | ||
| uint32_t prefix = 0; | ||
| uint32_t remaining_k = k; | ||
|
|
||
| // Initial barrier (skip for single CTA) | ||
| // Histograms are pre-cleared externally (Python side) and cleared at end of each iteration | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
| } | ||
| } | ||
|
|
||
| template <typename DType> | ||
| cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, float* top_p_arr, | ||
| uint32_t batch_size, float top_p_val, uint32_t d, | ||
| cudaStream_t stream = 0) { | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); | ||
| // Stage 2: NUM_ROUNDS of radix select | ||
| // Double buffer with leading CTA clearing at start of each round | ||
| for (uint32_t round = 0; round < NUM_ROUNDS; ++round) { | ||
| uint32_t new_prefix, new_remaining_k; | ||
| RadixSelectOneRound<BLOCK_THREADS, SINGLE_CTA, OrderedType>( | ||
| shared_ordered, actual_chunk_size, local_histogram, suffix_sum, shared_scalars, state, | ||
| prefix, remaining_k, round, iter, barrier_phase, ctas_per_group, cta_in_group, tx, | ||
| &new_prefix, &new_remaining_k); | ||
| prefix = new_prefix; | ||
| remaining_k = new_remaining_k; | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| auto compute_capacity = GetCudaComputeCapability(); | ||
| DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { | ||
| const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>); | ||
| dim3 nblks(batch_size); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d}; | ||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| auto kernel = TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| }); | ||
| return cudaSuccess; | ||
| }); | ||
| // Convert final ordered representation back to DType pivot | ||
| return Traits::FromOrdered(static_cast<OrderedType>(prefix)); | ||
| } | ||
|
|
||
| template <typename DType, typename IdType> | ||
| cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, | ||
| uint32_t batch_size, uint32_t top_k_val, uint32_t d, | ||
| cudaStream_t stream = 0) { | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); | ||
| template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType, | ||
| typename IdType> | ||
| __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_MultiCTA( | ||
| DType* logits, // [batch, vocab_size] | ||
| DType* masked_logits, // [batch, vocab_size] | ||
| IdType* top_k_arr, // [batch] or nullptr | ||
| uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size, | ||
| RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA) | ||
| uint32_t chunk_size, // elements per CTA | ||
| uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA) | ||
| { | ||
| // Type traits for FP16/BF16/FP32 support | ||
| using Traits = RadixTopKTraits<DType>; | ||
| using OrderedType = typename Traits::OrderedType; | ||
|
|
||
| constexpr uint32_t RADIX = 256; // 8-bit radix | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
| constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>(); | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
|
|
||
| const uint32_t global_cta_id = blockIdx.x; | ||
| const uint32_t group_id = global_cta_id / ctas_per_group; | ||
| const uint32_t cta_in_group = global_cta_id % ctas_per_group; | ||
| const uint32_t tx = threadIdx.x; | ||
|
|
||
| auto compute_capacity = GetCudaComputeCapability(); | ||
| DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { | ||
| const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>); | ||
| dim3 nblks(batch_size); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; | ||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>; | ||
| // Shared memory layout: [fixed storage] [ordered values cache] | ||
| extern __shared__ uint8_t smem[]; | ||
|
|
||
| // Fixed shared memory (at the beginning) | ||
| constexpr size_t fixed_smem_size = | ||
| sizeof(uint32_t) * (RADIX + RADIX + 4); // histogram + suffix + 4 scalars | ||
| uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem); | ||
| uint32_t* suffix_sum = local_histogram + RADIX; | ||
| uint32_t* shared_scalars = | ||
| suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket, found_remaining_k] | ||
|
|
||
| // Align ordered values cache to 16 bytes | ||
| size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16; | ||
| OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset); | ||
|
|
||
| // Aliases for scalar shared variables | ||
| #define prefix_cache shared_scalars[0] | ||
| #define remaining_k_cache shared_scalars[1] | ||
| #define found_bucket shared_scalars[2] | ||
| #define found_remaining_k shared_scalars[3] | ||
|
|
||
| // State pointer only used when not SINGLE_CTA | ||
| RadixRowState* state = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| state = &row_states[group_id]; | ||
| } | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; | ||
| uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups; | ||
|
|
||
| int barrier_phase = 0; | ||
|
|
||
| // Persistent loop over rows | ||
| for (uint32_t iter = 0; iter < total_iterations; iter++) { | ||
| uint32_t row_idx = group_id + iter * num_groups; | ||
|
|
||
| if (row_idx >= batch_size) break; | ||
|
|
||
| const uint32_t chunk_start = cta_in_group * chunk_size; | ||
| const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size); | ||
|
|
||
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; | ||
|
|
||
| DType pivot = Traits::NegInf(); | ||
|
|
||
| const uint32_t actual_chunk_size = chunk_end - chunk_start; | ||
|
|
||
| if (k >= vocab_size) { | ||
| // k >= vocab_size: no masking needed, just copy | ||
| vec_t<DType, VEC_SIZE> logits_vec_copy; | ||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < actual_chunk_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| if (i + VEC_SIZE <= actual_chunk_size) { | ||
| logits_vec_copy.cast_load(logits + row_idx * vocab_size + chunk_start + i); | ||
| logits_vec_copy.store(masked_logits + row_idx * vocab_size + chunk_start + i); | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = (actual_chunk_size / VEC_SIZE) * VEC_SIZE + tx; i < actual_chunk_size; | ||
| i += BLOCK_THREADS) { | ||
| masked_logits[row_idx * vocab_size + chunk_start + i] = | ||
| logits[row_idx * vocab_size + chunk_start + i]; | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // ========== Stage 1: Load and convert to ordered representation in shared memory ========== | ||
| // This is done ONCE per row, avoiding NUM_ROUNDS global memory reads | ||
| vec_t<DType, VEC_SIZE> logits_vec; | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| logits_vec.cast_load(logits + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| // Use type traits for FP16/BF16/FP32 support | ||
| shared_ordered[i + j] = Traits::ToOrdered(logits_vec[j]); | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| shared_ordered[i] = Traits::ToOrdered(logits[row_idx * vocab_size + chunk_start + i]); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Initialize local caches | ||
| if (tx == 0) { | ||
| prefix_cache = 0; | ||
| remaining_k_cache = k; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // ========== Stage 2: NUM_ROUNDS of radix select ========== | ||
| // Triple-buffer optimization: only 1 barrier per round | ||
| // - Use global_round = iter * NUM_ROUNDS + round for buffer indexing | ||
| // - Only leading CTA clears next buffer before barrier | ||
| for (uint32_t round = 0; round < NUM_ROUNDS; ++round) { | ||
| uint32_t global_round = iter * NUM_ROUNDS + round; | ||
| uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS; | ||
| // Read from local cache (no global memory access needed!) | ||
| uint32_t prefix = prefix_cache; | ||
| uint32_t remaining_k = remaining_k_cache; | ||
|
|
||
| // For multi-CTA: pointers to global histograms (triple buffer) | ||
| uint32_t* current_hist = nullptr; | ||
| uint32_t* next_hist = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| current_hist = state->histogram[global_round % 3]; | ||
| next_hist = state->histogram[(global_round + 1) % 3]; | ||
| } | ||
|
|
||
| // Clear local histogram only | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| local_histogram[i] = 0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Build local histogram from SHARED MEMORY (no global memory access!) | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered = shared_ordered[i]; | ||
|
|
||
| // Check if this element matches the prefix (high bits determined so far) | ||
| // Use generic mask based on OrderedType bits | ||
| OrderedType mask = | ||
| (round == 0) | ||
| ? OrderedType(0) | ||
| : static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS)); | ||
| if ((ordered & mask) == static_cast<OrderedType>(prefix)) { | ||
| uint32_t bucket = (ordered >> shift) & 0xFF; | ||
| atomicAdd(&local_histogram[bucket], 1); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // For multi-CTA: write → (leading CTA clears next) → barrier → read | ||
| // For single-CTA: local_histogram is already the complete histogram | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Accumulate local histogram to global | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| if (local_histogram[i] > 0) { | ||
| atomicAdd(¤t_hist[i], local_histogram[i]); | ||
| } | ||
| } | ||
|
|
||
| // Only leading CTA clears next round's histogram BEFORE barrier | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| next_hist[i] = 0; | ||
| } | ||
| } | ||
|
|
||
| // Barrier: wait for all CTAs to finish atomicAdd and clearing | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| // Read current histogram (after barrier, all atomicAdds are complete) | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = current_hist[i]; | ||
| } | ||
| } else { | ||
| // Single-CTA: copy local histogram directly to suffix_sum | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = local_histogram[i]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Parallel suffix sum in shared memory (much faster than global memory!) | ||
| // Compute count of elements >= each bucket value | ||
| for (uint32_t stride = 1; stride < RADIX; stride *= 2) { | ||
| uint32_t val = 0; | ||
| if (tx < RADIX) { | ||
| val = suffix_sum[tx]; | ||
| if (tx + stride < RADIX) { | ||
| val += suffix_sum[tx + stride]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| if (tx < RADIX) { | ||
| suffix_sum[tx] = val; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // ALL CTAs: find threshold bucket (all compute same result) | ||
| // Use shared variable to communicate the found bucket (via macros to shared_scalars[2..3]) | ||
| if (tx == 0) { | ||
| found_bucket = 0; | ||
| found_remaining_k = remaining_k; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (tx < RADIX) { | ||
| uint32_t count_ge = suffix_sum[tx]; | ||
| uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0; | ||
| if (count_ge >= remaining_k && count_gt < remaining_k) { | ||
| found_bucket = tx; | ||
| found_remaining_k = remaining_k - count_gt; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Update local caches (all CTAs have same values) | ||
| if (tx == 0) { | ||
| prefix_cache = prefix | (found_bucket << shift); | ||
| remaining_k_cache = found_remaining_k; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Convert final ordered representation back to DType pivot using type traits | ||
| OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache); | ||
| pivot = Traits::FromOrdered(ordered_pivot); | ||
|
|
||
| // ========== Stage 3: Final masking pass ========== | ||
| // Reuse logits_vec from Stage 1 | ||
| const DType neg_inf = Traits::NegInf(); | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| logits_vec.cast_load(logits + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| logits_vec[j] = (logits_vec[j] >= pivot) ? logits_vec[j] : neg_inf; | ||
| } | ||
| logits_vec.store(masked_logits + row_idx * vocab_size + chunk_start + i); | ||
| } | ||
|
|
||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| DType val = logits[row_idx * vocab_size + chunk_start + i]; | ||
| masked_logits[row_idx * vocab_size + chunk_start + i] = (val >= pivot) ? val : neg_inf; | ||
| } | ||
| } | ||
|
|
||
| // Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Only leading CTA clears the buffers using release semantics | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t buf = 0; buf < 3; ++buf) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| st_release(&state->histogram[buf][i], 0u); | ||
| } | ||
| } | ||
|
|
||
| if (tx == 0) { | ||
| st_release(&state->arrival_counter, 0); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #undef prefix_cache | ||
| #undef remaining_k_cache | ||
| #undef found_bucket | ||
| #undef found_remaining_k | ||
| } | ||
|
|
||
| template <typename DType, typename IdType> | ||
| cudaError_t RadixTopKMaskLogitsMultiCTA(DType* logits, DType* masked_logits, IdType* top_k_arr, | ||
| uint32_t batch_size, uint32_t top_k_val, | ||
| uint32_t vocab_size, RadixRowState* row_states_buffer, | ||
| cudaStream_t stream = 0) { | ||
| using OrderedType = typename RadixTopKTraits<DType>::OrderedType; | ||
| constexpr uint32_t BLOCK_THREADS = 1024; | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); | ||
|
|
||
| // Get device properties | ||
| int device; | ||
| FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); | ||
| int num_sms; | ||
| FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device)); | ||
| int max_smem_per_block; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); | ||
|
|
||
| // Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4); | ||
| constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); | ||
|
|
||
| // Calculate max chunk size that fits in shared memory | ||
| const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; | ||
| uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); | ||
| max_chunk_elements = round_down(max_chunk_elements, vec_size); | ||
| const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| chunk_size = round_up(chunk_size, vec_size); | ||
| chunk_size = std::min(chunk_size, max_chunk_elements); | ||
|
|
||
| const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType); | ||
| const bool single_cta = (ctas_per_group == 1); | ||
|
|
||
| // Calculate number of groups (how many rows to process concurrently) | ||
| uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size); | ||
| if (num_groups == 0) num_groups = 1; | ||
| uint32_t total_ctas = num_groups * ctas_per_group; | ||
|
|
||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| if (single_cta) { | ||
| auto kernel = | ||
| RadixTopKMaskLogitsKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
|
|
||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &vocab_size, | ||
| &batch_size, &row_states_buffer, &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| }); | ||
| return cudaSuccess; | ||
| } else { | ||
| auto kernel = | ||
| RadixTopKMaskLogitsKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
|
|
||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &vocab_size, | ||
| &batch_size, &row_states_buffer, &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| } | ||
| }); | ||
|
|
||
| return cudaSuccess; | ||
| } | ||
|
|
||
| // ==================== Multi-CTA Radix Top-K Renorm Probs ==================== | ||
|
|
||
| /*! | ||
| * \brief Multi-CTA Radix Top-K RenormProb kernel with unified single/multi-CTA paths. | ||
| * | ||
| * Finds the k-th largest probability, then normalizes all probs >= pivot to sum to 1, | ||
| * setting all others to 0. Uses the shared RadixSelectFindPivot function. | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType, | ||
| typename IdType> | ||
| __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_MultiCTA( | ||
| DType* probs, // [batch, vocab_size] | ||
| DType* renormed_prob, // [batch, vocab_size] | ||
| IdType* top_k_arr, // [batch] or nullptr | ||
| uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size, | ||
| RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA) | ||
| uint32_t chunk_size, // elements per CTA | ||
| uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA) | ||
| { | ||
| using Traits = RadixTopKTraits<DType>; | ||
| using OrderedType = typename Traits::OrderedType; | ||
|
|
||
| constexpr uint32_t RADIX = 256; | ||
|
|
||
| const uint32_t global_cta_id = blockIdx.x; | ||
| const uint32_t group_id = global_cta_id / ctas_per_group; | ||
| const uint32_t cta_in_group = global_cta_id % ctas_per_group; | ||
| const uint32_t tx = threadIdx.x; | ||
|
|
||
| // Shared memory layout: [fixed storage] [ordered values cache] | ||
| extern __shared__ uint8_t smem[]; | ||
|
|
||
| // Fixed shared memory (at the beginning) | ||
| // histogram[256] + suffix[256] + scalars[4] + sum_local[1] | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 4) + sizeof(float); | ||
| uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem); | ||
| uint32_t* suffix_sum = local_histogram + RADIX; | ||
| uint32_t* shared_scalars = suffix_sum + RADIX; | ||
| float* shared_sum = reinterpret_cast<float*>(shared_scalars + 4); | ||
|
|
||
| // Align ordered values cache to 16 bytes | ||
| size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16; | ||
| OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset); | ||
|
|
||
| // State pointer only used when not SINGLE_CTA | ||
| RadixRowState* state = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| state = &row_states[group_id]; | ||
| } | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; | ||
| uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups; | ||
|
|
||
| int barrier_phase = 0; | ||
|
|
||
| // Persistent loop over rows | ||
| for (uint32_t iter = 0; iter < total_iterations; iter++) { | ||
| uint32_t row_idx = group_id + iter * num_groups; | ||
|
|
||
| if (row_idx >= batch_size) break; | ||
|
|
||
| const uint32_t chunk_start = cta_in_group * chunk_size; | ||
| const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size); | ||
| const uint32_t actual_chunk_size = chunk_end - chunk_start; | ||
|
|
||
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; | ||
|
|
||
| // For RenormProb, pivot is compared with probs (must be non-negative) | ||
| DType pivot = DType(0); | ||
| float normalizer = 1.0f; | ||
|
|
||
| if (k >= vocab_size) { | ||
| // k >= vocab_size: no filtering needed, just compute sum and renormalize | ||
| // Stage 1: Compute sum | ||
| float thread_sum = 0.0f; | ||
| vec_t<DType, VEC_SIZE> data_vec; | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| thread_sum += float(data_vec[j]); | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| thread_sum += float(probs[row_idx * vocab_size + chunk_start + i]); | ||
| } | ||
|
|
||
| // Block reduction for sum | ||
| typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce; | ||
| __shared__ typename BlockReduce::TempStorage temp_storage; | ||
| float block_sum = BlockReduce(temp_storage).Sum(thread_sum); | ||
| __syncthreads(); | ||
|
|
||
| if constexpr (!SINGLE_CTA) { | ||
| // Multi-CTA: atomic add to global sum | ||
| if (tx == 0) { | ||
| if (cta_in_group == 0) { | ||
| state->sum_topk = 0.0f; // First CTA initializes | ||
| } | ||
| } | ||
| // Barrier for initialization | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| if (tx == 0 && block_sum > 0) { | ||
| atomicAdd(&state->sum_topk, block_sum); | ||
| } | ||
|
|
||
| // Barrier to ensure all CTAs have contributed | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f)); | ||
| } else { | ||
| // Single-CTA: use block_sum directly | ||
| if (tx == 0) { | ||
| *shared_sum = block_sum; | ||
| } | ||
| __syncthreads(); | ||
| normalizer = math::ptx_rcp(max(*shared_sum, 1e-8f)); | ||
| } | ||
|
|
||
| // Normalize and store | ||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| data_vec[j] = DType(float(data_vec[j]) * normalizer); | ||
| } | ||
| data_vec.store(renormed_prob + row_idx * vocab_size + chunk_start + i); | ||
| } | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| renormed_prob[row_idx * vocab_size + chunk_start + i] = | ||
| DType(float(probs[row_idx * vocab_size + chunk_start + i]) * normalizer); | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // ========== Stage 1: Find pivot using RadixSelectFindPivot ========== | ||
| pivot = RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DType>( | ||
| probs + row_idx * vocab_size, shared_ordered, local_histogram, suffix_sum, shared_scalars, | ||
| state, chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, tx, | ||
| iter); | ||
|
|
||
| // ========== Stage 2: Compute sum of elements >= pivot ========== | ||
| float thread_sum = 0.0f; | ||
| vec_t<DType, VEC_SIZE> data_vec; | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| if (data_vec[j] >= pivot) { | ||
| thread_sum += float(data_vec[j]); | ||
| } | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| DType val = probs[row_idx * vocab_size + chunk_start + i]; | ||
| if (val >= pivot) { | ||
| thread_sum += float(val); | ||
| } | ||
| } | ||
|
|
||
| // Block reduction for sum | ||
| typedef cub::BlockReduce<float, BLOCK_THREADS> BlockReduce; | ||
| __shared__ typename BlockReduce::TempStorage temp_storage; | ||
| float block_sum = BlockReduce(temp_storage).Sum(thread_sum); | ||
| __syncthreads(); | ||
|
|
||
| if constexpr (!SINGLE_CTA) { | ||
| // Multi-CTA: atomic add to global sum | ||
| if (tx == 0) { | ||
| if (cta_in_group == 0) { | ||
| state->sum_topk = 0.0f; // First CTA initializes | ||
| } | ||
| } | ||
| // Barrier for initialization | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| if (tx == 0 && block_sum > 0) { | ||
| atomicAdd(&state->sum_topk, block_sum); | ||
| } | ||
|
|
||
| // Barrier to ensure all CTAs have contributed | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| normalizer = math::ptx_rcp(max(state->sum_topk, 1e-8f)); | ||
| } else { | ||
| // Single-CTA: use block_sum directly | ||
| if (tx == 0) { | ||
| *shared_sum = block_sum; | ||
| } | ||
| __syncthreads(); | ||
| normalizer = math::ptx_rcp(max(*shared_sum, 1e-8f)); | ||
| } | ||
|
|
||
| // ========== Stage 3: Normalize elements >= pivot, set others to 0 ========== | ||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| data_vec.cast_load(probs + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| data_vec[j] = (data_vec[j] >= pivot) ? DType(float(data_vec[j]) * normalizer) : DType(0); | ||
| } | ||
| data_vec.store(renormed_prob + row_idx * vocab_size + chunk_start + i); | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| DType val = probs[row_idx * vocab_size + chunk_start + i]; | ||
| renormed_prob[row_idx * vocab_size + chunk_start + i] = | ||
| (val >= pivot) ? DType(float(val) * normalizer) : DType(0); | ||
| } | ||
| } | ||
|
|
||
| // Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Only leading CTA clears the buffers using release semantics | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t buf = 0; buf < 3; ++buf) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| st_release(&state->histogram[buf][i], 0u); | ||
| } | ||
| } | ||
|
|
||
| if (tx == 0) { | ||
| st_release(&state->arrival_counter, 0); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename DType, typename IdType> | ||
| cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, | ||
| uint32_t batch_size, uint32_t top_k_val, uint32_t d, | ||
| cudaStream_t stream = 0) { | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); | ||
| cudaError_t RadixTopKRenormProbMultiCTA(DType* probs, DType* renormed_prob, IdType* top_k_arr, | ||
| uint32_t batch_size, uint32_t top_k_val, | ||
| uint32_t vocab_size, RadixRowState* row_states_buffer, | ||
| cudaStream_t stream = 0) { | ||
| using OrderedType = typename RadixTopKTraits<DType>::OrderedType; | ||
| constexpr uint32_t BLOCK_THREADS = 1024; | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); | ||
|
|
||
| // Get device properties | ||
| int device; | ||
| FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); | ||
| int num_sms; | ||
| FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device)); | ||
| int max_smem_per_block; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); | ||
|
|
||
| // Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars + 1 float | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4) + sizeof(float); | ||
| constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); | ||
|
|
||
| // Calculate max chunk size that fits in shared memory | ||
| const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; | ||
| uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); | ||
| max_chunk_elements = round_down(max_chunk_elements, vec_size); | ||
| const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| chunk_size = round_up(chunk_size, vec_size); | ||
| chunk_size = std::min(chunk_size, max_chunk_elements); | ||
|
|
||
| const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType); | ||
| const bool single_cta = (ctas_per_group == 1); | ||
|
|
||
| // Calculate number of groups (how many rows to process concurrently) | ||
| uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size); | ||
| if (num_groups == 0) num_groups = 1; | ||
| uint32_t total_ctas = num_groups * ctas_per_group; | ||
|
|
||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| if (single_cta) { | ||
| auto kernel = | ||
| RadixTopKRenormProbKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
|
|
||
| auto compute_capacity = GetCudaComputeCapability(); | ||
| DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { | ||
| const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>); | ||
| dim3 nblks(batch_size); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; | ||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>; | ||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &vocab_size, | ||
| &batch_size, &row_states_buffer, &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| } else { | ||
| auto kernel = | ||
| RadixTopKRenormProbKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
|
|
||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &vocab_size, | ||
| &batch_size, &row_states_buffer, &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| }); | ||
| return cudaSuccess; | ||
| } | ||
| }); | ||
|
|
||
| return cudaSuccess; | ||
| } | ||
|
|
||
| // ==================== Multi-CTA Radix Top-K (Returns Indices) ==================== | ||
|
|
||
| /*! | ||
| * \brief Multi-CTA Radix Top-K kernel that returns indices of top-k elements. | ||
| * | ||
| * Uses cooperative multi-CTA radix select to find the k-th largest element, | ||
| * then collects indices of all elements >= pivot. | ||
| */ | ||
| template <uint32_t BLOCK_THREADS, uint32_t VEC_SIZE, bool SINGLE_CTA, typename DType, | ||
| typename IdType> | ||
| __global__ void __launch_bounds__(BLOCK_THREADS) | ||
| RadixTopKKernel_MultiCTA(DType* input, // [batch, vocab_size] | ||
| IdType* output_indices, // [batch, top_k] | ||
| DType* output_values, // [batch, top_k] or nullptr | ||
| IdType* top_k_arr, // [batch] or nullptr | ||
| uint32_t top_k_val, uint32_t vocab_size, uint32_t batch_size, | ||
| RadixRowState* row_states, // [num_groups] (nullptr if SINGLE_CTA) | ||
| uint32_t chunk_size, // elements per CTA | ||
| uint32_t ctas_per_group) // CTAs per row (1 if SINGLE_CTA) | ||
| { | ||
| // Type traits for FP16/BF16/FP32 support | ||
| using Traits = RadixTopKTraits<DType>; | ||
| using OrderedType = typename Traits::OrderedType; | ||
|
|
||
| constexpr uint32_t RADIX = 256; | ||
| constexpr uint32_t RADIX_BITS = 8; | ||
| constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<RADIX_BITS>(); | ||
| constexpr uint32_t ORDERED_BITS = sizeof(OrderedType) * 8; | ||
|
|
||
| const uint32_t global_cta_id = blockIdx.x; | ||
| const uint32_t group_id = global_cta_id / ctas_per_group; | ||
| const uint32_t cta_in_group = global_cta_id % ctas_per_group; | ||
| const uint32_t tx = threadIdx.x; | ||
|
|
||
| // Shared memory layout: [fixed storage] [ordered values cache] | ||
| extern __shared__ uint8_t smem[]; | ||
|
|
||
| // Fixed shared memory (at the beginning) | ||
| // When SINGLE_CTA, we need an extra uint32 for output_counter (no global state) | ||
| constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4; | ||
| constexpr size_t fixed_smem_size = | ||
| sizeof(uint32_t) * (RADIX + RADIX + num_scalars); // histogram + suffix + scalars | ||
| uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem); | ||
| uint32_t* suffix_sum = local_histogram + RADIX; | ||
| uint32_t* shared_scalars = suffix_sum + RADIX; // [prefix_cache, remaining_k_cache, found_bucket, | ||
| // found_remaining_k, (output_counter)] | ||
|
|
||
| // Align ordered values cache to 16 bytes | ||
| size_t ordered_offset = ((fixed_smem_size + 15) / 16) * 16; | ||
| OrderedType* shared_ordered = reinterpret_cast<OrderedType*>(smem + ordered_offset); | ||
|
|
||
| // Aliases for scalar shared variables | ||
| #define prefix_cache shared_scalars[0] | ||
| #define remaining_k_cache shared_scalars[1] | ||
| #define found_bucket shared_scalars[2] | ||
| #define found_remaining_k shared_scalars[3] | ||
| #define shared_output_counter shared_scalars[4] // Only valid when SINGLE_CTA | ||
|
|
||
| // State pointer only used when not SINGLE_CTA | ||
| RadixRowState* state = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| state = &row_states[group_id]; | ||
| } | ||
|
|
||
| // Calculate total number of iterations for persistent loop | ||
| uint32_t num_groups = gridDim.x / ctas_per_group; | ||
| uint32_t total_iterations = (batch_size + num_groups - 1) / num_groups; | ||
|
|
||
| int barrier_phase = 0; | ||
|
|
||
| // Persistent loop over rows | ||
| for (uint32_t iter = 0; iter < total_iterations; iter++) { | ||
| uint32_t row_idx = group_id + iter * num_groups; | ||
|
|
||
| if (row_idx >= batch_size) break; | ||
|
|
||
| const uint32_t chunk_start = cta_in_group * chunk_size; | ||
| const uint32_t chunk_end = min(chunk_start + chunk_size, vocab_size); | ||
|
|
||
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; | ||
|
|
||
| const uint32_t actual_chunk_size = chunk_end - chunk_start; | ||
|
|
||
| if (k >= vocab_size) { | ||
| // k >= vocab_size: return all indices | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| if (chunk_start + i < k) { | ||
| output_indices[row_idx * top_k_val + chunk_start + i] = | ||
| static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + chunk_start + i] = | ||
| input[row_idx * vocab_size + chunk_start + i]; | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
| // ========== Stage 1: Load and convert to ordered representation in shared memory ========== | ||
| vec_t<DType, VEC_SIZE> input_vec; | ||
| const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; | ||
|
|
||
| #pragma unroll 2 | ||
| for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { | ||
| input_vec.cast_load(input + row_idx * vocab_size + chunk_start + i); | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| shared_ordered[i + j] = Traits::ToOrdered(input_vec[j]); | ||
| } | ||
| } | ||
| // Handle tail | ||
| for (uint32_t i = aligned_size + tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| shared_ordered[i] = Traits::ToOrdered(input[row_idx * vocab_size + chunk_start + i]); | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Initialize local caches and clear global state | ||
| if (tx == 0) { | ||
| prefix_cache = 0; | ||
| remaining_k_cache = k; | ||
| if constexpr (SINGLE_CTA) { | ||
| shared_output_counter = 0; // Use shared memory counter for single CTA | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Barrier to ensure all CTAs have arrived at this iteration (skip for single CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| // CTA 0 clears output counter AFTER barrier (needed for every iteration) | ||
| if (cta_in_group == 0 && tx == 0) { | ||
| st_release(&state->output_counter, 0); | ||
| } | ||
| } | ||
|
|
||
| // ========== Stage 2: NUM_ROUNDS of radix select ========== | ||
| // Triple-buffer optimization: only 1 barrier per round | ||
| // - Use global_round = iter * NUM_ROUNDS + round for buffer indexing | ||
| // - Only leading CTA clears next buffer before barrier | ||
| for (uint32_t round = 0; round < NUM_ROUNDS; ++round) { | ||
| uint32_t global_round = iter * NUM_ROUNDS + round; | ||
| uint32_t shift = ORDERED_BITS - (round + 1) * RADIX_BITS; | ||
| // Read from local cache (no global memory access needed!) | ||
| uint32_t prefix = prefix_cache; | ||
| uint32_t remaining_k = remaining_k_cache; | ||
|
|
||
| // For multi-CTA: pointers to global histograms (triple buffer) | ||
| uint32_t* current_hist = nullptr; | ||
| uint32_t* next_hist = nullptr; | ||
| if constexpr (!SINGLE_CTA) { | ||
| current_hist = state->histogram[global_round % 3]; | ||
| next_hist = state->histogram[(global_round + 1) % 3]; | ||
| } | ||
|
|
||
| // Clear local histogram only | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| local_histogram[i] = 0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Build local histogram from SHARED MEMORY (no global memory access!) | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered = shared_ordered[i]; | ||
|
|
||
| // Check if this element matches the prefix (high bits determined so far) | ||
| OrderedType mask = | ||
| (round == 0) | ||
| ? OrderedType(0) | ||
| : static_cast<OrderedType>(~OrderedType(0) << (ORDERED_BITS - round * RADIX_BITS)); | ||
| if ((ordered & mask) == static_cast<OrderedType>(prefix)) { | ||
| uint32_t bucket = (ordered >> shift) & 0xFF; | ||
| atomicAdd(&local_histogram[bucket], 1); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // For multi-CTA: write → (leading CTA clears next) → barrier → read | ||
| // For single-CTA: local_histogram is already the complete histogram | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Accumulate local histogram to global | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| if (local_histogram[i] > 0) { | ||
| atomicAdd(¤t_hist[i], local_histogram[i]); | ||
| } | ||
| } | ||
|
|
||
| // Only leading CTA clears next round's histogram BEFORE barrier | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| next_hist[i] = 0; | ||
| } | ||
| } | ||
|
|
||
| // Barrier: wait for all CTAs to finish atomicAdd and clearing | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| __syncthreads(); | ||
|
|
||
| // Read current histogram (after barrier, all atomicAdds are complete) | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = current_hist[i]; | ||
| } | ||
| } else { | ||
| // Single-CTA: copy local histogram directly to suffix_sum | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| suffix_sum[i] = local_histogram[i]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Parallel suffix sum in shared memory | ||
| for (uint32_t stride = 1; stride < RADIX; stride *= 2) { | ||
| uint32_t val = 0; | ||
| if (tx < RADIX) { | ||
| val = suffix_sum[tx]; | ||
| if (tx + stride < RADIX) { | ||
| val += suffix_sum[tx + stride]; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| if (tx < RADIX) { | ||
| suffix_sum[tx] = val; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // ALL CTAs: find threshold bucket (all compute same result) | ||
| if (tx == 0) { | ||
| found_bucket = 0; | ||
| found_remaining_k = remaining_k; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (tx < RADIX) { | ||
| uint32_t count_ge = suffix_sum[tx]; | ||
| uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0; | ||
| if (count_ge >= remaining_k && count_gt < remaining_k) { | ||
| found_bucket = tx; | ||
| found_remaining_k = remaining_k - count_gt; | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Update local caches (all CTAs have same values) | ||
| if (tx == 0) { | ||
| prefix_cache = prefix | (found_bucket << shift); | ||
| remaining_k_cache = found_remaining_k; | ||
| } | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Get final ordered pivot from prefix_cache | ||
| OrderedType ordered_pivot = static_cast<OrderedType>(prefix_cache); | ||
|
|
||
| // ========== Stage 3: Collect indices >= pivot ========== | ||
| // Two-pass approach to handle ties correctly: | ||
| // Pass 1: collect all elements strictly > pivot (these must be in top-k) | ||
| // Pass 2: fill remaining slots with elements == pivot | ||
| // | ||
| // Optimization for Pass 1 (> pivot): Use shared memory atomic to count locally, | ||
| // then one global atomic per CTA to get base position, then shared atomic to write. | ||
| // This works because all > pivot elements are guaranteed to be in top-k. | ||
| // | ||
| // For Pass 2 (== pivot): Use global atomic directly since we need cross-CTA | ||
| // coordination to respect the k limit (some == pivot elements may be truncated). | ||
|
|
||
| // Reuse local_histogram[0..1] as counters | ||
| #define local_counter local_histogram[0] | ||
| #define global_base local_histogram[1] | ||
|
|
||
| // Pass 1: Count elements > pivot locally, then write with one global atomic | ||
| if (tx == 0) { | ||
| local_counter = 0; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // First pass: count how many elements > pivot in this CTA | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val > ordered_pivot) { | ||
| atomicAdd(&local_counter, 1); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Get base position for this CTA | ||
| uint32_t cta_count_gt = local_counter; | ||
| if (tx == 0 && cta_count_gt > 0) { | ||
| if constexpr (SINGLE_CTA) { | ||
| global_base = atomicAdd(&shared_output_counter, cta_count_gt); | ||
| } else { | ||
| global_base = atomicAdd(&state->output_counter, cta_count_gt); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Second pass: write elements > pivot using local shared atomic for position | ||
| if (tx == 0) { | ||
| local_counter = 0; // Reset for use as write position | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| if (cta_count_gt > 0) { | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val > ordered_pivot) { | ||
| uint32_t local_pos = atomicAdd(&local_counter, 1); | ||
| int pos = global_base + local_pos; | ||
| // No need to check pos < k here since all > pivot elements are in top-k | ||
| output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Barrier to ensure all > pivot elements are collected first (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| if (tx == 0) { | ||
| red_release(&state->arrival_counter, 1); | ||
| } | ||
| int target = (barrier_phase + 1) * ctas_per_group; | ||
| wait_ge(&state->arrival_counter, target, tx); | ||
| barrier_phase++; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Pass 2: Write elements == pivot | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| OrderedType ordered_val = shared_ordered[i]; | ||
| if (ordered_val == ordered_pivot) { | ||
| int pos; | ||
| if constexpr (SINGLE_CTA) { | ||
| pos = atomicAdd(&shared_output_counter, 1); | ||
| } else { | ||
| pos = atomicAdd(&state->output_counter, 1); | ||
| } | ||
| if (pos < static_cast<int>(k)) { | ||
| output_indices[row_idx * top_k_val + pos] = static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + pos] = Traits::FromOrdered(ordered_val); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #undef local_counter | ||
| #undef global_base | ||
| // No barrier needed here - the barrier at the start of next iteration | ||
| // ensures all CTAs complete Stage 3 before output_counter is reset | ||
| } | ||
|
|
||
| // Clear histogram buffers and reset arrival counter for next kernel launch (only for multi-CTA) | ||
| if constexpr (!SINGLE_CTA) { | ||
| // Only leading CTA clears the buffers using release semantics | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t buf = 0; buf < 3; ++buf) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| st_release(&state->histogram[buf][i], 0u); | ||
| } | ||
| } | ||
|
|
||
| if (tx == 0) { | ||
| st_release(&state->arrival_counter, 0); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #undef prefix_cache | ||
| #undef remaining_k_cache | ||
| #undef found_bucket | ||
| #undef found_remaining_k | ||
| #undef shared_output_counter | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Launch multi-CTA Radix Top-K kernel (returns indices and optionally values) | ||
| * | ||
| * \param input Input tensor [batch_size, vocab_size] | ||
| * \param output_indices Output indices tensor [batch_size, top_k] | ||
| * \param output_values Output values tensor [batch_size, top_k] or nullptr if not needed | ||
| * \param top_k_arr Per-row top-k values or nullptr for uniform top_k | ||
| * \param batch_size Number of rows | ||
| * \param top_k_val Default top-k value (used when top_k_arr is nullptr) | ||
| * \param vocab_size Number of elements per row | ||
| * \param row_states_buffer Buffer for inter-CTA synchronization | ||
| * \param stream CUDA stream | ||
| */ | ||
| template <typename DType, typename IdType> | ||
| cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* output_values, | ||
| IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, | ||
| uint32_t vocab_size, RadixRowState* row_states_buffer, | ||
| cudaStream_t stream = 0) { | ||
| using OrderedType = typename RadixTopKTraits<DType>::OrderedType; | ||
| constexpr uint32_t BLOCK_THREADS = 1024; | ||
| const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size); | ||
|
|
||
| int device; | ||
| FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); | ||
| int num_sms; | ||
| FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device)); | ||
| int max_smem_per_block; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); | ||
|
|
||
| // Fixed smem: histogram[256] + suffix_sum[256] + scalars (5 for single-CTA path) | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 5); | ||
| constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); | ||
|
|
||
| const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; | ||
| uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); | ||
| max_chunk_elements = round_down(max_chunk_elements, vec_size); | ||
| const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| chunk_size = round_up(chunk_size, vec_size); | ||
| chunk_size = std::min(chunk_size, max_chunk_elements); | ||
|
|
||
| // Determine if we use single-CTA path | ||
| const bool single_cta = (ctas_per_group == 1); | ||
|
|
||
| // Calculate smem_size | ||
| const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType); | ||
|
|
||
| // Calculate number of groups (how many rows to process concurrently) | ||
| uint32_t num_groups = std::min(static_cast<uint32_t>(num_sms) / ctas_per_group, batch_size); | ||
| if (num_groups == 0) num_groups = 1; | ||
| uint32_t total_ctas = num_groups * ctas_per_group; | ||
|
|
||
| DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
| if (single_cta) { | ||
| auto kernel = RadixTopKKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, true, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&input, &output_indices, &output_values, &top_k_arr, | ||
| &top_k_val, &vocab_size, &batch_size, &row_states_buffer, | ||
| &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| } else { | ||
| auto kernel = RadixTopKKernel_MultiCTA<BLOCK_THREADS, VEC_SIZE, false, DType, IdType>; | ||
| FLASHINFER_CUDA_CALL( | ||
| cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
| dim3 nblks(total_ctas); | ||
| dim3 nthrs(BLOCK_THREADS); | ||
| void* args[] = {&input, &output_indices, &output_values, &top_k_arr, | ||
| &top_k_val, &vocab_size, &batch_size, &row_states_buffer, | ||
| &chunk_size, &ctas_per_group}; | ||
| FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
| } | ||
| }); | ||
|
|
||
| return cudaSuccess; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rerun clang-format and commit the formatted changes to unblock CI
The pre-commit hook detected formatting issues in the file. Please run clang-format locally and push the updated file to resolve the CI failure.
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 2030-2032: clang-format formatting check failed. Some files were modified by this hook.
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 1823 to 3373 the pre-commit hook
flagged formatting issues; run your project's clang-format configuration on this
file to apply the expected style, stage the updated file, create a commit with
the formatted changes, and push it to the branch so CI/pre-commit checks pass.
Ensure you use the repo’s clang-format config (or the same pre-commit hook) so
the exact formatting rules are applied before committing.
upd upd
3843536 to
c912dec
Compare
|
/bot run |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
include/flashinfer/sampling.cuh (3)
2843-2857: Same shared memory calculation bug as RadixTopKMaskLogitsMultiCTA.Lines 2851-2852 have the identical issue:
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);This can force
chunk_sizeto exceed available shared memory. Apply the same fix as suggested for lines 2509-2526.
3313-3326: Same shared memory calculation bug affects this launcher too.Lines 3320-3321:
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);Apply the same validation fix as suggested for the other launchers.
3216-3249: Out-of-bounds writes in index collection phase.Both passes of the index collection phase write to:
output_indices[row_idx * top_k_val + pos] output_values[row_idx * top_k_val + pos] // if not nullptrBut
poscan reachkwhich may exceedtop_k_valwhentop_k_arr != nullptr. The check at line 3245 (if (pos < static_cast<int>(k))) doesn't prevent the issue—it should checkpos < top_k_val.This is part of the same validation issue flagged for lines 2975-2992.
🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)
1840-1852: Comment clarification is helpful but could be more precise.The existing review comment suggested adding clarification about the reduction. The current comment at line 1844-1845 states "Release pattern using acq_rel fence + relaxed modifier" but doesn't explicitly mention that this also performs an atomic add (reduction). Consider expanding:
__device__ __forceinline__ void red_release(int* ptr, int val) { #if (__CUDA_ARCH__ >= 700) // SM70 and newer use memory consistency qualifiers - // Release pattern using acq_rel fence + relaxed modifier + // Release pattern using acq_rel fence + relaxed atomic add (reduction + release) // (The fence also releases data that was weakly-written by other threads prior to the last // syncthreads)Based on learnings, this addresses the past review comment suggesting clarification.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/sampling.cuh(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/sampling.cuh (6)
119-204: Type traits implementation looks solid.The
RadixTopKTraitsspecializations correctly implement IEEE 754 bit manipulation for descending order:
- Positive values: flip sign bit to maintain order
- Negative values: flip all bits to reverse order
- The negative infinity constants are correct for each format
The
num_rounds()template function appropriately computes the number of radix iterations based on the ordered type size and radix bits.
1826-1875: Synchronization primitives are correctly implemented.The acquire/release memory ordering for SM70+ and conservative fallbacks for older architectures are appropriate. The
wait_gefunction correctly has only thread 0 spin to avoid warp divergence overhead, followed by__syncthreads()to broadcast the result.
1880-1887: RadixRowState design is efficient.The triple-buffered histograms enable pipelining without extra synchronization barriers, which is a good optimization. The struct size (~3KB per state) is reasonable for the coordination it provides.
[approve_code_challenges]
1900-2189: Device helper functions are well-structured and correct.The radix selection algorithm is properly decomposed:
RadixSuffixSum: Efficient parallel suffix sumRadixFindThresholdBucket: Correct threshold detectionRadixBuildLocalHistogram: Proper prefix matching with atomic operationsRadixSelectOneRound: Complex but correct multi-CTA coordination with triple bufferingRadixSelectFindPivot: Clean entry pointThe triple-buffering pattern (using
global_round % 3) with the leading CTA clearing the next buffer before the barrier is a clever optimization that avoids race conditions.
2634-2713: Fast-path optimization for k >= vocab_size is well-implemented.The sum-and-normalize fast path correctly handles the case where no filtering is needed. The multi-CTA sum aggregation with proper barrier synchronization (initialize → barrier → atomic adds → barrier → read) is correct.
2321-2442: Inline radix select implementation is correct and well-optimized.The kernel inlines the radix selection logic with local caches (
prefix_cache,remaining_k_cache) to avoid repeated global memory access. The triple-buffering withglobal_round = iter * NUM_ROUNDS + roundcorrectly handles the persistent loop over multiple rows. The barrier choreography (write → leading CTA clears → barrier → read) matches the helper function and is race-free.
| // Fixed shared memory overhead: histogram[256] + suffix_sum[256] + 4 scalars | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (256 + 256 + 4); | ||
| constexpr size_t fixed_smem_aligned = round_up(fixed_smem_size, 16); | ||
|
|
||
| // Calculate max chunk size that fits in shared memory | ||
| const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned; | ||
| uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType); | ||
| max_chunk_elements = round_down(max_chunk_elements, vec_size); | ||
| const uint32_t min_chunk_size = vec_size * BLOCK_THREADS; | ||
| max_chunk_elements = std::max(max_chunk_elements, min_chunk_size); | ||
|
|
||
| uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements); | ||
| uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group); | ||
| chunk_size = round_up(chunk_size, vec_size); | ||
| chunk_size = std::min(chunk_size, max_chunk_elements); | ||
|
|
||
| const uint32_t smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType); | ||
| const bool single_cta = (ctas_per_group == 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Shared memory calculation can exceed available capacity.
The logic at lines 2517-2518 can inflate max_chunk_elements beyond the available shared memory:
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);If available_for_ordered is small (e.g., 10K elements worth), but min_chunk_size is large (e.g., 16K elements for vec_size=16, BLOCK_THREADS=1024), then max_chunk_elements becomes 16K. Subsequently:
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
chunk_size = std::min(chunk_size, max_chunk_elements); // Still 16K!The final smem_size = fixed_smem_aligned + chunk_size * sizeof(OrderedType) will exceed max_smem_per_block, causing cudaErrorInvalidConfiguration at kernel launch.
Apply the fix suggested in the past review:
const size_t available_for_ordered = max_smem_per_block - fixed_smem_aligned;
uint32_t max_chunk_elements = available_for_ordered / sizeof(OrderedType);
max_chunk_elements = round_down(max_chunk_elements, vec_size);
const uint32_t min_chunk_size = vec_size * BLOCK_THREADS;
- max_chunk_elements = std::max(max_chunk_elements, min_chunk_size);
+ if (max_chunk_elements < min_chunk_size) {
+ return cudaErrorInvalidConfiguration;
+ }
uint32_t ctas_per_group = ceil_div(vocab_size, max_chunk_elements);
uint32_t chunk_size = ceil_div(vocab_size, ctas_per_group);
chunk_size = round_up(chunk_size, vec_size);
- chunk_size = std::min(chunk_size, max_chunk_elements);Also applies to: lines 2843-2857 (RadixTopKRenormProbMultiCTA) and lines 3313-3326 (RadixTopKMultiCTA).
Based on learnings, this matches the critical issue flagged by gemini-code-assist in past reviews.
🤖 Prompt for AI Agents
include/flashinfer/sampling.cuh lines 2509-2526: the code allows
max_chunk_elements to be raised above the actual available_for_ordered by taking
std::max with min_chunk_size; to fix, compute available_limit =
round_down(available_for_ordered / sizeof(OrderedType), vec_size) first, then
cap the minimum chunk size to that limit (if min_chunk_size > available_limit
set min_chunk_size = available_limit) and finally set max_chunk_elements =
std::min(available_limit, /*previous computation*/ available_limit) or simply
use max_chunk_elements = available_limit after rounding; ensure chunk_size
calculation remains bounded by available_limit so smem_size cannot exceed
max_smem_per_block; apply the same clamping change to the analogous blocks at
lines 2843-2857 and 3313-3326.
| uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx]; | ||
|
|
||
| const uint32_t actual_chunk_size = chunk_end - chunk_start; | ||
|
|
||
| if (k >= vocab_size) { | ||
| // k >= vocab_size: return all indices | ||
| for (uint32_t i = tx; i < actual_chunk_size; i += BLOCK_THREADS) { | ||
| if (chunk_start + i < k) { | ||
| output_indices[row_idx * top_k_val + chunk_start + i] = | ||
| static_cast<IdType>(chunk_start + i); | ||
| if (output_values != nullptr) { | ||
| output_values[row_idx * top_k_val + chunk_start + i] = | ||
| input[row_idx * vocab_size + chunk_start + i]; | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Out-of-bounds writes when per-row k exceeds top_k_val.
At line 2975, k is read from top_k_arr[row_idx] which can differ from top_k_val:
uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[row_idx];However, output buffers are allocated with stride top_k_val:
output_indices[row_idx * top_k_val + pos](lines 2983, 3216, 3246)output_values[row_idx * top_k_val + pos](lines 2986, 3218, 3248)
If any top_k_arr[row_idx] > top_k_val, writes will overflow into the next row's buffer. For example, at lines 2982-2989:
if (chunk_start + i < k) {
output_indices[row_idx * top_k_val + chunk_start + i] = ...When k > top_k_val and chunk_start + i >= top_k_val, this writes beyond the allocated row.
Add validation in the host launcher (lines 3297-3364):
template <typename DType, typename IdType>
cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* output_values,
IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val,
uint32_t vocab_size, RadixRowState* row_states_buffer,
cudaStream_t stream = 0) {
using OrderedType = typename RadixTopKTraits<DType>::OrderedType;
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), vocab_size);
// Validate top_k_arr if provided
if (top_k_arr != nullptr) {
// Check that all top_k_arr[i] <= top_k_val
// For now, document the requirement or add host-side validation
// Kernel cannot safely validate this without extra overhead
}Alternatively, change the output buffer allocation to use max(top_k_arr) instead of top_k_val, but that requires upstream changes.
Also applies to lines 2982-2989 (same kernel, k >= vocab_size case).
Based on learnings, this matches the major issue flagged by coderabbitai in past reviews.
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2975-2992 there is an
out-of-bounds write when per-row k is read from top_k_arr[row_idx] but output
rows are laid out with stride top_k_val; if top_k_arr[row] > top_k_val writes
overflow the next row. Fix by validating top_k_arr in the host launcher
(RadixTopKMultiCTA around lines 3297-3364): if top_k_arr != nullptr iterate the
array on the host to ensure every top_k_arr[i] <= top_k_val and return a
cudaError / fail early (or document and assert) when violated; alternatively
change the allocation/stride to use max(top_k_arr) instead of top_k_val if
upstream allocation can be adjusted, but prefer the host-side validation and
explicit error return to prevent kernel OOB writes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/logits_processor/operators.py (1)
121-142: Castmaybe_top_k_arrtoint32before calling the native module.
These operators callget_sampling_module().top_k_{...}directly (Line 139-141, 185-187) but don’t apply the.int()cast that exists inflashinfer/sampling.py. If a user passestop_kas anint64tensor, this can break or produce wrong reads in C++.top_k = self._get_param("top_k", kwargs, required=True) maybe_top_k_arr, top_k_val = _to_tensor_scalar_tuple(top_k) @@ + if maybe_top_k_arr is not None: + maybe_top_k_arr = maybe_top_k_arr.int() @@ renorm_probs = get_sampling_module().top_k_renorm_probs( tensor.data, maybe_top_k_arr, top_k_val, row_states_buffer )Apply the same in
LogitsTopKOp.__call__.Also applies to: 167-187
♻️ Duplicate comments (4)
flashinfer/topk.py (1)
69-165: Concurrency hazard: cachedrow_states_bufferis shared across streams on the same device.
row_states_bufferis a single per-device cache entry (Line 138-145), but kernels mutate it for inter-CTA sync. Concurrenttop_k()calls on different CUDA streams can trample each other.Concrete options:
- include the current stream in the cache key (best-effort in Python), or
- allocate per-call when concurrency is expected, or
- clearly document “single-stream per device” usage constraints.
include/flashinfer/sampling.cuh (2)
2509-2524: Shared-memory sizing bug can force invalid configurations (must clamp, don’tstd::max).
These blocks can inflatemax_chunk_elementsbeyond whatavailable_for_orderedsupports (Line 2518-2519, 2851-2853, 3314-3316), which can makesmem_sizeexceedcudaDevAttrMaxSharedMemoryPerBlockOptinand fail launches.Use the pattern from the prior review: if
max_chunk_elements < min_chunk_size, returncudaErrorInvalidConfiguration(or reduceBLOCK_THREADS/vec) rather than forcingmax_chunk_elementsupward.Also applies to: 2843-2857, 3311-3321
2975-2988: Potential OOB writes whentop_k_arr[row] > top_k_val(validate contract in launcher).
Kernel indexes outputs with stridetop_k_valbut uses per-rowkfromtop_k_arr(Line 2975, writes at Line 2983-2986 and Line 3214-3216 / 3242-3244). If anyk > top_k_val, this will write past the row’s allocation.Fix: in
RadixTopKMultiCTA(...)(Line 3290+), iftop_k_arr != nullptr, validatemax(top_k_arr) <= top_k_valon host (or document + assert in debug builds) before launching.Also applies to: 3213-3216, 3241-3244, 3290-3358
flashinfer/logits_processor/operators.py (1)
132-138: Same multi-stream race risk: cachedrow_states_bufferis shared per device.
Even with op-specific names, each op still shares a single buffer across streams for that op+device. Concurrent calls can collide.Also applies to: 178-184
🧹 Nitpick comments (2)
flashinfer/topk.py (1)
54-63: Static analysis: unused args in fake op — either prefix with_or ignore in linter config.
Ruff flagsrow_states_buffer/output_valuesas unused in_fake_radix_topk(Line 55-62). Consider renaming to_row_states_buffer/_output_values.include/flashinfer/sampling.cuh (1)
1866-1875: Avoid hard spin inwait_ge(reduce contention / TDR risk).
while (ld_acquire(ptr) < target_val) {}(Line 1871-1872) can saturate memory and SM issue slots. Consider inserting__nanosleep()/ exponential backoff, and/or a max-iteration safety break in debug builds.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/flashinfer_topk_binding.cu(1 hunks)csrc/topk.cu(1 hunks)flashinfer/logits_processor/operators.py(2 hunks)flashinfer/topk.py(1 hunks)include/flashinfer/sampling.cuh(2 hunks)tests/utils/test_logits_processor.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/flashinfer_topk_binding.cu
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🧬 Code graph analysis (3)
flashinfer/logits_processor/operators.py (2)
flashinfer/utils.py (1)
_get_cache_buf(206-217)flashinfer/sampling.py (4)
top_k_renorm_probs(375-394)top_k_renorm_probs(1320-1390)top_k_mask_logits(410-430)top_k_mask_logits(1397-1465)
csrc/topk.cu (3)
flashinfer/topk.py (1)
top_k(69-164)flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(294-296)
flashinfer/topk.py (4)
flashinfer/jit/topk.py (1)
gen_topk_module(21-28)flashinfer/utils.py (1)
_get_cache_buf(206-217)csrc/flashinfer_topk_binding.cu (1)
radix_topk(20-21)csrc/topk.cu (2)
radix_topk(24-60)radix_topk(24-25)
🪛 Ruff (0.14.8)
flashinfer/topk.py
58-58: Unused function argument: row_states_buffer
(ARG001)
59-59: Unused function argument: output_values
(ARG001)
🔇 Additional comments (3)
tests/utils/test_logits_processor.py (2)
642-642: LGTM: Appropriate tolerance for floating-point probability comparisons.The change from strict equality to
torch.allcloseis reasonable given the multi-CTA optimizations and expanded data-type support introduced in this PR. Default tolerances (rtol=1e-5, atol=1e-8) are appropriate for comparing normalized probabilities.
666-666: Verify the necessity ofequal_nan=Truein top-k logits masking assertion.The test includes
equal_nan=Truewhen comparing LogitsPipe's TopK implementation against the referenceflashinfer.sampling.top_k_mask_logits. Both implementations are documented to set non-top-k values to -inf, yet the parameter suggests NaN values may be produced. Given that this test parametrizesneginf_inputto specifically test edge cases with -inf values in the logits, clarify whether:
- The multi-CTA kernel produces NaN in specific edge cases (e.g., all logits are -inf, fewer than k finite values)
- NaN production is intentional behavior or an unintended side effect
equal_nan=Trueis a temporary workaround or the correct long-term solutionIf NaN production is unintended, the kernel should be fixed to consistently produce -inf instead.
include/flashinfer/sampling.cuh (1)
119-204: Traits look consistent; good foundation for fp16/bf16 support.
The ordered transforms andNegInf()encodings are sensible and keep the radix code generic across float/half/bf16.
| void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, | ||
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | ||
| CHECK_INPUT(input); | ||
| CHECK_INPUT(output_indices); | ||
| CHECK_INPUT(output_values); | ||
| CHECK_DIM(2, input); // input: (batch_size, d) | ||
| CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k) | ||
| CHECK_DIM(2, output_values); // output_values: (batch_size, top_k) | ||
|
|
||
| unsigned int batch_size = input.size(0); | ||
| unsigned int d = input.size(1); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add basic shape + top_k range validation (prevents kernel OOB / nonsense launches).
Only rank is validated today. At minimum, enforce:
top_k > 0top_k <= input.size(1)output_{indices,values}.size(1) == top_k- batch dims match across tensors
This avoids silent misbehavior when callers pass a mismatched k vs output shapes.
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 24 to 35, add explicit shape and top_k range
validation: check top_k > 0 and top_k <= input.size(1), verify
output_indices.size(0) and output_values.size(0) equal input.size(0), verify
output_indices.size(1) and output_values.size(1) equal top_k, and ensure
maybe_row_states_buffer (if present) has a compatible batch dimension; use the
same CHECK_* macros (or error paths) already used in the file so failures are
reported consistently and prevent launching kernels with out-of-bounds sizes.
| void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, | ||
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | ||
| CHECK_INPUT(input); | ||
| CHECK_INPUT(output_indices); | ||
| CHECK_INPUT(output_values); | ||
| CHECK_DIM(2, input); // input: (batch_size, d) | ||
| CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k) | ||
| CHECK_DIM(2, output_values); // output_values: (batch_size, top_k) | ||
|
|
||
| unsigned int batch_size = input.size(0); | ||
| unsigned int d = input.size(1); | ||
|
|
||
| cudaSetDevice(input.device().device_id); | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| cudaError_t status; | ||
| auto dtype = input.dtype(); | ||
|
|
||
| // Get row_states_buffer if provided (for multi-CTA path) | ||
| sampling::RadixRowState* row_states_ptr = nullptr; | ||
| if (maybe_row_states_buffer.has_value()) { | ||
| row_states_ptr = | ||
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | ||
| } | ||
|
|
||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | ||
| status = sampling::RadixTopKMultiCTA<c_type, int32_t>( | ||
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()), | ||
| static_cast<c_type*>(output_values.data_ptr()), | ||
| nullptr, // top_k_arr | ||
| batch_size, static_cast<uint32_t>(top_k), d, row_states_ptr, stream); | ||
| return true; | ||
| }); | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "RadixTopK failed with error code " << cudaGetErrorString(status); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix BF16 dispatch + initialize status to avoid UB / wrong dtype behavior.
Right now the kernel dispatch only covers FP32/FP16 (Line 49), but the Python API advertises BF16 support. Also status is uninitialized if dispatch doesn’t run, and the final TVM_FFI_ICHECK(status == cudaSuccess) becomes undefined behavior.
void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
@@
- cudaSetDevice(input.device().device_id);
+ TVM_FFI_ICHECK(cudaSetDevice(input.device().device_id) == cudaSuccess);
auto stream = get_stream(input.device());
- cudaError_t status;
+ cudaError_t status = cudaErrorInvalidValue;
auto dtype = input.dtype();
@@
- DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
+ DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16(dtype, c_type, [&] {
status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
@@
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RadixTopK failed with error code " << cudaGetErrorString(status);
}If there is no ..._BF16 dispatch macro available, you should either add it (preferred, given nv_bfloat16 support in include/flashinfer/sampling.cuh) or explicitly reject bf16 here with a clear error before launch.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 24 to 60, status is left uninitialized and the
dtype dispatch only covers FP32/FP16, so BF16 calls lead to UB or wrong
behavior; initialize status (e.g., to cudaErrorInvalidValue) before the dispatch
to avoid undefined reads, then extend the dispatch to include BF16 by using the
BF16-capable dispatch macro (or add a
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 variant that maps nv_bfloat16 to
the correct c_type and launches sampling::RadixTopKMultiCTA for bf16), and if
you cannot add a BF16 dispatch, explicitly check for bf16 and return a clear
error (set status to an appropriate cudaError and fail fast) before the kernel
launch so the final TVM_FFI_ICHECK sees a defined value.
| @register_custom_op( | ||
| "flashinfer::radix_topk", mutates_args=("row_states_buffer", "output_values") | ||
| ) | ||
| def radix_topk( | ||
| input: torch.Tensor, | ||
| top_k: int, | ||
| row_states_buffer: Optional[torch.Tensor], | ||
| output_values: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| device = input.device | ||
| # Supports float32, float16, bfloat16 | ||
| assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( | ||
| f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" | ||
| ) | ||
| batch_size = input.size(0) | ||
| output_indices = torch.empty( | ||
| batch_size, top_k, dtype=torch.int32, device=device | ||
| ) | ||
| module.radix_topk( | ||
| input, output_indices, output_values, row_states_buffer, top_k | ||
| ) | ||
| return output_indices | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace assert dtype guard with a real exception.
assert can be stripped with python -O, turning invalid dtypes into hard-to-debug native crashes.
- assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], (
- f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16"
- )
+ if input.dtype not in (torch.float32, torch.float16, torch.bfloat16):
+ raise TypeError(
+ f"Unsupported dtype {input.dtype}, expected float32/float16/bfloat16"
+ )🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 31 to 53, replace the assert-based dtype
guard with an explicit runtime exception: check if input.dtype is one of
(torch.float32, torch.float16, torch.bfloat16) and if not raise a TypeError (or
ValueError) with the existing descriptive message (e.g. f"Unsupported dtype
{input.dtype}, expected float32, float16, or bfloat16"); do not use assert so
the check remains active under optimized execution.
| def top_k( | ||
| input: torch.Tensor, | ||
| k: int, | ||
| sorted: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| r"""Radix-based Top-K selection. | ||
| This function selects the top-k largest elements from each row of the input | ||
| tensor. It uses an efficient radix-based selection algorithm that is | ||
| particularly fast for large vocabularies. | ||
| This is designed as a drop-in replacement for ``torch.topk`` with better | ||
| performance for large tensors (vocab_size > 10000). | ||
| Parameters | ||
| ---------- | ||
| input : torch.Tensor | ||
| Input tensor of shape ``(batch_size, d)`` containing the values to select from. | ||
| Supported dtypes: ``float32``, ``float16``, ``bfloat16``. | ||
| k : int | ||
| Number of top elements to select from each row. | ||
| sorted : bool, optional | ||
| If True, the returned top-k elements will be sorted in descending order. | ||
| Default is False (unsorted, which is faster). | ||
| Returns | ||
| ------- | ||
| values : torch.Tensor | ||
| Tensor of shape ``(batch_size, k)`` containing the top-k values. | ||
| Same dtype as input. | ||
| indices : torch.Tensor | ||
| Tensor of shape ``(batch_size, k)`` with int64 dtype containing the | ||
| indices of the top-k elements. | ||
| Note | ||
| ---- | ||
| - Unlike ``torch.topk``, the default behavior returns unsorted results for | ||
| better performance. Set ``sorted=True`` if you need sorted output. | ||
| - The radix-based algorithm is O(n) in vocabulary size, compared to O(n log k) | ||
| for heap-based methods, making it faster for large vocabularies. | ||
| - For small vocabularies (< 1000), ``torch.topk`` may be faster. | ||
| Examples | ||
| -------- | ||
| >>> import torch | ||
| >>> import flashinfer | ||
| >>> torch.manual_seed(42) | ||
| >>> batch_size = 4 | ||
| >>> vocab_size = 32000 | ||
| >>> k = 256 | ||
| >>> logits = torch.randn(batch_size, vocab_size, device="cuda") | ||
| >>> values, indices = flashinfer.top_k(logits, k) | ||
| >>> values.shape, indices.shape | ||
| (torch.Size([4, 256]), torch.Size([4, 256])) | ||
| With sorting enabled (for compatibility with torch.topk): | ||
| >>> values_sorted, indices_sorted = flashinfer.top_k(logits, k, sorted=True) | ||
| >>> # Values are now in descending order within each row | ||
| See Also | ||
| -------- | ||
| torch.topk : PyTorch's built-in top-k function | ||
| sampling.top_k_mask_logits : Top-k masking for logits (sets non-top-k to -inf) | ||
| sampling.top_k_renorm_probs : Top-k filtering and renormalization for probabilities | ||
| """ | ||
| batch_size = input.size(0) | ||
| device = input.device | ||
|
|
||
| # Allocate row_states buffer for multi-CTA path | ||
| # 1MB is enough for any reasonable GPU (covers up to ~500 groups) | ||
| row_states_buffer: Optional[torch.Tensor] = _get_cache_buf( | ||
| f"radix_topk_row_states_{input.device}", | ||
| 1024 * 1024, # 1MB | ||
| input.device, | ||
| zero_init=True, | ||
| ) | ||
|
|
||
| # Allocate output_values for kernel to write directly | ||
| output_values = torch.empty(batch_size, k, dtype=input.dtype, device=device) | ||
|
|
||
| # Get indices using radix-based selection | ||
| indices_int32 = get_topk_module().radix_topk( | ||
| input, k, row_states_buffer, output_values | ||
| ) | ||
|
|
||
| # Convert to int64 for compatibility | ||
| indices = indices_int32.long() | ||
|
|
||
| if sorted: | ||
| # Sort within each row by value (descending) | ||
| sorted_values, sort_indices = torch.sort(output_values, dim=-1, descending=True) | ||
| sorted_indices = torch.gather(indices, dim=-1, index=sort_indices) | ||
| return sorted_values, sorted_indices | ||
|
|
||
| return output_values, indices | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate input rank + k bounds in Python (fast fail, better errors).
Add checks for input.ndim == 2, k > 0, and k <= input.size(1) before allocating outputs.
🤖 Prompt for AI Agents
In flashinfer/topk.py around lines 69 to 165, add an early fast-fail input
validation before any buffer/allocation: verify input.ndim == 2 and raise a
ValueError with a clear message if not; verify k is an int > 0 and raise
ValueError if not; verify k <= input.size(1) and raise ValueError if exceeded.
Perform these checks immediately after computing batch_size/device (or at the
top of the function) so allocations and kernel calls are skipped on invalid
inputs, using concise, user-friendly error messages.
tests/utils/test_logits_processor.py
Outdated
| # Allow small differences due to floating point precision in intermediate steps | ||
| diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size | ||
| assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider tightening the tolerance for deterministic sampling tests.
Allowing up to 1% of samples to differ seems lenient for a test using deterministic generators. While floating-point precision differences in intermediate probability normalization steps can affect boundary cases, consider:
- Tightening the threshold: For batch_size=989, this allows ~10 samples to differ. Could this be reduced to 0.1% (< 1 sample per 1000)?
- Adding diagnostics: Log which samples differ and their probability differences to distinguish genuine floating-point edge cases from correctness issues.
- Documenting expectations: Expand the comment to explain specific scenarios where differences are expected (e.g., ties near the top-k or top-p boundaries).
Apply this diff to add diagnostic output:
# Allow small differences due to floating point precision in intermediate steps
diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size
- assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%"
+ assert diff_ratio < 0.01, (
+ f"Too many differences: {diff_ratio * 100:.2f}% "
+ f"({(samples_pipe != samples_direct).sum().item()} out of {batch_size} samples differ)"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Allow small differences due to floating point precision in intermediate steps | |
| diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size | |
| assert diff_ratio < 0.01, f"Too many differences: {diff_ratio * 100:.2f}%" | |
| # Allow small differences due to floating point precision in intermediate steps | |
| diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size | |
| assert diff_ratio < 0.01, ( | |
| f"Too many differences: {diff_ratio * 100:.2f}% " | |
| f"({(samples_pipe != samples_direct).sum().item()} out of {batch_size} samples differ)" | |
| ) |
🤖 Prompt for AI Agents
In tests/utils/test_logits_processor.py around lines 821 to 823, the
deterministic sampling test tolerance is too permissive (1%); tighten it to 0.1%
(0.001) to reduce allowed differing samples, add diagnostic output that logs the
indices of differing samples and a small summary of their probability
differences (e.g., count, max diff, example diffs) to help distinguish FP edge
cases from logic bugs, and expand the existing comment to explain when small
differences are expected (ties near top-k/top-p boundaries and floating-point
normalization rounding). Ensure the assertion message includes the diagnostic
summary when it fails so CI shows which samples differ.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/utils/test_logits_processor.py (1)
821-823: The 2% mismatch allowance is very permissive; tighten + add diagnostics
This is the same concern previously raised (then 1%); increasing to 2% further reduces the test’s ability to catch regressions. Add failure diagnostics so CI tells you which samples differ.# Allow small differences due to floating point precision in intermediate steps diff_ratio = (samples_pipe != samples_direct).sum().item() / batch_size - assert diff_ratio < 0.02, f"Too many differences: {diff_ratio * 100:.2f}%" + diff_cnt = (samples_pipe != samples_direct).sum().item() + diff_idx = (samples_pipe != samples_direct).nonzero(as_tuple=False).flatten() + assert diff_ratio < 0.001, ( + f"Too many differences: {diff_ratio * 100:.2f}% " + f"({diff_cnt}/{batch_size}); first_diff_idx={diff_idx[:10].tolist()}" + )
🧹 Nitpick comments (2)
tests/utils/test_logits_processor.py (2)
642-642: Prefer explicit tolerances (and/ortorch.testing.assert_close) instead of defaulttorch.allclose
Defaultrtol/atolcan be surprisingly strict/loose depending on dtype/device and PyTorch version; this can turn into CI flakiness or masked regressions.- assert torch.allclose(samples_pipe, samples_direct) + torch.testing.assert_close(samples_pipe, samples_direct, rtol=1e-4, atol=1e-6)
666-666: For logits masking, validate mask pattern explicitly (don’t only rely onallclose)
If this is validatingtop_k_mask_logits, it’s useful to assert that the “masked positions” (e.g.,-inf) match exactly, then useassert_closefor the remaining finite values.- assert torch.allclose(samples_pipe, samples_direct, equal_nan=True) + # masked pattern should match exactly (e.g., -inf locations) + assert torch.equal(torch.isneginf(samples_pipe), torch.isneginf(samples_direct)) + # remaining finite values should be close + torch.testing.assert_close( + samples_pipe[~torch.isneginf(samples_pipe)], + samples_direct[~torch.isneginf(samples_direct)], + rtol=1e-4, + atol=1e-6, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/utils/test_logits_processor.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
📌 Description
This PR implements several features and optimizations for top-k:
Multi-CTA optimization
Followup of #2044 , this PR optimizes the top-k/top-p filter using multi-cta optimizations, more specifically:
The major advantage over the main branch is that we make sure logits/probabilities are stored in shared memory, so multi-round of scan doesn't affect performance too much.
Radix Top-K
The number of radix bits is set to 8.
More Top-K APIs and Data Types
Speedup
On H100, the speedup over 0.5.2 is available at:
https://docs.google.com/spreadsheets/d/1DO8_11gzv-EUACCY6q4IMIHa8SaYv4q8hJ6gZl-D0mU/edit?usp=sharing (3-14x faster).
On consumer GPUs (e.g. Ada6000), the gap is even larger, e.g. for small batch size and large vocabulary setting:
The gap can be as large as 100 times.
🔍 Related Issues
#2044
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.