Skip to content

Conversation

@JasonJ2021
Copy link
Contributor

@JasonJ2021 JasonJ2021 commented Aug 24, 2025

📌 Description

This commit introduces a new radix sort-based top-k sampling algorithm,radik_sampling_from_probs(), which is ported from https://github.com/leefige/radik.git. Following the design of https://arxiv.org/abs/2501.14336.

🔍 Related Issues

#1243

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
    • Add tests to prove correctness of radix-sorting based top-k sampling algorithms
  • All tests are passing (unittest, etc.).

Reviewer Notes

The current radix-sort based top-k implementation only supports cases where k <= 1024. I'm still looking for a way to integrate it into the existing top-k sampling. An ideal scenario would be to use the radix implementation when k <= 1024 and the batch size is small, and use the existing implementation for all other cases.

The radix implementation requires an additional buffer to store intermediate results for sorting. For example, for top-k sampling task(batch_size=1024, vocab_size=128512, k=10), it require ~1GB workspacebuffer. To accelerate memory allocation, the current implementation pre-allocates a 64MB buffer.

The SamplingFromRadiKSelectKernel currently has an address misalignment issue. Therefore, I have manually set vec_size=1. This is because the kernel needs to read the sampling source data from the top_k_select_result address, but this address is not aligned with top_k_val. Might consider using an AlignedAllocator in the future. However, this kernel does not have a significant impact on end-to-end latency.

Compared to the existing top-k sampling implementation, the radix-sort based approach significantly reduces end-to-end latency, especially with smaller batches. I selected some benchmark data with a relatively small batch size; the complete benchmark data is in the attachment. Benchmarks were conducted on H20 GPU.

Original Implementation of top_k_sampling_from_probs:

vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 523.30 us, effective bandwidth: 0.98 GB/s
vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 340.35 us, effective bandwidth: 1.51 GB/s
vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 195.71 us, effective bandwidth: 2.63 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 1412.06 us, effective bandwidth: 5.82 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 961.15 us, effective bandwidth: 8.56 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 696.38 us, effective bandwidth: 11.81 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 1028.06 us, effective bandwidth: 16.00 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 839.38 us, effective bandwidth: 19.60 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 655.28 us, effective bandwidth: 25.10 GB/s

radik_sampling_from_probs:

vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 101.28 us, effective bandwidth: 5.08 GB/s
vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 102.05 us, effective bandwidth: 5.04 GB/s
vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 104.32 us, effective bandwidth: 4.93 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 111.55 us, effective bandwidth: 73.73 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 121.98 us, effective bandwidth: 67.43 GB/s
vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 114.82 us, effective bandwidth: 71.63 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 10, duration: 120.45 us, effective bandwidth: 136.57 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 100, duration: 121.34 us, effective bandwidth: 135.56 GB/s
vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), deterministic: False, k: 1000, duration: 138.34 us, effective bandwidth: 118.91 GB/s

benchmark.log

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @JasonJ2021, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new radix sort-based top-k sampling algorithm, radik_sampling_from_probs, to the FlashInfer library. This algorithm, ported from an external project, aims to significantly improve the end-to-end latency of top-k sampling, particularly for scenarios with smaller batch sizes. While currently supporting k values up to 1024 and requiring an additional workspace buffer, it offers a substantial performance boost over the existing implementation.

Highlights

  • New Algorithm: Introduces radik_sampling_from_probs, a radix sort-based top-k sampling algorithm, to enhance performance.
  • Performance: Achieves significant latency reduction for top-k sampling, especially with smaller batch sizes, as demonstrated by benchmarks.
  • Limitations: The current implementation of the radix-based sampling algorithm is limited to k values less than or equal to 1024.
  • Resource Usage: Requires an additional workspace buffer (pre-allocated 64MB) for intermediate sorting results.
  • Future Integration: Proposed integration strategy involves using the new radix algorithm for k <= 1024 and small batch sizes, and the existing method otherwise.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new radix sort-based top-k sampling algorithm. The implementation looks promising and shows significant performance improvements in the benchmarks. My review focuses on correctness and performance aspects of the new implementation. I've found a critical race condition in the CUDA kernel, a bug related to handling variable top_k values, a performance issue with workspace allocation, and an inaccurate docstring. Addressing these points will improve the robustness and usability of the new feature.

