Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
_check_pos_encoding_mode,
check_shape_dtype_device,
_get_cache_alibi_slopes_buf,
_get_sink_buf,
_get_cache_buf,
_get_range_buf,
_unpack_paged_kv_cache,
Expand Down Expand Up @@ -242,6 +243,7 @@ def run_batch_decode(
window_left: int,
enable_pdl: bool,
alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
Expand All @@ -263,6 +265,7 @@ def run_batch_decode(
window_left,
enable_pdl,
alibi_slopes,
maybe_s_aux,
logits_soft_cap,
sm_scale,
1.0 / rope_scale, # rope_rcp_scale
Expand All @@ -286,6 +289,7 @@ def _fake_run_batch_decode(
window_left: int,
enable_pdl: bool,
alibi_slopes: Optional[torch.Tensor],
maybe_s_aux: Optional[torch.Tensor],
logits_soft_cap: float,
sm_scale: float,
rope_scale: float,
Expand Down Expand Up @@ -1330,7 +1334,7 @@ def run(
self._kv_lens_buffer,
page_size,
self._max_kv_len,
sinks,
_get_sink_buf(sinks),
]

self._cached_module.paged_run(*run_args)
Expand Down Expand Up @@ -1364,6 +1368,7 @@ def run(
else:
run_args += [
_get_cache_alibi_slopes_buf(q.shape[1], q.device),
_get_sink_buf(sinks),
logits_soft_cap,
sm_scale,
rope_scale,
Expand Down
8 changes: 4 additions & 4 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def gen_single_decode_module(
dtype_o,
head_dim_qk,
head_dim_vo,
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid breaking single-decode JIT: remove maybe_s_aux here (or plumb it through Python).

The single-decode Python wrapper doesn’t pass maybe_s_aux. Adding it here shifts the C++ run signature and will misalign subsequent args (logits_soft_cap becomes the second tensor param), breaking calls.

Minimal fix:

-        ["maybe_alibi_slopes", "maybe_s_aux"],  # additional_tensor_names
-        ["float", "float"],  # additional_tensor_dtypes
+        ["maybe_alibi_slopes"],  # additional_tensor_names
+        ["float"],  # additional_tensor_dtypes

If you want single-decode sink support, also update flashinfer/decode.py:get_single_decode_module wrappers to accept and pass maybe_s_aux in the correct position.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
[
🤖 Prompt for AI Agents
In flashinfer/jit/attention/modules.py around lines 470-472, adding
"maybe_s_aux" to the additional tensor names shifts the C++ JIT run signature
and breaks the single-decode wrapper which does not pass that arg; remove
"maybe_s_aux" from the additional_tensor_names and its dtype from
additional_tensor_dtypes so the signature remains unchanged, or alternatively
update flashinfer/decode.py:get_single_decode_module to accept a maybe_s_aux
parameter and forward it in the exact positional order expected by the JIT run
(ensure both the names and dtypes lists and all wrapper call sites stay
consistent).

"logits_soft_cap",
"sm_scale",
Expand Down Expand Up @@ -760,8 +760,8 @@ def gen_batch_decode_module(
dtype_idx,
head_dim_qk,
head_dim_vo,
["maybe_alibi_slopes"], # additional_tensor_names
["float"], # additional_tensor_dtypes
["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names
["float", "float"], # additional_tensor_dtypes
[
"logits_soft_cap",
"sm_scale",
Expand Down
17 changes: 17 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,23 @@ def _get_cache_alibi_slopes_buf(
return buf


def _get_sink_buf(
sinks: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
"""Convert sinks tensor to proper format for CUDA kernels.

Args:
sinks: Optional tensor of shape [num_qo_heads] with sink values per head

Returns:
Contiguous float32 tensor or None if sinks is None
"""
if sinks is None:
return None
# Ensure it's float32 and contiguous as expected by CUDA kernels
return sinks.to(torch.float32).contiguous()


def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
return getattr(torch, dtype)
Expand Down
16 changes: 16 additions & 0 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(variant, st_local, reinterpret_cast<float*>(smem), smem_md,
tx, ty, tz);
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
Comment on lines +358 to +365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for adding the sink contribution is duplicated in BatchDecodeWithPagedKVCacheDevice (lines 600-607). To improve maintainability and reduce code duplication, consider extracting this block into a helper function.

Also, the constant LOG2_E is defined inline here and in the other location. It would be better to define it once at the top of the file in an anonymous namespace to avoid magic numbers and ensure consistency.

For example, you could add at the top of the file:

namespace flashinfer {

namespace { // anonymous namespace

static constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)

template <typename AttentionVariant, typename State, typename Params>
__device__ __forceinline__ void AddSinkContribution(AttentionVariant variant, State& st,
                                                    const Params& params,
                                                    uint32_t qo_head_idx) {
  if constexpr (variant.use_softmax) {
    if (params.maybe_s_aux != nullptr) {
      float s_aux_val = params.maybe_s_aux[qo_head_idx];
      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
    }
  }
}

} // anonymous namespace

// ... rest of the file

Then you could replace this block and the one in BatchDecodeWithPagedKVCacheDevice with a call to this helper function:

AddSinkContribution(variant, st_local, params, qo_head_idx);

Comment on lines +358 to +365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix s_aux scaling to match logits path.

s[j] is scaled by variant.sm_scale_log2 before softmax. s_aux currently uses (s_aux - m) * LOG2_E, which mismatches and yields incorrect normalization. Scale s_aux with variant.sm_scale_log2 and drop LOG2_E.

Apply:

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
+    }
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E);
}
}
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
st_local.d += math::ptx_exp2(s_aux_scaled - st_local.m);
}
}
🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 358 to 365, the s_aux
contribution is being added using (s_aux - m) * LOG2_E which mismatches the
logits path scaling; change the computation to multiply (s_aux_val - st_local.m)
by variant.sm_scale_log2 and remove LOG2_E so the call becomes
math::ptx_exp2((s_aux_val - st_local.m) * variant.sm_scale_log2); keep the same
null check and use_softmax guard.

#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
st_local.o[i] = variant.OutputTransform(params, st_local.o[i], /*batch_idx=*/0, /*qo_idx=*/0,
Expand Down Expand Up @@ -589,6 +597,14 @@ __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& param
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(variant, st, reinterpret_cast<float*>(smem), smem_md, tx, ty,
tz);
// Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed
if constexpr (variant.use_softmax) {
if (params.maybe_s_aux != nullptr) {
constexpr float LOG2_E = 1.4426950408889634f; // log2(e)
float s_aux_val = params.maybe_s_aux[qo_head_idx];
st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
}
}
Comment on lines +601 to +607
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same scaling fix for batch kernel.

Mirror the s_aux scaling correction here to keep behavior consistent across kernels.

-  if constexpr (variant.use_softmax) {
-    if (params.maybe_s_aux != nullptr) {
-      constexpr float LOG2_E = 1.4426950408889634f;  // log2(e)
-      float s_aux_val = params.maybe_s_aux[qo_head_idx];
-      st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E);
-    }
-  }
+  if constexpr (variant.use_softmax) {
+    if (params.maybe_s_aux != nullptr) {
+      float s_aux_scaled = params.maybe_s_aux[qo_head_idx] * variant.sm_scale_log2;
+      st.d += math::ptx_exp2(s_aux_scaled - st.m);
+    }
+  }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In include/flashinfer/attention/decode.cuh around lines 601-607, the
batch-kernel branch needs the same s_aux scaling fix as the non-batch path:
compute LOG2_E = 1.4426950408889634f, read s_aux_val =
params.maybe_s_aux[qo_head_idx], multiply (s_aux_val - st.m) by LOG2_E and pass
that to math::ptx_exp2, then add the result to st.d so the auxiliary scaling
matches the other kernel.

#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
st.o[i] = variant.OutputTransform(params, st.o[i], bx, /*qo_idx=*/0, qo_head_idx, st.m, st.d,
Expand Down
6 changes: 6 additions & 0 deletions include/flashinfer/attention/default_decode_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct SingleDecodeParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint32_t kv_len;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
Expand All @@ -58,6 +59,7 @@ struct SingleDecodeParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
kv_len(0),
num_qo_heads(0),
num_kv_heads(0),
Expand All @@ -84,6 +86,7 @@ struct SingleDecodeParams {
o(o),
lse(nullptr),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(nullptr),
kv_len(seq_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
Expand Down Expand Up @@ -118,6 +121,7 @@ struct BatchDecodeParams {
DTypeO* o;
float* lse;
float* maybe_alibi_slopes;
float* maybe_s_aux;
uint32_t padded_batch_size;
uint32_t num_qo_heads;
IdType q_stride_n;
Expand All @@ -142,6 +146,7 @@ struct BatchDecodeParams {
o(nullptr),
lse(nullptr),
maybe_alibi_slopes(nullptr),
maybe_s_aux(nullptr),
padded_batch_size(0),
num_qo_heads(0),
q_stride_n(0),
Expand Down Expand Up @@ -170,6 +175,7 @@ struct BatchDecodeParams {
o(o),
lse(lse),
maybe_alibi_slopes(maybe_alibi_slopes),
maybe_s_aux(nullptr),
padded_batch_size(0),
num_qo_heads(num_qo_heads),
q_stride_n(q_stride_n),
Expand Down
Loading