-
Notifications
You must be signed in to change notification settings - Fork 596
perf: bunch of features and optimizations for top-k (sampling + sparse attention) #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
51af95c
7a55718
b391eb7
ca14df1
7d71381
b2b960d
051ccf9
d6fb90b
72fea03
4532d88
566b432
a105af4
c912dec
18bb844
00b8bcf
63f07f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| /* | ||
| * Copyright (c) 2024 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| using tvm::ffi::Optional; | ||
|
|
||
| void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, | ||
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k); | ||
|
|
||
| // Radix-based Top-K selection | ||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(radix_topk, radix_topk); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| /* | ||
| * Copyright (c) 2024 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include <flashinfer/sampling.cuh> | ||
|
|
||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| using namespace flashinfer; | ||
|
|
||
| using tvm::ffi::Optional; | ||
|
|
||
| void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, | ||
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | ||
| CHECK_INPUT(input); | ||
| CHECK_INPUT(output_indices); | ||
| CHECK_INPUT(output_values); | ||
| CHECK_DIM(2, input); // input: (batch_size, d) | ||
| CHECK_DIM(2, output_indices); // output_indices: (batch_size, top_k) | ||
| CHECK_DIM(2, output_values); // output_values: (batch_size, top_k) | ||
|
|
||
| unsigned int batch_size = input.size(0); | ||
| unsigned int d = input.size(1); | ||
|
|
||
| cudaSetDevice(input.device().device_id); | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| cudaError_t status; | ||
| auto dtype = input.dtype(); | ||
|
|
||
| // Get row_states_buffer if provided (for multi-CTA path) | ||
| sampling::RadixRowState* row_states_ptr = nullptr; | ||
| if (maybe_row_states_buffer.has_value()) { | ||
| row_states_ptr = | ||
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | ||
| } | ||
|
|
||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | ||
| status = sampling::RadixTopKMultiCTA<c_type, int32_t>( | ||
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()), | ||
| static_cast<c_type*>(output_values.data_ptr()), | ||
| nullptr, // top_k_arr | ||
| batch_size, static_cast<uint32_t>(top_k), d, row_states_ptr, stream); | ||
| return true; | ||
| }); | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "RadixTopK failed with error code " << cudaGetErrorString(status); | ||
| } | ||
|
Comment on lines
+24
to
+60
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix BF16 dispatch + initialize void radix_topk(TensorView input, TensorView output_indices, TensorView output_values,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
@@
- cudaSetDevice(input.device().device_id);
+ TVM_FFI_ICHECK(cudaSetDevice(input.device().device_id) == cudaSuccess);
auto stream = get_stream(input.device());
- cudaError_t status;
+ cudaError_t status = cudaErrorInvalidValue;
auto dtype = input.dtype();
@@
- DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
+ DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16_BF16(dtype, c_type, [&] {
status = sampling::RadixTopKMultiCTA<c_type, int32_t>(
@@
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RadixTopK failed with error code " << cudaGetErrorString(status);
}If there is no
π€ Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """ | ||
| Copyright (c) 2024 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| from . import env as jit_env | ||
| from .core import JitSpec, gen_jit_spec | ||
|
|
||
|
|
||
| def gen_topk_module() -> JitSpec: | ||
| return gen_jit_spec( | ||
| "topk", | ||
| [ | ||
| jit_env.FLASHINFER_CSRC_DIR / "topk.cu", | ||
| jit_env.FLASHINFER_CSRC_DIR / "flashinfer_topk_binding.cu", | ||
| ], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add basic shape +
top_krange validation (prevents kernel OOB / nonsense launches).Only rank is validated today. At minimum, enforce:
top_k > 0top_k <= input.size(1)output_{indices,values}.size(1) == top_kThis avoids silent misbehavior when callers pass a mismatched
kvs output shapes.π€ Prompt for AI Agents