Comment on lines +2655 to +2659
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(tmpK.data(), top_k_arr, sizeof(int) * batch_size,
cudaMemcpyDefault, stream));
}
uint32_t max_top_k_val =
top_k_arr == nullptr ? top_k_val : *std::max_element(tmpK.begin(), tmpK.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a race condition here. cudaMemcpyAsync is used to copy top_k_arr from device to the host vector tmpK. However, tmpK is accessed on the host immediately after in the std::max_element call on line 2659 without synchronizing the stream. This can lead to incorrect max_top_k_val calculation or undefined behavior. A cudaStreamSynchronize(stream) is needed before accessing tmpK on the host. Also, it's better to be explicit with cudaMemcpyDeviceToHost.

    FLASHINFER_CUDA_CALL(cudaMemcpyAsync(tmpK.data(), top_k_arr, sizeof(int) * batch_size,
                                         cudaMemcpyDeviceToHost, stream));
    FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream));

Comment on lines +886 to +888
workspace_buffer = _get_cache_buf(
"radik_sampling_from_probs_workspace", 64 * 1024 * 1024, probs.device
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function allocates a fixed-size 64MB workspace buffer. However, the underlying CUDA kernel might require a much larger buffer depending on batch_size and vocab_size (e.g., ~1GB for a use case mentioned in the PR description). The C++ implementation handles this by allocating a new buffer with cudaMalloc if the provided one is insufficient. This hidden, potentially slow, allocation can be a significant performance bottleneck, especially if this function is called in a loop.

It would be better to provide a way for the user to calculate the required workspace size and pre-allocate it. For example, you could expose a helper function like get_radik_sampling_workspace_size(batch_size, vocab_size, max_top_k).


if (selected_probs != nullptr) {
FLASHINFER_CUDA_CALL(cudaMemcpy(selected_probs, top_k_select_result,
sizeof(T) * batch_size * top_k_val, cudaMemcpyDefault));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cudaMemcpy to copy the selected probabilities uses top_k_val in the size calculation. When top_k is passed as a tensor from Python, top_k_val will be 0, and no data will be copied. This should use max_top_k_val to ensure all selected probabilities for all requests in the batch are copied correctly. The user of selected_probs will need to ensure the buffer is allocated to (batch_size, max_top_k_val).

                                    sizeof(T) * batch_size * max_top_k_val, cudaMemcpyDefault));

Comment on lines 850 to 855
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring appears to be copied from another sampling function and incorrectly describes the algorithm. It mentions "rejection sampling without explicit sorting", but this implementation is based on radix sort. The docstring should be updated to accurately reflect the algorithm being used.

Suggested change
r"""Fused GPU kernel for top-k sampling from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
r"""Fused GPU kernel for top-k sampling from probabilities based on radix sort.
This operator is ported from https://github.com/leefige/radik.git and follows the design of https://arxiv.org/abs/2501.14336.
The radix sort based selection is implemented in a single CUDA kernel,
which can be more efficient than other methods for certain cases (e.g. small k).

@yzh119 yzh119 self-requested a review August 24, 2025 08:46
FLASHINFER_CUDA_CALL(
cudaMemsetAsync(histPtr, 0, sizeof(int) * batch_size * ((1 << 12) + 1), stream));

DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI failed because of too many shared memory usage (the public CI was executed on a machine with small shared memory size ~98k/sm), need to tweak BLOCK_THREADS to make sure shared memory usage is bounded.

We can create a new macro, the heuristics defined in DISPATCH_COMPUTE_CAP_NUM_THREADS is designed specifically for certain kernels and might not work for radix top-k.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I'll try to fix it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared memory issue has been resolved.

A new macro is used to determine BLOCK_THREADS for the selectCandidateEx kernel, while all other kernel launches use the existing macro. I am not sure whether this is the correct approach and if it is consistent with the existing flashinfer codestyle.

@yzh119
Copy link
Collaborator

yzh119 commented Sep 3, 2025

Hi @JasonJ2021 is this PR ready for review?

@JasonJ2021
Copy link
Contributor Author

