Skip to content
40 changes: 40 additions & 0 deletions benchmarks/bench_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def init_seed_top_p_sampling(*args, **kwargs):
return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs)


def init_seed_radik_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.radik_sampling_from_probs(*args, **kwargs)


@torch.inference_mode()
def main():
print("---")
Expand Down Expand Up @@ -119,6 +124,41 @@ def main():
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("radik sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
for k in [10, 100, 1000]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
measurements = bench_gpu_time(
lambda: init_seed_radik_sampling(
probs, k, deterministic=deterministic
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-p sampling")

Expand Down
8 changes: 8 additions & 0 deletions csrc/flashinfer_sampling_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
at::Tensor output_emitted_draft_token_num, bool deterministic,
std::optional<at::Generator> gen);

void radik_sampling_from_probs(at::Tensor workspace_buffer, at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, std::optional<at::Tensor> maybe_selected_probs,
std::optional<at::Generator> gen);

TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Softmax
m.def("softmax", softmax);
Expand All @@ -86,4 +92,6 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("top_k_mask_logits", top_k_mask_logits);
// Speculative sampling from sequence of probabilities
m.def("chain_speculative_sampling", chain_speculative_sampling);
// Top-k sampling with Radix Sorting
m.def("radik_sampling_from_probs", radik_sampling_from_probs);
}
42 changes: 42 additions & 0 deletions csrc/sampling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,45 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
std::string(cudaGetErrorString(status)));
}

void radik_sampling_from_probs(at::Tensor workspace_buffer, at::Tensor probs, at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
bool deterministic, std::optional<at::Tensor> maybe_selected_probs,
std::optional<at::Generator> gen_) {
CHECK_INPUT(workspace_buffer);
CHECK_INPUT(probs);
CHECK_INPUT(output);
CHECK_GE(1024, top_k_val); // only support top-k <= 1024 currently
auto device = probs.device();
CHECK_EQ(output.device(), device);
CHECK_DIM(2, probs);
CHECK_DIM(1, output);
unsigned int batch_size = output.size(0);
unsigned int vocab_size = probs.size(1);

bool has_top_k_arr = maybe_top_k_arr.has_value();

uint64_t philox_seed, philox_offset;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState rng_engine_inputs = gen->philox_cuda_state(32 * batch_size);
philox_seed = rng_engine_inputs.seed_.val;
philox_offset = rng_engine_inputs.offset_.val;

const c10::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::RadiKSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<int*>(output.data_ptr()),
maybe_indices.has_value() ? static_cast<int*>(maybe_indices->data_ptr()) : nullptr,
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, probs.size(0),
maybe_indices.has_value() ? maybe_indices->size(0) : batch_size, top_k_val, vocab_size,
philox_seed, philox_offset, deterministic, workspace_buffer.data_ptr(),
workspace_buffer.element_size() * workspace_buffer.size(0),
maybe_selected_probs.has_value() ? static_cast<float*>(maybe_selected_probs->data_ptr())
: nullptr,
stream);
TORCH_CHECK(status == cudaSuccess, "RadiKSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
}
122 changes: 122 additions & 0 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
device_support_pdl,
register_custom_op,
register_fake_op,
get_radik_workspace_size,
_is_buf_cached,
)


Expand Down Expand Up @@ -214,6 +216,37 @@ def top_k_sampling_from_probs(
)
return samples

@register_custom_op(
"flashinfer::radik_sampling_from_probs", mutates_args=("workspace_buffer",)
)
def radik_sampling_from_probs(
workspace_buffer: torch.Tensor,
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
deterministic: bool,
selected_probs: Optional[torch.Tensor],
generator: Optional[torch.Generator],
) -> torch.Tensor:
device = probs.device
probs = probs.float()
batch_size = indices.size(0) if indices is not None else probs.size(0)
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
samples = torch.empty(batch_size, dtype=torch.int32, device=device)
module.radik_sampling_from_probs.default(
workspace_buffer,
probs,
samples,
indices,
maybe_top_k_arr,
top_k_val,
deterministic,
selected_probs,
generator,
)
return samples

