Skip to content
Merged
6 changes: 4 additions & 2 deletions csrc/flashinfer_sampling_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_p_arr, double top_p_val);

void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val);
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer);

void top_k_mask_logits(TensorView logits, TensorView mask_logits,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val);
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer);

void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_ids,
TensorView target_probs, TensorView output_token_ids,
Expand Down
25 changes: 25 additions & 0 deletions csrc/flashinfer_topk_binding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tvm_ffi_utils.h"

using tvm::ffi::Optional;

void radix_topk(TensorView input, TensorView output_indices,
Optional<TensorView> maybe_output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k);

// Radix-based Top-K selection
TVM_FFI_DLL_EXPORT_TYPED_FUNC(radix_topk, radix_topk);
43 changes: 33 additions & 10 deletions csrc/renorm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs,
}

void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer) {
CHECK_INPUT(probs);
CHECK_INPUT(row_states_buffer);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
Expand All @@ -52,18 +54,29 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs,

ffi::CUDADeviceGuard device_guard(probs.device().device_id);
auto stream = get_stream(probs.device());
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

cudaError_t status;
auto dtype = probs.dtype();

// Use radix-based top-k with dtype dispatch for FP32/FP16/BF16
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKRenormProbMultiCTA<c_type, int>(
static_cast<c_type*>(probs.data_ptr()), static_cast<c_type*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, static_cast<sampling::RadixRowState*>(row_states_buffer.data_ptr()),
stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRenormProb failed with error code " << cudaGetErrorString(status);
}

void top_k_mask_logits(TensorView logits, TensorView mask_logits,
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val) {
Optional<TensorView> maybe_top_k_arr, int64_t top_k_val,
TensorView row_states_buffer) {
CHECK_INPUT(logits);
CHECK_INPUT(row_states_buffer);
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits.size(0);
unsigned int vocab_size = logits.size(1);
Expand All @@ -72,10 +85,20 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits,

ffi::CUDADeviceGuard device_guard(logits.device().device_id);
auto stream = get_stream(logits.device());
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);

cudaError_t status;
auto dtype = logits.dtype();

// Use radix-based top-k with auto-selection (single-CTA for small vocab, multi-CTA for large
// vocab)
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::RadixTopKMaskLogitsMultiCTA<c_type, int>(
static_cast<c_type*>(logits.data_ptr()), static_cast<c_type*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, static_cast<sampling::RadixRowState*>(row_states_buffer.data_ptr()),
stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKMaskLogits failed with error code " << cudaGetErrorString(status);
Expand Down
65 changes: 65 additions & 0 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/sampling.cuh>

#include "tvm_ffi_utils.h"

using namespace flashinfer;

using tvm::ffi::Optional;

void radix_topk(TensorView input, TensorView output_indices,
Optional<TensorView> maybe_output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_DIM(2, input); // input: (batch_size, d)
CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k)

unsigned int batch_size = input.size(0);
unsigned int d = input.size(1);

cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

cudaError_t status;
auto dtype = input.dtype();

// Get row_states_buffer if provided (for multi-CTA path)
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
c_type* output_values_ptr = nullptr;
if (maybe_output_values.has_value()) {
CHECK_INPUT(maybe_output_values.value());
CHECK_DIM(2, maybe_output_values.value());
output_values_ptr = static_cast<c_type*>(maybe_output_values.value().data_ptr());
}
status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
output_values_ptr, // output_values (nullptr if not writing values)
nullptr, // top_k_arr
batch_size, static_cast<uint32_t>(top_k), d, row_states_ptr, stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RadixTopK failed with error code " << cudaGetErrorString(status);
}
17 changes: 17 additions & 0 deletions csrc/tvm_ffi_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0};
} \
}()

// Dispatcher for FP32/FP16/BF16 data types
#define DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dlpack_dtype, c_type, ...) \
[&]() -> bool { \
switch (encode_dlpack_dtype(dlpack_dtype)) { \
case float32_code: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
TVM_FFI_ICHECK(false) << __PRETTY_FUNCTION__ << " failed to dispatch data type " \
<< (dlpack_dtype).code << " " << (dlpack_dtype).bits; \
return false; \
} \
}()

#define _DISPATCH_CASE_I32(c_type, ...) \
case int32_code: { \
using c_type = int32_t; \
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
from .sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs
from .sampling import top_p_renorm_probs as top_p_renorm_probs
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
from . import topk as topk
from .topk import top_k as top_k
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
from .sparse import (
VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper,
Expand Down
5 changes: 4 additions & 1 deletion flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def clear_cache_dir():
"-gencode=arch=compute_89,code=sm_89",
"-DFLASHINFER_ENABLE_FP8_E8M0",
]
sm90a_nvcc_flags = ["-gencode=arch=compute_90a,code=sm_90a"] + common_nvcc_flags
sm90a_nvcc_flags = [
"-gencode=arch=compute_90a,code=sm_90a",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
] + common_nvcc_flags
sm100a_nvcc_flags = ["-gencode=arch=compute_100a,code=sm_100a"] + common_nvcc_flags
sm103a_nvcc_flags = ["-gencode=arch=compute_103a,code=sm_103a"] + common_nvcc_flags
sm100f_nvcc_flags = ["-gencode=arch=compute_100f,code=sm_100f"] + common_nvcc_flags
Expand Down
28 changes: 28 additions & 0 deletions flashinfer/jit/topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Copyright (c) 2024 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from . import env as jit_env
from .core import JitSpec, gen_jit_spec


def gen_topk_module() -> JitSpec:
return gen_jit_spec(
"topk",
[
jit_env.FLASHINFER_CSRC_DIR / "topk.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_topk_binding.cu",
],
)
Loading