Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 59 additions & 20 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
_get_range_buf,
_unpack_paged_kv_cache,
canonicalize_torch_dtype,
determine_attention_backend,
device_support_pdl,
get_device_sm_count,
is_float8,
Expand Down Expand Up @@ -721,7 +722,7 @@ def __init__(
self._jit_module = get_batch_prefill_jit_module(
jit_args[0],
gen_customize_batch_prefill_module(
"fa2", *jit_args
backend, *jit_args
).build_and_load(),
)
else:
Expand Down Expand Up @@ -834,6 +835,7 @@ def plan(
logits_soft_cap: Optional[float] = None,
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
kv_data_type: Optional[Union[str, torch.dtype]] = None,
o_data_type: Optional[Union[str, torch.dtype]] = None,
data_type: Optional[Union[str, torch.dtype]] = None,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
Expand Down Expand Up @@ -881,6 +883,9 @@ def plan(
kv_data_type : Optional[Union[str, torch.dtype]]
The data type of the key/value tensor. If None, will be set to
``q_data_type``. Defaults to ``None``.
o_data_type : Optional[Union[str, torch.dtype]]
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
data_type: Optional[Union[str, torch.dtype]]
The data type of both the query and key/value tensors. Defaults to torch.float16.
data_type is deprecated, please use q_data_type and kv_data_type instead.
Expand Down Expand Up @@ -966,6 +971,10 @@ def plan(
if kv_data_type is None:
kv_data_type = q_data_type
kv_data_type = canonicalize_torch_dtype(kv_data_type)
if o_data_type is None:
o_data_type = q_data_type
o_data_type = canonicalize_torch_dtype(o_data_type)

if fixed_split_size is not None and not self.use_tensor_cores:
raise ValueError(
"fixed_split_size is only supported by tensor core decode for now."
Expand All @@ -975,6 +984,7 @@ def plan(

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
self._cached_o_data_type = o_data_type
self._batch_size = batch_size
self._num_qo_heads = num_qo_heads
self._num_kv_heads = num_kv_heads
Expand Down Expand Up @@ -1014,7 +1024,7 @@ def plan(
self._cached_module = get_trtllm_gen_decode_module(
q_data_type,
kv_data_type,
q_data_type,
o_data_type,
indptr.dtype,
head_dim,
head_dim,
Expand All @@ -1029,11 +1039,20 @@ def plan(
if self._jit_module is not None:
self._cached_module = self._jit_module
else:
if self._backend == "auto":
self._backend = determine_attention_backend(
self.device,
PosEncodingMode[pos_encoding_mode].value,
False, # use_fp16_qk_reduction
False, # use_custom_mask
q_data_type,
kv_data_type,
)
self._cached_module = get_batch_prefill_module(
"fa2",
self._backend,
q_data_type,
kv_data_type,
q_data_type,
o_data_type,
indptr.dtype,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
Expand All @@ -1043,7 +1062,7 @@ def plan(
False, # use_fp16_qk_reduction
)

self._plan_info = self._cached_module.plan(
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
Expand All @@ -1060,9 +1079,13 @@ def plan(
head_dim,
False, # causal
window_left,
fixed_split_size,
disable_split_kv,
0, # num_colocated_ctas
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
else:
if self._jit_module is not None:
Expand All @@ -1071,7 +1094,7 @@ def plan(
self._cached_module = get_batch_decode_module(
q_data_type,
kv_data_type,
q_data_type,
o_data_type,
indptr.dtype,
head_dim, # head_dim_qk
head_dim, # head_dim_vo
Expand Down Expand Up @@ -1281,9 +1304,13 @@ def run(
)

if out is None:
out = torch.empty_like(q)
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
out = torch.empty(
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
)
else:
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")

if self._backend == "trtllm-gen":
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
Expand Down Expand Up @@ -1311,6 +1338,14 @@ def run(
if self._jit_module is not None:
run_args.extend(list(args))
else:
# Extract FP8 scale tensors from *args if q is FP8
fp8_scale_q = None
fp8_scale_k = None
fp8_scale_v = None
if is_float8(q) and len(args) >= 3:
fp8_scale_q = args[0]
fp8_scale_k = args[1]
fp8_scale_v = args[2]
run_args += [
None, # packed_custom_mask
None, # mask_indptr_buf
Expand All @@ -1320,9 +1355,9 @@ def run(
None, # maybe_max_item_len_ptr
logits_soft_cap,
sm_scale,
None, # scale_q, not supported yet
None, # scale_k
None, # scale_v
fp8_scale_q,
fp8_scale_k,
fp8_scale_v,
rope_scale,
rope_theta,
0, # token_pos_in_items_len
Expand Down Expand Up @@ -1375,7 +1410,7 @@ def run(
]

self._cached_module.run(*run_args)
if v_scale is not None:
if v_scale is not None and v_scale != 1.0:
# TODO(Zihao): fused into kernel
if is_float8(out):
out = (out.to(torch.float32) * v_scale).to(out.dtype)
Expand Down Expand Up @@ -2597,8 +2632,8 @@ def fast_decode_plan(
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)

try:
# Make sure we pass exactly 16 arguments for tensor core version
self._plan_info = self._cached_module.plan(
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
Expand All @@ -2615,9 +2650,13 @@ def fast_decode_plan(
head_dim,
False, # causal
window_left,
fixed_split_size,
disable_split_kv,
0, # num_colocated_ctas
]
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
)
except Exception as e:
raise RuntimeError(f"Error in standard plan: {e}") from e
Expand Down
7 changes: 7 additions & 0 deletions flashinfer/jit/attention/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,13 @@ def gen_batch_prefill_module(
# KV-only quant is not influenced by this flag
fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]

assert backend in ["fa2", "fa3"], (
f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
)
assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], (
"FP8 output is not supported in fa2/fa3 backends yet"
)

if backend == "fa2":
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
additional_tensor_names = [
Expand Down
13 changes: 10 additions & 3 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ def plan(
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
o_data_type : Optional[Union[str, torch.dtype]]
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
For FP8 inputs, this should typically be set to torch.float16.
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
prefix_len_ptr :Optional[torch.Tensor]
Expand Down Expand Up @@ -2077,6 +2077,8 @@ def run(

*args
Additional arguments for custom kernels.
q_scale : Optional[float]
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
k_scale : Optional[float]
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
v_scale : Optional[float]
Expand Down Expand Up @@ -2105,6 +2107,11 @@ def run(
_check_cached_qkv_data_type(
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
)
o_dtype = self._cached_o_data_type
if out is not None and out.dtype != o_dtype:
raise ValueError(
f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function."
)

if self._kv_layout == "NHD":
page_size = k_cache.shape[1]
Expand Down Expand Up @@ -2258,7 +2265,7 @@ def run(

assert self._cached_module is not None, "cached module is not initialized"
self._cached_module.paged_run(*run_args)
if v_scale is not None:
if v_scale is not None and v_scale != 1.0:
# TODO(Zihao): fused into kernel
if is_float8(out):
out = (out.to(torch.float32) * v_scale).to(out.dtype)
Expand Down Expand Up @@ -2646,7 +2653,7 @@ def plan(
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
o_data_type : Optional[Union[str, torch.dtype]]
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
For FP8 inputs, this should typically be set to torch.float16.
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
non_blocking : bool
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
prefix_len_ptr :Optional[torch.Tensor]
Expand Down