Skip to content
Merged
6 changes: 4 additions & 2 deletions csrc/flashinfer_sampling_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_p_arr, double top_p_val);

void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val);
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer);

void top_k_mask_logits(TensorView logits, TensorView mask_logits,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val);
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer);

void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_ids,
TensorView target_probs, TensorView output_token_ids,
Expand Down
24 changes: 24 additions & 0 deletions csrc/flashinfer_topk_binding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;

void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k);

// Radix-based Top-K selection
TVM_FFI_DLL_EXPORT_TYPED_FUNC(radix_topk, radix_topk);
43 changes: 33 additions & 10 deletions csrc/renorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
}

void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer) {
CHECK_INPUT(probs);
CHECK_INPUT(row_states_buffer);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
Expand All @@ -52,18 +54,29 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,

ffi::CUDADeviceGuard device_guard(probs.device().device_id);
auto stream = get_stream(probs.device());
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

cudaError_t status;
auto dtype = probs.dtype();

// Use radix-based top-k with dtype dispatch for FP32/FP16/BF16
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKRenormProbMultiCTA<c_type, int>(
static_cast<c_type*>(probs.data_ptr()), static_cast<c_type*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, static_cast<sampling::RadixRowState*>(row_states_buffer.data_ptr()),
stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRenormProb failed with error code " << cudaGetErrorString(status);
}

void top_k_mask_logits(TensorView logits, TensorView mask_logits,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer) {
CHECK_INPUT(logits);
CHECK_INPUT(row_states_buffer);
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits.size(0);
unsigned int vocab_size = logits.size(1);
Expand All @@ -72,10 +85,20 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits,

ffi::CUDADeviceGuard device_guard(logits.device().device_id);
auto stream = get_stream(logits.device());
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

cudaError_t status;
auto dtype = logits.dtype();

// Use radix-based top-k with auto-selection (single-CTA for small vocab, multi-CTA for large
// vocab)
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKMaskLogitsMultiCTA<c_type, int>(
static_cast<c_type*>(logits.data_ptr()), static_cast<c_type*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, static_cast<sampling::RadixRowState*>(row_states_buffer.data_ptr()),
stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKMaskLogits failed with error code " << cudaGetErrorString(status);
Expand Down
60 changes: 60 additions & 0 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/sampling.cuh>

#include "tvm_ffi_utils.h"

using namespace flashinfer;

using tvm::ffi::Optional;

void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(output_values);
CHECK_DIM(2, input); // input: (batch_size, d)
CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k)
CHECK_DIM(2, output_values); // output_values: (batch_size, top_k)

unsigned int batch_size = input.size(0);
unsigned int d = input.size(1);

Comment on lines +24 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add basic shape + top_k range validation (prevents kernel OOB / nonsense launches).
Only rank is validated today. At minimum, enforce:

  • top_k > 0
  • top_k <= input.size(1)
  • output_{indices,values}.size(1) == top_k
  • batch dims match across tensors

This avoids silent misbehavior when callers pass a mismatched k vs output shapes.

πŸ€– Prompt for AI Agents
In csrc/topk.cu around lines 24 to 35, add explicit shape and top_k range
validation: check top_k > 0 and top_k <= input.size(1), verify
output_indices.size(0) and output_values.size(0) equal input.size(0), verify
output_indices.size(1) and output_values.size(1) equal top_k, and ensure
maybe_row_states_buffer (if present) has a compatible batch dimension; use the
same CHECK_* macros (or error paths) already used in the file so failures are
reported consistently and prevent launching kernels with out-of-bounds sizes.

cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

cudaError_t status;
auto dtype = input.dtype();

// Get row_states_buffer if provided (for multi-CTA path)
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<c_type*>(output_values.data_ptr()),
nullptr, // top_k_arr
batch_size, static_cast<uint32_t>(top_k), d, row_states_ptr, stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RadixTopK failed with error code " << cudaGetErrorString(status);
}
Comment on lines +24 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Fix BF16 dispatch + initialize status to avoid UB / wrong dtype behavior.
Right now the kernel dispatch only covers FP32/FP16 (Line 49), but the Python API advertises BF16 support. Also status is uninitialized if dispatch doesn’t run, and the final TVM_FFI_ICHECK(status == cudaSuccess) becomes undefined behavior.

 void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
                 Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
@@
-  cudaSetDevice(input.device().device_id);
+  TVM_FFI_ICHECK(cudaSetDevice(input.device().device_id) == cudaSuccess);
   auto stream = get_stream(input.device());
 
-  cudaError_t status;
+  cudaError_t status = cudaErrorInvalidValue;
   auto dtype = input.dtype();
@@
-  DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
+  DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16(dtype, c_type, [&] {
     status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
@@
   TVM_FFI_ICHECK(status == cudaSuccess)
       << "RadixTopK failed with error code " << cudaGetErrorString(status);
 }

If there is no ..._BF16 dispatch macro available, you should either add it (preferred, given nv_bfloat16 support in include/flashinfer/sampling.cuh) or explicitly reject bf16 here with a clear error before launch.

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In csrc/topk.cu around lines 24 to 60, status is left uninitialized and the
dtype dispatch only covers FP32/FP16, so BF16 calls lead to UB or wrong
behavior; initialize status (e.g., to cudaErrorInvalidValue) before the dispatch
to avoid undefined reads, then extend the dispatch to include BF16 by using the
BF16-capable dispatch macro (or add a
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16 variant that maps nv_bfloat16 to
the correct c_type and launches sampling::RadixTopKMultiCTA for bf16), and if
you cannot add a BF16 dispatch, explicitly check for bf16 and return a clear
error (set status to an appropriate cudaError and fail fast) before the kernel
launch so the final TVM_FFI_ICHECK sees a defined value.

17 changes: 17 additions & 0 deletions csrc/tvm_ffi_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0};
} \
}()