JasonJ2021 commented Sep 5, 2025 via email

@yzh119 yzh119 changed the title [DRAFT] feat: support radix-based top-k sampling algorithm feat: support radix-based top-k sampling algorithm Sep 5, 2025
int* taskLenPtr[2]{globalCountPtr + batch_size, globalCountPtr + 2 * batch_size};
std::vector<int> tmpTaskLen(2 * batch_size, d);

FLASHINFER_CUDA_CALL(cudaMemcpyAsync(taskLenPtr[0], tmpTaskLen.data(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't profiled the performance yet, but could be use a standalone cuda kernel (and enable PDL) to initialize these global buffers (e.g. taskLenPtr, kPtr and histPtr) instead of calling multiple MemcpyAsync like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, i'm not experienced in cuda programming.
My understanding is to first use a standalone kernel (e.g., initialize_buffer()) to initialize these global buffers, and then leverage PDL to allow this initialization kernel to execute concurrently with a subsequent compute kernel (e.g., countBinExKernel) in a single stream. Is my understanding correct?

flag ^= 1;
}
// clear globalCount
FLASHINFER_CUDA_CALL(cudaMemsetAsync(globalCountPtr, 0, sizeof(int) * batch_size, stream));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}

// === Iter 3 ===
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get rid of cudaStreamSynchronization like this?
Seems we rely on host-side maxTaskLen as kernel configuration of countbin_iter3_kernel, but is it possible to rewrite countbin_iter3_kernel as a persistent kernel (fix grid configuration), but inside the kernel we depend on device-side maxTaskLen value (so that we can reduce synchronization overhead here).

});

// check global histgram
FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)


def radik_sampling_from_probs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also encourage dispatching topk sampling to this function for small K's.

@yzh119
Copy link
Collaborator

yzh119 commented Sep 10, 2025

Hi @JasonJ2021 do you have time to address the comments?

@JasonJ2021
Copy link
Contributor Author

Hi @JasonJ2021 do you have time to address the comments?

Yes, I will start to address these issues in the next two to three days.

@JasonJ2021
Copy link
Contributor Author

def top_k_sampling_from_probs(
    probs: torch.Tensor,
    top_k: Union[torch.Tensor, int],
    indices: Optional[torch.Tensor] = None,
    deterministic: bool = True,
    generator: Optional[torch.Generator] = None,
    check_nan: bool = False,
) -> torch.Tensor:
    ...
    # NOTE: radik_sampling_from_probs is non-deterministic
    use_radik_impl = not deterministic and (
        (isinstance(top_k, int) and top_k <= 100)
        or (isinstance(top_k, torch.Tensor) and top_k.max() <= 100)
    )

    if use_radik_impl:
        return radik_sampling_from_probs(
            probs,
            top_k,
            indices,
            deterministic,
            generator,
            selected_probs=None,
            check_nan=check_nan,
        )

    return get_sampling_module().top_k_sampling_from_probs(
        probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator
    )

i found that the Radik implementation for top_k_sampling_from_probs should only be dispatched when deterministic output is not required.

The root cause is that Radik's parallel algorithm for selecting the top-k elements is non-deterministic. It uses multiple threadblocks with atomic operations to populate an top-k-select array, but the final order of elements in that array is not guaranteed between runs.

This unstable ordering directly affects the subsequent Cumulative Distribution Function (CDF) sampling step, leading to different results from the same input probabilities.

For example, for the same input, the Radik top-k selection might produce [0.4, 0.3, 0.2, 0.1] in one run and [0.1, 0.2, 0.3, 0.4] in another. Given the same random number for sampling (e.g., 0.35), the CDF-based logic would select index 0 (prob=0.4) in the first case, but index 2 (prob=0.3) in the second, leading to a different result.

@ECMGit
Copy link

ECMGit commented Nov 17, 2025

Hi experts, When will the commit be merged? Looks like we are facing a under performing issue due to the sampling strategy, if this commit can be merged, we can get rid of this change from replacing flashinfer.sampling.top_k_top_p_sampling_from_logits to flashinfer.sampling.top_k_top_p_sampling_from_probs.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 12, 2025

Hi @ECMGit @JasonJ2021 can you try #2119 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants