diff --git a/benchmarks/bench_sampling.py b/benchmarks/bench_sampling.py index 2eb2de3875..6054e05d87 100644 --- a/benchmarks/bench_sampling.py +++ b/benchmarks/bench_sampling.py @@ -50,6 +50,11 @@ def init_seed_top_p_sampling(*args, **kwargs): return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs) +def init_seed_radik_sampling(*args, **kwargs): + torch.manual_seed(42) + return flashinfer.sampling.radik_sampling_from_probs(*args, **kwargs) + + @torch.inference_mode() def main(): print("---") @@ -119,6 +124,41 @@ def main(): f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" ) + print("---") + print("radik sampling") + for vocab_size in [128512]: + for batch_size in [1, 16, 32, 64, 128, 256, 512]: + for distrib in [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + gumbel_distribution(1), + ]: + for deterministic in [True, False]: + for k in [10, 100, 1000]: + logits = distrib((batch_size, vocab_size), device="cuda") + probs = torch.softmax(logits, dim=-1) + samples = torch.zeros( + batch_size, dtype=torch.int32, device=probs.device + ) + measurements = bench_gpu_time( + lambda: init_seed_radik_sampling( + probs, k, deterministic=deterministic + ), + dry_run_time_ms=100, + repeat_time_ms=1000, + ) + ms = np.median(measurements) + + io = ( + probs.numel() * probs.element_size() + + samples.numel() * samples.element_size() + ) + bandwidth = io * 1e-6 / ms + print( + f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" + ) + print("---") print("top-p sampling") diff --git a/csrc/flashinfer_sampling_ops.cu b/csrc/flashinfer_sampling_ops.cu index d93057bd87..1cdb24c539 100644 --- a/csrc/flashinfer_sampling_ops.cu +++ b/csrc/flashinfer_sampling_ops.cu @@ -63,6 +63,12 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i at::Tensor output_emitted_draft_token_num, bool deterministic, std::optional gen); +void radik_sampling_from_probs(at::Tensor workspace_buffer, at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_top_k_arr, int64_t top_k_val, + bool deterministic, std::optional maybe_selected_probs, + std::optional gen); + TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // Softmax m.def("softmax", softmax); @@ -86,4 +92,6 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("top_k_mask_logits", top_k_mask_logits); // Speculative sampling from sequence of probabilities m.def("chain_speculative_sampling", chain_speculative_sampling); + // Top-k sampling with Radix Sorting + m.def("radik_sampling_from_probs", radik_sampling_from_probs); } diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 654bb582e4..348973eb02 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -279,3 +279,45 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " + std::string(cudaGetErrorString(status))); } + +void radik_sampling_from_probs(at::Tensor workspace_buffer, at::Tensor probs, at::Tensor output, + std::optional maybe_indices, + std::optional maybe_top_k_arr, int64_t top_k_val, + bool deterministic, std::optional maybe_selected_probs, + std::optional gen_) { + CHECK_INPUT(workspace_buffer); + CHECK_INPUT(probs); + CHECK_INPUT(output); + CHECK_GE(1024, top_k_val); // only support top-k <= 1024 currently + auto device = probs.device(); + CHECK_EQ(output.device(), device); + CHECK_DIM(2, probs); + CHECK_DIM(1, output); + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); + + bool has_top_k_arr = maybe_top_k_arr.has_value(); + + uint64_t philox_seed, philox_offset; + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + std::lock_guard lock(gen->mutex_); + at::PhiloxCudaState rng_engine_inputs = gen->philox_cuda_state(32 * batch_size); + philox_seed = rng_engine_inputs.seed_.val; + philox_offset = rng_engine_inputs.offset_.val; + + const c10::cuda::OptionalCUDAGuard device_guard(device); + auto stream = at::cuda::getCurrentCUDAStream(); + cudaError_t status = sampling::RadiKSamplingFromProb( + static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices->data_ptr()) : nullptr, + has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, probs.size(0), + maybe_indices.has_value() ? maybe_indices->size(0) : batch_size, top_k_val, vocab_size, + philox_seed, philox_offset, deterministic, workspace_buffer.data_ptr(), + workspace_buffer.element_size() * workspace_buffer.size(0), + maybe_selected_probs.has_value() ? static_cast(maybe_selected_probs->data_ptr()) + : nullptr, + stream); + TORCH_CHECK(status == cudaSuccess, "RadiKSamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); +} diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 4cd7e5bd5a..a601837a4c 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -28,6 +28,8 @@ device_support_pdl, register_custom_op, register_fake_op, + get_radik_workspace_size, + _is_buf_cached, ) @@ -214,6 +216,37 @@ def top_k_sampling_from_probs( ) return samples + @register_custom_op( + "flashinfer::radik_sampling_from_probs", mutates_args=("workspace_buffer",) + ) + def radik_sampling_from_probs( + workspace_buffer: torch.Tensor, + probs: torch.Tensor, + indices: Optional[torch.Tensor], + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + deterministic: bool, + selected_probs: Optional[torch.Tensor], + generator: Optional[torch.Generator], + ) -> torch.Tensor: + device = probs.device + probs = probs.float() + batch_size = indices.size(0) if indices is not None else probs.size(0) + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + samples = torch.empty(batch_size, dtype=torch.int32, device=device) + module.radik_sampling_from_probs.default( + workspace_buffer, + probs, + samples, + indices, + maybe_top_k_arr, + top_k_val, + deterministic, + selected_probs, + generator, + ) + return samples + @register_fake_op("flashinfer::top_k_sampling_from_probs") def _fake_top_k_sampling_from_probs( probs: torch.Tensor, @@ -453,6 +486,7 @@ def _fake_chain_speculative_sampling( top_k_renorm_probs=top_k_renorm_probs, top_k_mask_logits=top_k_mask_logits, chain_speculative_sampling=chain_speculative_sampling, + radik_sampling_from_probs=radik_sampling_from_probs, ) @@ -801,11 +835,99 @@ def top_k_sampling_from_probs( if check_nan: if torch.any(torch.isnan(probs)): raise ValueError("Input probs contains NaN.") + + # dispatch non-determinitic and small top-k requests to radik_sampling_from_probs + use_radik_impl = not deterministic and ( + (isinstance(top_k, int) and top_k <= 100) + or (isinstance(top_k, torch.Tensor) and top_k.max() <= 100) + ) + # Check if GPU memory is available for radik_sampling_from_probs + is_radik_buf_cached, radik_buf_bytes = _is_buf_cached( + "radik_sampling_from_probs_workspace", probs.device + ) + required_radik_buf_bytes = get_radik_workspace_size(probs, top_k) + memory_avaliable = ( + is_radik_buf_cached and radik_buf_bytes >= required_radik_buf_bytes + ) or ( + not is_radik_buf_cached + and torch.cuda.mem_get_info()[1] >= required_radik_buf_bytes + ) + + use_radik_impl = use_radik_impl and memory_avaliable + + if use_radik_impl: + return radik_sampling_from_probs( + probs, + top_k, + indices, + deterministic, + generator, + selected_probs=None, + check_nan=check_nan, + ) + return get_sampling_module().top_k_sampling_from_probs( probs, indices, *_to_tensor_scalar_tuple(top_k), deterministic, generator ) +def radik_sampling_from_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + selected_probs: Optional[torch.Tensor] = None, + check_nan: bool = False, +) -> torch.Tensor: + r"""GPU kernel for radix top-k sampling from probability distributions, + utilizing radix selection to efficiently identify top-k elements followed by sampling from the selected subset. + Check the `radik paper `_ for more details. + + Parameters + ---------- + probs: torch.Tensor + Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)`` + and the i-th output will be sampled from the i-th row of probabilities. When indices is provided, + shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique + probability distributions. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + indices: Optional[torch.Tensor] + Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs. + For example, if indices[i] = j, then the i-th output will be sampled from probs[j]. + This allows reusing the same probability distribution for multiple outputs. + If indices is not provided, the i-th output will be sampled from the i-th row of probs. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. However, the radix sampling process itself is inherently non-deterministic. + generator: Optional[torch.Generator] + A random number generator for the operation. + selected_probs: Optional[torch.Tensor] + Optional tensor of shape ``(batch_size, top_k)`` that stores the top-k selected probabilities. + check_nan: bool + Whether to check nan in :attr:`probs`, default is ``False``. + """ + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + + workspace_buffer = _get_cache_buf( + "radik_sampling_from_probs_workspace", 64 * 1024 * 1024, probs.device + ) + + return get_sampling_module().radik_sampling_from_probs( + workspace_buffer, + probs, + indices, + *_to_tensor_scalar_tuple(top_k), + deterministic, + selected_probs, + generator, + ) + + def min_p_sampling_from_probs( probs: torch.Tensor, min_p: Union[torch.Tensor, float], diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 64d9e1d95f..322ddc1903 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -195,6 +195,15 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: return buf +def _is_buf_cached(name: str, device: torch.device) -> Tuple[bool, int]: + key = (name, device) + if key in _cache_buf: + buf = _cache_buf[key] + bytes_size = buf.numel() * buf.element_size() + return (True, bytes_size) + return (False, 0) + + # find the least power of 2 that is greater than or equal to x def _ceil_pow2(x: int) -> int: return 1 << (x - 1).bit_length() @@ -737,3 +746,34 @@ def get_shuffle_matrix_sf_a_row_indices( row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) return row_indices + + +def get_radik_workspace_size(probs: torch.Tensor, top_k: Union[torch.Tensor, int]): + """ + Calculate the workspace size required for the radix select algorithm + + Args: + probs: The input probabilities + top_k: The k value in top-k selection + + Returns: + size_in_bytes: Required workspace size in bytes + """ + k = top_k.max() if isinstance(top_k, torch.Tensor) else top_k + task_num = probs.size(0) + vocab_size = probs.size(1) + + sizeof_CompT = 4 + sizeof_int = 4 + sizeof_T = probs.element_size() + sizeof_IdxType = 4 + + size_in_bytes = task_num * ( + sizeof_CompT * vocab_size * 2 # buffer for val + + sizeof_int * (1**12) # buffer for hist (4096 = 2^12) + + sizeof_int * 5 # buffer for globalCount,old_taskLen,new_taskLen,K,binId + + sizeof_T * k # buffer for top-k select result + + sizeof_IdxType * k # buffer for top-k select result + ) + + return size_in_bytes diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..ca9dded7d4 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -28,9 +28,11 @@ #include #include #include +#include #include "allocator.h" #include "math.cuh" +#include "pytorch_extension_utils.h" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -77,6 +79,15 @@ using namespace cub; __VA_ARGS__ \ } +#define DISPATCH_RADIK_SHM_AWARE_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 256; \ + __VA_ARGS__ \ + } + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -2227,6 +2238,711 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids }); } +template +struct ComputeT { + using type = ValType; +}; + +template +void getRadixSelectWorkSpaceSize(const int& K, const int& n, const int& taskNum, + size_t* sizeInBytes) { + using CompT = typename ComputeT::type; + *sizeInBytes = + taskNum * (sizeof(CompT) * n * 2 /* buffer for val */ + + sizeof(int) * (1 << 12) /* buffer for hist */ + + sizeof(int) * 5 /* buffer for globalCount,old_taskLen,new_taskLen,K,binId */ + + sizeof(T) * K /* buffer for top-k select result*/ + + sizeof(IdxType) * K); /* buffer for top-k select result*/ + return; +} + +static constexpr int MAX_GRID_SIZE = 1280; + +template +__device__ __forceinline__ int getBinId(const float& a) { + const uint32_t& u_a = reinterpret_cast(a); + uint32_t mask = ((~(u_a >> 31)) + 1) | 0x80000000; + return static_cast(((u_a ^ mask) << LEFT) >> RIGHT); +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + countBinKernel(const float* dataIn, const int* taskLenPtr, int* histPtr, const int stride, + const int taskNum) { + constexpr int histLen = 1 << (8 * sizeof(float) - RIGHT); + const int taskId = blockIdx.y; + const int taskLen = taskLenPtr[taskId]; + const int tid = blockIdx.x * BLOCK_THREADS + threadIdx.x; + const int tx = threadIdx.x; + __shared__ int blockHist[histLen]; + +#pragma unroll + for (int i = tx; i < histLen; i += BLOCK_THREADS) { + blockHist[i] = 0; + } + __syncthreads(); + + if (tid < taskLen) { + const int binID = getBinId(dataIn[taskId * stride + tid]); + atomicAdd(&blockHist[binID], 1); + } + __syncthreads(); + +#pragma unroll + for (int i = tx; i < histLen; i += BLOCK_THREADS) { + if (blockHist[i] > 0) { + atomicAdd(&histPtr[taskId * histLen + i], blockHist[i]); + } + } + return; +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + countBinExKernel(const T* dataIn, const int* taskLenPtr, int* histPtr, const int stride, + const int taskNum, IdType* indices) { + using CompT = typename ComputeT::type; + const int bx = blockIdx.x, tx = threadIdx.x; + const int taskId = blockIdx.y; + const int row_idx = indices == nullptr ? taskId : indices[taskId]; + + constexpr int histLen = 1 << (8 * sizeof(CompT) - RIGHT); + __shared__ int blockHist[histLen]; + for (int i = tx; i < histLen; i += BLOCK_THREADS) { + blockHist[i] = 0; + } + __syncthreads(); + + vec_t dataIn_vec; + const int taskLen = taskLenPtr[taskId]; + + if ((bx * BLOCK_THREADS + tx) * VEC_SIZE < taskLen) { + dataIn_vec.cast_load(dataIn + row_idx * stride + (bx * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + const int binId = getBinId(dataIn_vec[j]); + atomicAdd(&blockHist[binId], 1); + } + } + __syncthreads(); + +#pragma unroll + for (int i = tx; i < histLen; i += BLOCK_THREADS) { + if (blockHist[i] > 0) { + atomicAdd(&histPtr[histLen * taskId + i], blockHist[i]); + } + } + return; +} + +template +struct SelectBinTempStorage { + union { + typename cub::BlockScan::TempStorage scan; + } block_prim; +}; + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + selectBinKernel(const int* histPtr, int* binIdPtr, int* kPtr, int* taskLenPtr) { + static_assert(HISTLEN % BLOCK_THREADS == 0, "HISTLEN % BLOCK_THREADS != 0"); + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const int taskId = bx; + const int taskLen = taskLenPtr[taskId]; + const int UNROLL = HISTLEN / BLOCK_THREADS; + + int count[UNROLL]; + int thread_sum = 0; + int oldK = kPtr[taskId]; + using TempStorage = SelectBinTempStorage; + extern __shared__ __align__(alignof(TempStorage)) uint8_t smem[]; + auto& temp_storage = reinterpret_cast(smem); + +#pragma unroll + for (int i = 0; i < UNROLL; i++) { + count[i] = histPtr[taskId * HISTLEN + tx * UNROLL + i]; + thread_sum += count[i]; + } + __syncthreads(); + + // Compute PrefixSum + int exclusive_prefix_sum = 0; + int block_sum = 0; + cub::BlockScan(temp_storage.block_prim.scan) + .ExclusiveSum(thread_sum, exclusive_prefix_sum, block_sum); + __syncthreads(); + + int inclusive_suffix_sum = block_sum - exclusive_prefix_sum; + int exclusive_suffix_sum = inclusive_suffix_sum - thread_sum; + + if (oldK > exclusive_suffix_sum && oldK <= inclusive_suffix_sum) { + oldK -= exclusive_suffix_sum; + +#pragma unroll + for (int i = UNROLL - 1; i >= 0; i--) { + if (count[i] >= oldK) { + binIdPtr[taskId] = tx * UNROLL + i; + kPtr[taskId] = oldK; + taskLenPtr[taskId] = count[i]; + break; + } + oldK -= count[i]; + } + } + return; +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + selectCandidateKernel(float* dataIn, float* dataOut, int* globalCountPtr, const int* binIdPtr, + const int* taskLenPtr, const int stride) { + // TODO: Optimize kernel with double Block Buffer + __shared__ int blockCount[1]; + __shared__ float blockCache[BLOCK_THREADS]; + + const int taskId = blockIdx.y; + const int taskLen = taskLenPtr[taskId]; + const int mask = binIdPtr[taskId]; + int idx = blockIdx.x * BLOCK_THREADS + threadIdx.x; + + if (idx < taskLen && threadIdx.x == 0) { + blockCount[0] = 0; + } + __syncthreads(); + + if (idx < taskLen) { + float data = dataIn[taskId * stride + idx]; + if (mask == getBinId(data)) { + // printf("select task%d top-k candidate data: %.14f\n", taskId, data); + int pos = atomicAdd(blockCount, 1); + blockCache[pos] = data; + } + } + __syncthreads(); + + int count = blockCount[0]; + __syncthreads(); + + if (idx < taskLen && threadIdx.x == 0) { + blockCount[0] = atomicAdd(globalCountPtr + taskId, count); + } + __syncthreads(); + + if (idx < taskLen && threadIdx.x < count) { + dataOut[taskId * stride + blockCount[0] + threadIdx.x] = blockCache[threadIdx.x]; + } + return; +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + selectCandidateExKernel(T* dataIn, T* dataOut, int* globalCountPtr, const int* binIdPtr, + const int* taskLenPtr, const int stride, IdType* indices) { + // TODO: Optimize kernel with double Block Buffer + using CompT = typename ComputeT::type; + const int bx = blockIdx.x, tx = threadIdx.x; + const int taskId = blockIdx.y; + const int taskLen = taskLenPtr[taskId]; + const int mask = binIdPtr[taskId]; + const int row_idx = indices == nullptr ? taskId : indices[taskId]; + vec_t dataIn_vec; + + __shared__ int blockCount[1]; + __shared__ T blockCache[BLOCK_THREADS * VEC_SIZE]; + + if (tx == 0) { + blockCount[0] = 0; + } + __syncthreads(); + + if ((bx * BLOCK_THREADS + tx) * VEC_SIZE < taskLen) { + dataIn_vec.cast_load(dataIn + row_idx * stride + (bx * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + T data = dataIn_vec[j]; + const int binId = getBinId(data); + if (binId == mask) { + int pos = atomicAdd(blockCount, 1); + blockCache[pos] = data; + } + } + } + __syncthreads(); + + int count = blockCount[0]; + __syncthreads(); + + if (count > 0 && tx == 0) { + // printf("count: %d, globalcount: %d\n", count, globalCountPtr[taskId]); + blockCount[0] = atomicAdd(globalCountPtr + taskId, count); + } + __syncthreads(); + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + if (tx * VEC_SIZE + j < count) { + dataOut[taskId * stride + blockCount[0] + tx * VEC_SIZE + j] = blockCache[tx * VEC_SIZE + j]; + } + } + return; +} + +template +__global__ void __launch_bounds__(BLOCK_THREADS) + filterKernel(const ValType* dataIn, const typename ComputeT::type* kThElePtr, + ValType* valOut, IdxType* idxOut, int* globalCount, const int* top_k_arr, + const int* taskLenPtr, const int stride, const int max_top_k_val, + const int taskNum, IdxType* indices) { + using CompT = typename ComputeT::type; + + __shared__ int blockCount[1]; + __shared__ ValType valBlockCache[CACHE_SIZE]; + __shared__ IdxType idxBlockCache[CACHE_SIZE]; + + const int tx = threadIdx.x, bx = blockIdx.x; + const int taskId = bx; + const int taskLen = taskLenPtr[taskId]; + const int row_idx = indices == nullptr ? taskId : indices[taskId]; + const int k = top_k_arr == nullptr ? max_top_k_val : top_k_arr[taskId]; + + if (tx == 0) { + blockCount[0] = 0; + } + __syncthreads(); + + vec_t dataIn_vec; + if (taskLen < k) { + // copy all +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(taskLen, BLOCK_THREADS * VEC_SIZE); i++) { + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < taskLen) { + dataIn_vec.cast_load(dataIn + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + int pos = atomicAdd(blockCount, 1); + valBlockCache[pos] = dataIn_vec[j]; + idxBlockCache[pos] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + } + } + } else { + // N > K, filter by k-th element + const CompT kThElem = *(kThElePtr + taskId * stride); + +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(taskLen, BLOCK_THREADS * VEC_SIZE); i++) { + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < taskLen) { + dataIn_vec.cast_load(dataIn + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + if (dataIn_vec[j] > kThElem) { + int pos = atomicAdd(blockCount, 1); + if (pos < k) { + valBlockCache[pos] = dataIn_vec[j]; + idxBlockCache[pos] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + } + } + } + } + + // NOTE: selecting elements >= kThElem simultaneously might lead to an incorrect result. + // because #(elements >= kThElem) can be larger than k +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(taskLen, BLOCK_THREADS * VEC_SIZE); i++) { + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < taskLen) { + dataIn_vec.cast_load(dataIn + row_idx * stride + (i * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + if (dataIn_vec[j] == kThElem) { + // printf("task%d filter data: %.14f\n", taskId, dataIn_vec[j]); + int pos = atomicAdd(blockCount, 1); + if (pos < k) { + valBlockCache[pos] = dataIn_vec[j]; + idxBlockCache[pos] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; + } + } + } + } + } + } + __syncthreads(); + + // Note: BlockCount[0] can be large than k, because of the k-th element is not unique + for (int i = tx; i < std::min(k, blockCount[0]); i += BLOCK_THREADS) { + valOut[taskId * max_top_k_val + i] = valBlockCache[i]; + idxOut[taskId * max_top_k_val + i] = idxBlockCache[i]; + } + + return; +} + +template +__global__ void SamplingFromRadiKSelectKernel(DType* select_probs, IdType* select_indices, + IdType* output, int* top_k_arr, IdType* indices, + uint32_t max_top_k_val, uint64_t philox_seed, + uint64_t philox_offset) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const int row_idx = indices == nullptr ? bx : indices[bx]; + const int K = top_k_arr == nullptr ? max_top_k_val : top_k_arr[bx]; + curandStatePhilox4_32_10_t state; + curand_init(philox_seed, bx, philox_offset, &state); + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + temp_storage.sampled_id = K; + __syncthreads(); + __shared__ DType block_sum; + + DType thread_sum = 0.f; + vec_t probs_vec; + + for (uint32_t i = 0; i < ceil_div(K, BLOCK_THREADS * VEC_SIZE); ++i) { + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < K) { + probs_vec.cast_load(select_probs + row_idx * max_top_k_val + + (i * BLOCK_THREADS + tx) * VEC_SIZE); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + thread_sum += probs_vec[j]; + } + } + } + __syncthreads(); + + DType _block_sum = + cub::BlockReduce(temp_storage.block_prim.reduce).Sum(thread_sum); + if (tx == 0) { + block_sum = _block_sum; + } + __syncthreads(); + + DType renorm_factor = math::ptx_rcp(max(block_sum, 1e-8)); + + // SamplingFrom Renorm probs + float aggregate(0); + float u = curand_uniform(&state); + +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(K, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < K) { + probs_vec.cast_load(select_probs + row_idx * max_top_k_val + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); + } + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = probs_vec[j] * renorm_factor; + } + + DeviceSamplingFromProb( + i, K, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); + if (float(aggregate) > u) { + break; + } + } + + int sampled_id = temp_storage.sampled_id; + if (sampled_id == K) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + output[bx] = select_indices[row_idx * max_top_k_val + sampled_id]; + return; +} + +template +cudaError_t RadiKSamplingFromProb(T* probs, IdType* output, IdType* indices, int* top_k_arr, + uint32_t batch_size, uint32_t output_batch_size, + uint32_t top_k_val, uint32_t d, uint64_t philox_seed, + uint64_t philox_offset, bool deterministic, + void* workspace_buffer, size_t workspace_buffer_size_in_bytes, + T* selected_probs, cudaStream_t stream = 0) { + using CompT = typename ComputeT::type; + uint32_t vec_size = std::gcd(16 / sizeof(T), d); + auto compute_capacity = GetCudaComputeCapability(); + + std::vector tmpK(batch_size, top_k_val); + if (top_k_arr != nullptr) { + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(tmpK.data(), top_k_arr, sizeof(int) * batch_size, + cudaMemcpyDefault, stream)); + } + uint32_t max_top_k_val = + top_k_arr == nullptr ? top_k_val : *std::max_element(tmpK.begin(), tmpK.end()); + + size_t workSpaceSize = 0; + getRadixSelectWorkSpaceSize(max_top_k_val, d, batch_size, &workSpaceSize); + void* workSpace = 0; + bool workspace_allocated = false; + if (workSpaceSize > workspace_buffer_size_in_bytes) { + auto const cu_malloc_status = cudaMalloc(&workSpace, workSpaceSize); + TORCH_CHECK(cu_malloc_status == cudaSuccess, + "CUDA out of memory when allocating workspace for radix select."); + workspace_allocated = true; + } else { + workSpace = workspace_buffer; + } + + CompT* valBuffer[2]{static_cast(workSpace), + static_cast(workSpace) + batch_size * d}; + int* histPtr = reinterpret_cast(valBuffer[1] + batch_size * d); + int* globalCountPtr = histPtr + (1 << 12) * batch_size; + + int* taskLenPtr[2]{globalCountPtr + batch_size, globalCountPtr + 2 * batch_size}; + std::vector tmpTaskLen(2 * batch_size, d); + + int* kPtr = taskLenPtr[1] + batch_size; + + int* binIdPtr = kPtr + batch_size; + T* top_k_select_result = reinterpret_cast(binIdPtr + batch_size); + IdType* top_k_select_idx = + reinterpret_cast(top_k_select_result + batch_size * max_top_k_val); + std::vector taskLenHost(batch_size); + + // TODO: use a standalone kernel to initialize taskLenPtr, kPtr, histPtr... + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(taskLenPtr[0], tmpTaskLen.data(), + sizeof(int) * batch_size * 2, cudaMemcpyDefault, stream)); + + FLASHINFER_CUDA_CALL( + cudaMemcpyAsync(kPtr, tmpK.data(), sizeof(int) * batch_size, cudaMemcpyDefault, stream)); + + // clear hist and globalCount + FLASHINFER_CUDA_CALL( + cudaMemsetAsync(histPtr, 0, sizeof(int) * batch_size * ((1 << 12) + 1), stream)); + + // === Iter 1 === + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + // iter1: countBin + dim3 countbin_iter1_nblks((d + BLOCK_THREADS * VEC_SIZE - 1) / (BLOCK_THREADS * VEC_SIZE), + batch_size); + dim3 countbin_iter1_nthrs(BLOCK_THREADS); + + auto countbin_iter1_kernel = countBinExKernel; + void* countbin_iter1_args[] = {&probs, &taskLenPtr[0], &histPtr, &d, &batch_size, &indices}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)countbin_iter1_kernel, countbin_iter1_nblks, + countbin_iter1_nthrs, countbin_iter1_args, 0, stream)); + + // iter1: selectBin + dim3 selectbin_iter1_nblks(batch_size); + dim3 selectbin_iter1_nthrs(BLOCK_THREADS); + size_t smem_size = sizeof(SelectBinTempStorage); + + auto selectbin_iter1_kernel = selectBinKernel; + void* selectbin_iter1_args[] = {&histPtr, &binIdPtr, &kPtr, &taskLenPtr[0]}; + + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + selectbin_iter1_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectbin_iter1_kernel, selectbin_iter1_nblks, + selectbin_iter1_nthrs, selectbin_iter1_args, smem_size, + stream)); + }); + + DISPATCH_RADIK_SHM_AWARE_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + // iter1: selectCandidate + dim3 selectcan_iter1_nblks((d + BLOCK_THREADS * VEC_SIZE - 1) / (BLOCK_THREADS * VEC_SIZE), + batch_size); + dim3 selectcan_iter1_nthrs(BLOCK_THREADS); + + auto selectcan_iter1_kernel = + selectCandidateExKernel; + void* selectcan_iter1_args[] = { + &probs, &valBuffer[0], &globalCountPtr, &binIdPtr, &taskLenPtr[1], &d, &indices}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectcan_iter1_kernel, selectcan_iter1_nblks, + selectcan_iter1_nthrs, selectcan_iter1_args, 0, + stream)); + }); + }); + + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(taskLenHost.data(), taskLenPtr[0], sizeof(int) * batch_size, + cudaMemcpyDefault, stream)); + + // === Iter 2 === + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + int maxTaskLen = *std::max_element(taskLenHost.begin(), taskLenHost.end()); + int flag = 0; + + if (maxTaskLen != 1) { + // clear hist and globalCount + FLASHINFER_CUDA_CALL( + cudaMemsetAsync(histPtr, 0, sizeof(int) * ((1 << 12) + 1) * batch_size, stream)); + + // iter2: countBin + dim3 countbin_iter2_nblks((maxTaskLen + BLOCK_THREADS - 1) / BLOCK_THREADS, batch_size); + dim3 countbin_iter2_nthrs(BLOCK_THREADS); + + auto countbin_iter2_kernel = countBinKernel; + void* countbin_iter2_args[] = {&valBuffer[flag], &taskLenPtr[flag], &histPtr, &d, + &batch_size}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)countbin_iter2_kernel, countbin_iter2_nblks, + countbin_iter2_nthrs, countbin_iter2_args, 0, stream)); + + // iter2: selectBin + dim3 selectbin_iter2_nblks(batch_size); + dim3 selectbin_iter2_nthrs(BLOCK_THREADS); + size_t smem_size = sizeof(SelectBinTempStorage); + + auto selectbin_iter2_kernel = selectBinKernel; + void* selectbin_iter2_args[] = {&histPtr, &binIdPtr, &kPtr, &taskLenPtr[flag ^ 1]}; + + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + selectbin_iter2_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectbin_iter2_kernel, selectbin_iter2_nblks, + selectbin_iter2_nthrs, selectbin_iter2_args, smem_size, + stream)); + + // iter2: selectCandidate + dim3 selectcan_iter2_nblks((maxTaskLen + BLOCK_THREADS - 1) / BLOCK_THREADS, batch_size); + dim3 selectcan_iter2_nthrs(BLOCK_THREADS); + + auto selectcan_iter2_kernel = selectCandidateKernel; + void* selectcan_iter2_args[] = {&valBuffer[flag], &valBuffer[flag ^ 1], &globalCountPtr, + &binIdPtr, &taskLenPtr[flag], &d}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectcan_iter2_kernel, selectcan_iter2_nblks, + selectcan_iter2_nthrs, selectcan_iter2_args, 0, + stream)); + + // update taskLen + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(taskLenHost.data(), taskLenPtr[flag ^ 1], + sizeof(int) * batch_size, cudaMemcpyDefault, stream)); + flag ^= 1; + } + + // === Iter 3 === + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + maxTaskLen = *std::max_element(taskLenHost.begin(), taskLenHost.end()); + if (maxTaskLen != 1) { + // clear hist and globalCount + int* new_histPtr = histPtr; + FLASHINFER_CUDA_CALL( + cudaMemsetAsync(new_histPtr, 0, sizeof(int) * ((1 << 12) + 1) * batch_size, stream)); + // iter3: countBin + dim3 countbin_iter3_nblks((maxTaskLen + BLOCK_THREADS - 1) / BLOCK_THREADS, batch_size); + dim3 countbin_iter3_nthrs(BLOCK_THREADS); + auto countbin_iter3_kernel = countBinKernel; + void* countbin_iter3_args[] = {&valBuffer[flag], &taskLenPtr[flag], &new_histPtr, &d, + &batch_size}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)countbin_iter3_kernel, countbin_iter3_nblks, + countbin_iter3_nthrs, countbin_iter3_args, 0, stream)); + + // iter3: selectBin + // NOTE: setting Third iter selectBin BLOCK_THREADS=256 + dim3 selectbin_iter3_nblks(batch_size); + dim3 selectbin_iter3_nthrs(256); + size_t smem_size = sizeof(SelectBinTempStorage<256>); + + auto selectbin_iter3_kernel = selectBinKernel<256, (1 << 8)>; + void* selectbin_iter3_args[] = {&new_histPtr, &binIdPtr, &kPtr, &taskLenPtr[flag ^ 1]}; + + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + selectbin_iter3_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectbin_iter3_kernel, selectbin_iter3_nblks, + selectbin_iter3_nthrs, selectbin_iter3_args, smem_size, + stream)); + + // iter3: selectCandidate + dim3 selectcan_iter3_nblks((maxTaskLen + BLOCK_THREADS - 1) / BLOCK_THREADS, batch_size); + dim3 selectcan_iter3_nthrs(BLOCK_THREADS); + + auto selectcan_iter3_kernel = selectCandidateKernel; + void* selectcan_iter3_args[] = {&valBuffer[flag], &valBuffer[flag ^ 1], &globalCountPtr, + &binIdPtr, &taskLenPtr[flag], &d}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)selectcan_iter3_kernel, selectcan_iter3_nblks, + selectcan_iter3_nthrs, selectcan_iter3_args, 0, + stream)); + + flag ^= 1; + } + // clear globalCount + FLASHINFER_CUDA_CALL(cudaMemsetAsync(globalCountPtr, 0, sizeof(int) * batch_size, stream)); + // reset taskLenPtr + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(taskLenPtr[0], tmpTaskLen.data(), sizeof(int) * batch_size, + cudaMemcpyDefault, stream)); + +#define RADIX_TOPK_CALL_FILTER(CACHE_SIZE) \ + do { \ + auto filter_kernel = filterKernel; \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)filter_kernel, filter_nblks, filter_nthrs, \ + filter_args, 0, stream)); \ + } while (0) + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + dim3 filter_nblks(batch_size); + dim3 filter_nthrs(BLOCK_THREADS); + void* filter_args[] = {&probs, + &valBuffer[flag], + &top_k_select_result, + &top_k_select_idx, + &globalCountPtr, + &top_k_arr, + &taskLenPtr[0], + &d, + &max_top_k_val, + &batch_size, + &indices}; + if (max_top_k_val <= 128) { + RADIX_TOPK_CALL_FILTER(128); + } else if (max_top_k_val <= 256) { + RADIX_TOPK_CALL_FILTER(256); + } else if (max_top_k_val <= 512) { + RADIX_TOPK_CALL_FILTER(512); + } else if (max_top_k_val <= 1024) { + RADIX_TOPK_CALL_FILTER(1024); + } + }); + + // TODO: Fix misaligned access to top_k_select_result when VEC_SIZE > 1 and top_k_arr == nullptr + vec_size = 1; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + dim3 sample_nblks(output_batch_size); + dim3 sample_nthrs(BLOCK_THREADS); + auto sample_kernel = SamplingFromRadiKSelectKernel; + const uint32_t sample_smem_size = + sizeof(SamplingTempStorage); + void* sample_args[] = {&top_k_select_result, &top_k_select_idx, &output, + &top_k_arr, &indices, &max_top_k_val, + &philox_seed, &philox_offset}; + + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)sample_kernel, sample_nblks, sample_nthrs, + sample_args, sample_smem_size, stream)); + }); + }); + }); + + // check global histgram + FLASHINFER_CUDA_CALL(cudaStreamSynchronize(stream)); + + if (selected_probs != nullptr) { + FLASHINFER_CUDA_CALL(cudaMemcpy(selected_probs, top_k_select_result, + sizeof(T) * batch_size * top_k_val, cudaMemcpyDefault)); + } + + if (workspace_allocated) { + FLASHINFER_CUDA_CALL(cudaFree(workSpace)) + } + return cudaSuccess; +} + } // namespace sampling } // namespace flashinfer diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 333a24bce8..26f159a8b7 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -557,6 +557,94 @@ def test_chain_speculative_sampling( assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1)) +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize( + "distribution", + [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + ], +) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_radik_top_k_sampling(batch_size, vocab_size, distribution, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = distribution((batch_size, vocab_size), "cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + + num_trails = 1000 + for _ in range(num_trails): + samples = flashinfer.sampling.radik_sampling_from_probs(normalized_prob, k) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("batch_size", [1, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_radik_top_k_sampling_with_variable_k(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + k = torch.randint(1, k + 1, (batch_size,), device="cuda:0") + pivot = sorted_prob[torch.arange(batch_size), k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + + num_trails = 1000 + for _ in range(num_trails): + samples = flashinfer.sampling.radik_sampling_from_probs(normalized_prob, k) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) +@pytest.mark.parametrize( + "distribution", + [ + normal_distribution(1), + normal_distribution(5), + gumbel_distribution(0.1), + ], +) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_radik_top_k_sampling_freq(vocab_size, distribution, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + logits = distribution((1, vocab_size), "cuda:0") + probs = torch.softmax(logits, dim=-1) + sorted_prob, _ = torch.sort(probs, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (probs >= pivot.unsqueeze(-1)).int() + + renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k) + counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device) + num_trials = 5000000 + samples = flashinfer.sampling.radik_sampling_from_probs( + probs, + k, + indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device), + ) + counter.scatter_add_(0, samples.long(), torch.ones_like(samples)) + freq = counter.float() / num_trials + assert torch.all(mask[torch.arange(1), samples] == 1) + similarity = torch.cosine_similarity(freq, renorm_probs) + assert similarity > 0.99, f"similarity: {similarity}" + + if __name__ == "__main__": # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_sampling_from_logits_freq(128256, gumbel_distribution(0.1))