// Dispatcher for FP32/FP16/BF16 data types
#define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dlpack_dtype, c_type, ...) \
[&]() -> bool { \
switch (encode_dlpack_dtype(dlpack_dtype)) { \
case float32_code: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
TVM_FFI_ICHECK(false) << __PRETTY_FUNCTION__ << " failed to dispatch data type " \
<< (dlpack_dtype).code << " " << (dlpack_dtype).bits; \
return false; \
} \
}()

#define _DISPATCH_CASE_I32(c_type, ...) \
case int32_code: { \
using c_type = int32_t; \
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
from . import topk as topk
from .topk import top_k as top_k
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .jit.quantization import gen_quantization_module
from .jit.rope import gen_rope_module
from .jit.sampling import gen_sampling_module
from .jit.topk import gen_topk_module
from .jit.tllm_utils import gen_trtllm_utils_module
from .jit.xqa import gen_xqa_module, gen_xqa_module_mla
from .jit.attention import (
Expand Down Expand Up @@ -528,6 +529,7 @@ def gen_all_modules(
gen_quantization_module(),
gen_rope_module(),
gen_sampling_module(),
gen_topk_module(),
]
if has_sm90:
jit_specs.append(gen_trtllm_utils_module())
Expand Down
5 changes: 4 additions & 1 deletion flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def clear_cache_dir():
"-gencode=arch=compute_89,code=sm_89",
"-DFLASHINFER_ENABLE_FP8_E8M0",
]
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
sm90a_nvcc_flags = [
"-gencode=arch=compute_90a,code=sm_90a",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
] + common_nvcc_flags
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
sm100f_nvcc_flags = ["-gencode=arch=compute_100f,code=sm_100f"] + common_nvcc_flags
Expand Down
28 changes: 28 additions & 0 deletions flashinfer/jit/topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from . import env as jit_env
from .core import JitSpec, gen_jit_spec


def gen_topk_module() -> JitSpec:
return gen_jit_spec(
"topk",
[
jit_env.FLASHINFER_CSRC_DIR / "topk.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_topk_binding.cu",
],
)
18 changes: 16 additions & 2 deletions flashinfer/logits_processor/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,15 @@ def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor:
):
raise ValueError("top_k must be a positive integer or a tensor array")

# Allocate row_states buffer for multi-CTA kernel (1MB is enough for any GPU)
row_states_buffer = _get_cache_buf(
f"top_k_renorm_probs_row_states_{tensor.data.device}",
1024 * 1024,
tensor.data.device,
zero_init=True,
)
renorm_probs = get_sampling_module().top_k_renorm_probs(
tensor.data, maybe_top_k_arr, top_k_val
tensor.data, maybe_top_k_arr, top_k_val, row_states_buffer
)

return TaggedTensor(renorm_probs, output_type)
Expand Down Expand Up @@ -168,8 +175,15 @@ def __call__(self, tensor: TaggedTensor, **kwargs: Any) -> TaggedTensor:
):
raise ValueError("top_k must be a positive integer or a tensor array")

# Allocate row_states buffer for multi-CTA kernel (1MB is enough for any GPU)
row_states_buffer = _get_cache_buf(
f"top_k_mask_logits_row_states_{tensor.data.device}",
1024 * 1024,
tensor.data.device,
zero_init=True,
)
masked_logits = get_sampling_module().top_k_mask_logits(
tensor.data, maybe_top_k_arr, top_k_val
tensor.data, maybe_top_k_arr, top_k_val, row_states_buffer
)
return TaggedTensor(masked_logits, output_type)

Expand Down
Loading