-
Notifications
You must be signed in to change notification settings - Fork 593
feat: add sink to flashinfer decode #2087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -355,6 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par | |||||||||||||||||||||||||||||||
| // sync local state of all warps inside a threadblock | ||||||||||||||||||||||||||||||||
| sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast<float*>(smem), smem_md, | ||||||||||||||||||||||||||||||||
| tx, ty, tz); | ||||||||||||||||||||||||||||||||
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||||||||||||||||||||||||||||||||
| if constexpr (variant.use_softmax) { | ||||||||||||||||||||||||||||||||
| if (params.maybe_s_aux != nullptr) { | ||||||||||||||||||||||||||||||||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||||||||||||||||||||||||||||||||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||||||||||||||||||||||||||||||||
| st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
Comment on lines
+358
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic for adding the sink contribution is duplicated in Also, the constant For example, you could add at the top of the file: namespace flashinfer {
namespace { // anonymous namespace
static constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
const Params& params,
uint32_t qo_head_idx) {
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
}
} // anonymous namespace
// ... rest of the fileThen you could replace this block and the one in AddSinkContribution(variant, st_local, params, qo_head_idx);
Comment on lines
+358
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix s_aux scaling to match logits path. s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E. Apply: - if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+ }
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < vec_size; ++i) { | ||||||||||||||||||||||||||||||||
| st_local.o[i] = variant.OutputTransform(params, st_local.o[i], /*batch_idx=*/0, /*qo_idx=*/0, | ||||||||||||||||||||||||||||||||
|
|
@@ -589,6 +597,14 @@ __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& param | |||||||||||||||||||||||||||||||
| // sync local state of all warps inside a threadblock | ||||||||||||||||||||||||||||||||
| sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast<float*>(smem), smem_md, tx, ty, | ||||||||||||||||||||||||||||||||
| tz); | ||||||||||||||||||||||||||||||||
| // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed | ||||||||||||||||||||||||||||||||
| if constexpr (variant.use_softmax) { | ||||||||||||||||||||||||||||||||
| if (params.maybe_s_aux != nullptr) { | ||||||||||||||||||||||||||||||||
| constexpr float LOG2_E = 1.4426950408889634f; // log2(e) | ||||||||||||||||||||||||||||||||
| float s_aux_val = params.maybe_s_aux[qo_head_idx]; | ||||||||||||||||||||||||||||||||
| st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
Comment on lines
+601
to
+607
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same scaling fix for batch kernel. Mirror the s_aux scaling correction here to keep behavior consistent across kernels. - if constexpr (variant.use_softmax) {
- if (params.maybe_s_aux != nullptr) {
- constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
- float s_aux_val = params.maybe_s_aux[qo_head_idx];
- st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
- }
- }
+ if constexpr (variant.use_softmax) {
+ if (params.maybe_s_aux != nullptr) {
+ float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+ st.d += math::ptx_exp2(s_aux_scaled - st.m);
+ }
+ }
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < vec_size; ++i) { | ||||||||||||||||||||||||||||||||
| st.o[i] = variant.OutputTransform(params, st.o[i], bx, /*qo_idx=*/0, qo_head_idx, st.m, st.d, | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid breaking single-decode JIT: remove maybe_s_aux here (or plumb it through Python).
The single-decode Python wrapper doesn’t pass maybe_s_aux. Adding it here shifts the C++ run signature and will misalign subsequent args (logits_soft_cap becomes the second tensor param), breaking calls.
Minimal fix:
If you want single-decode sink support, also update flashinfer/decode.py:get_single_decode_module wrappers to accept and pass maybe_s_aux in the correct position.
📝 Committable suggestion
🤖 Prompt for AI Agents