-
Notifications
You must be signed in to change notification settings - Fork 597
feat: support radix-based top-k sampling algorithm #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
48332b8
f92a214
4e5f598
d00fab2
f47a98e
9223421
5c2bef3
8ef840b
7bd7c9d
4a55fa8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,8 @@ | |
| device_support_pdl, | ||
| register_custom_op, | ||
| register_fake_op, | ||
| get_radik_workspace_size, | ||
| _is_buf_cached, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -214,6 +216,37 @@ def top_k_sampling_from_probs( | |
| ) | ||
| return samples | ||
|
|
||
| @register_custom_op( | ||
| "flashinfer::radik_sampling_from_probs", mutates_args=("workspace_buffer",) | ||
| ) | ||
| def radik_sampling_from_probs( | ||
| workspace_buffer: torch.Tensor, | ||
| probs: torch.Tensor, | ||
| indices: Optional[torch.Tensor], | ||
| maybe_top_k_arr: Optional[torch.Tensor], | ||
| top_k_val: int, | ||
| deterministic: bool, | ||
| selected_probs: Optional[torch.Tensor], | ||
| generator: Optional[torch.Generator], | ||
| ) -> torch.Tensor: | ||
| device = probs.device | ||
| probs = probs.float() | ||
| batch_size = indices.size(0) if indices is not None else probs.size(0) | ||
| maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None | ||
| samples = torch.empty(batch_size, dtype=torch.int32, device=device) | ||
| module.radik_sampling_from_probs.default( | ||
| workspace_buffer, | ||
| probs, | ||
| samples, | ||
| indices, | ||
| maybe_top_k_arr, | ||
| top_k_val, | ||
| deterministic, | ||
| selected_probs, | ||
| generator, | ||
| ) | ||
| return samples | ||
|
|
||
| @register_fake_op("flashinfer::top_k_sampling_from_probs") | ||
| def _fake_top_k_sampling_from_probs( | ||
| probs: torch.Tensor, | ||
|
|
@@ -453,6 +486,7 @@ def _fake_chain_speculative_sampling( | |
| top_k_renorm_probs=top_k_renorm_probs, | ||
| top_k_mask_logits=top_k_mask_logits, | ||
| chain_speculative_sampling=chain_speculative_sampling, | ||
| radik_sampling_from_probs=radik_sampling_from_probs, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -801,11 +835,99 @@ def top_k_sampling_from_probs( | |
| if check_nan: | ||
| if torch.any(torch.isnan(probs)): | ||
| raise ValueError("Input probs contains NaN.") | ||
|
|
||
| # dispatch non-determinitic and small top-k requests to radik_sampling_from_probs | ||
| use_radik_impl = not deterministic and ( | ||
| (isinstance(top_k, int) and top_k <= 100) | ||
| or (isinstance(top_k, torch.Tensor) and top_k.max() <= 100) | ||
| ) | ||
| # Check if GPU memory is available for radik_sampling_from_probs | ||
| is_radik_buf_cached, radik_buf_bytes = _is_buf_cached( | ||
| "radik_sampling_from_probs_workspace", probs.device | ||
| ) | ||
| required_radik_buf_bytes = get_radik_workspace_size(probs, top_k) | ||
| memory_avaliable = ( | ||
| is_radik_buf_cached and radik_buf_bytes >= required_radik_buf_bytes | ||
| ) or ( | ||
| not is_radik_buf_cached | ||
| and torch.cuda.mem_get_info()[1] >= required_radik_buf_bytes | ||
| ) | ||
|
|
||
| use_radik_impl = use_radik_impl and memory_avaliable | ||
|
|
||
| if use_radik_impl: | ||
| return radik_sampling_from_probs( | ||
| probs, | ||
| top_k, | ||
| indices, | ||
| deterministic, | ||
| generator, | ||
| selected_probs=None, | ||
| check_nan=check_nan, | ||
| ) | ||
|
|
||
| return get_sampling_module().top_k_sampling_from_probs( | ||
| probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator | ||
| ) | ||
|
|
||
|
|
||
| def radik_sampling_from_probs( | ||
| probs: torch.Tensor, | ||
| top_k: Union[torch.Tensor, int], | ||
| indices: Optional[torch.Tensor] = None, | ||
| deterministic: bool = True, | ||
| generator: Optional[torch.Generator] = None, | ||
| selected_probs: Optional[torch.Tensor] = None, | ||
| check_nan: bool = False, | ||
| ) -> torch.Tensor: | ||
| r"""GPU kernel for radix top-k sampling from probability distributions, | ||
| utilizing radix selection to efficiently identify top-k elements followed by sampling from the selected subset. | ||
| Check the `radik paper <https://arxiv.org/abs/2501.14336>`_ for more details. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| probs: torch.Tensor | ||
| Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` | ||
| and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, | ||
| shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique | ||
| probability distributions. | ||
| top_k: Union[torch.Tensor, int] | ||
| Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. | ||
| If a scalar, the same threshold is used for all requests. | ||
| If a tensor, each request has its own threshold. | ||
| indices: Optional[torch.Tensor] | ||
| Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. | ||
| For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. | ||
| This allows reusing the same probability distribution for multiple outputs. | ||
| If indices is not provided, the i-th output will be sampled from the i-th row of probs. | ||
| deterministic: bool | ||
| Whether to use deterministic kernel implementation, default is ``True``. However, the radix sampling process itself is inherently non-deterministic. | ||
| generator: Optional[torch.Generator] | ||
| A random number generator for the operation. | ||
| selected_probs: Optional[torch.Tensor] | ||
| Optional tensor of shape ``(batch_size, top_k)`` that stores the top-k selected probabilities. | ||
| check_nan: bool | ||
| Whether to check nan in :attr:`probs`, default is ``False``. | ||
| """ | ||
| if check_nan: | ||
| if torch.any(torch.isnan(probs)): | ||
| raise ValueError("Input probs contains NaN.") | ||
|
|
||
| workspace_buffer = _get_cache_buf( | ||
| "radik_sampling_from_probs_workspace", 64 * 1024 * 1024, probs.device | ||
| ) | ||
|
Comment on lines
+916
to
+918
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function allocates a fixed-size 64MB workspace buffer. However, the underlying CUDA kernel might require a much larger buffer depending on It would be better to provide a way for the user to calculate the required workspace size and pre-allocate it. For example, you could expose a helper function like |
||
|
|
||
| return get_sampling_module().radik_sampling_from_probs( | ||
| workspace_buffer, | ||
| probs, | ||
| indices, | ||
| *_to_tensor_scalar_tuple(top_k), | ||
| deterministic, | ||
| selected_probs, | ||
| generator, | ||
| ) | ||
|
|
||
|
|
||
| def min_p_sampling_from_probs( | ||
| probs: torch.Tensor, | ||
| min_p: Union[torch.Tensor, float], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also encourage dispatching topk sampling to this function for small K's.