@register_fake_op("flashinfer::top_k_sampling_from_probs")
def _fake_top_k_sampling_from_probs(
probs: torch.Tensor,
Expand Down Expand Up @@ -453,6 +486,7 @@ def _fake_chain_speculative_sampling(
top_k_renorm_probs=top_k_renorm_probs,
top_k_mask_logits=top_k_mask_logits,
chain_speculative_sampling=chain_speculative_sampling,
radik_sampling_from_probs=radik_sampling_from_probs,
)


Expand Down Expand Up @@ -801,11 +835,99 @@ def top_k_sampling_from_probs(
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")

# dispatch non-determinitic and small top-k requests to radik_sampling_from_probs
use_radik_impl = not deterministic and (
(isinstance(top_k, int) and top_k <= 100)
or (isinstance(top_k, torch.Tensor) and top_k.max() <= 100)
)
# Check if GPU memory is available for radik_sampling_from_probs
is_radik_buf_cached, radik_buf_bytes = _is_buf_cached(
"radik_sampling_from_probs_workspace", probs.device
)
required_radik_buf_bytes = get_radik_workspace_size(probs, top_k)
memory_avaliable = (
is_radik_buf_cached and radik_buf_bytes >= required_radik_buf_bytes
) or (
not is_radik_buf_cached
and torch.cuda.mem_get_info()[1] >= required_radik_buf_bytes
)

use_radik_impl = use_radik_impl and memory_avaliable

if use_radik_impl:
return radik_sampling_from_probs(
probs,
top_k,
indices,
deterministic,
generator,
selected_probs=None,
check_nan=check_nan,
)

return get_sampling_module().top_k_sampling_from_probs(
probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator
)


def radik_sampling_from_probs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also encourage dispatching topk sampling to this function for small K's.

probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
indices: Optional[torch.Tensor] = None,
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
selected_probs: Optional[torch.Tensor] = None,
check_nan: bool = False,
) -> torch.Tensor:
r"""GPU kernel for radix top-k sampling from probability distributions,
utilizing radix selection to efficiently identify top-k elements followed by sampling from the selected subset.
Check the `radik paper <https://arxiv.org/abs/2501.14336>`_ for more details.

Parameters
----------
probs: torch.Tensor
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
probability distributions.
top_k: Union[torch.Tensor, int]
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
indices: Optional[torch.Tensor]
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
This allows reusing the same probability distribution for multiple outputs.
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``. However, the radix sampling process itself is inherently non-deterministic.
generator: Optional[torch.Generator]
A random number generator for the operation.
selected_probs: Optional[torch.Tensor]
Optional tensor of shape ``(batch_size, top_k)`` that stores the top-k selected probabilities.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")

workspace_buffer = _get_cache_buf(
"radik_sampling_from_probs_workspace", 64 * 1024 * 1024, probs.device
)
Comment on lines +916 to +918
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function allocates a fixed-size 64MB workspace buffer. However, the underlying CUDA kernel might require a much larger buffer depending on batch_size and vocab_size (e.g., ~1GB for a use case mentioned in the PR description). The C++ implementation handles this by allocating a new buffer with cudaMalloc if the provided one is insufficient. This hidden, potentially slow, allocation can be a significant performance bottleneck, especially if this function is called in a loop.

It would be better to provide a way for the user to calculate the required workspace size and pre-allocate it. For example, you could expose a helper function like get_radik_sampling_workspace_size(batch_size, vocab_size, max_top_k).


return get_sampling_module().radik_sampling_from_probs(
workspace_buffer,
probs,
indices,
*_to_tensor_scalar_tuple(top_k),
deterministic,
selected_probs,
generator,
)


def min_p_sampling_from_probs(
probs: torch.Tensor,
min_p: Union[torch.Tensor, float],
Expand Down
40 changes: 40 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
return buf


def _is_buf_cached(name: str, device: torch.device) -> Tuple[bool, int]:
key = (name, device)
if key in _cache_buf:
buf = _cache_buf[key]
bytes_size = buf.numel() * buf.element_size()
return (True, bytes_size)
return (False, 0)


# find the least power of 2 that is greater than or equal to x
def _ceil_pow2(x: int) -> int:
return 1 << (x - 1).bit_length()
Expand Down Expand Up @@ -737,3 +746,34 @@ def get_shuffle_matrix_sf_a_row_indices(
row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m)

return row_indices


def get_radik_workspace_size(probs: torch.Tensor, top_k: Union[torch.Tensor, int]):
"""
Calculate the workspace size required for the radix select algorithm

Args:
probs: The input probabilities
top_k: The k value in top-k selection

Returns:
size_in_bytes: Required workspace size in bytes
"""
k = top_k.max() if isinstance(top_k, torch.Tensor) else top_k
task_num = probs.size(0)
vocab_size = probs.size(1)

sizeof_CompT = 4
sizeof_int = 4
sizeof_T = probs.element_size()
sizeof_IdxType = 4

size_in_bytes = task_num * (
sizeof_CompT * vocab_size * 2 # buffer for val
+ sizeof_int * (1**12) # buffer for hist (4096 = 2^12)
+ sizeof_int * 5 # buffer for globalCount,old_taskLen,new_taskLen,K,binId
+ sizeof_T * k # buffer for top-k select result
+ sizeof_IdxType * k # buffer for top-k select result
)

return size_in_bytes
Loading