Skip to content

Commit 09a1ece

Browse files
committed
Enable Hopper FA3 FP8 attention
Signed-off-by: Po-Han Huang <[email protected]>
1 parent dc0ade7 commit 09a1ece

File tree

3 files changed

+76
-23
lines changed

3 files changed

+76
-23
lines changed

flashinfer/decode.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
_get_range_buf,
6464
_unpack_paged_kv_cache,
6565
canonicalize_torch_dtype,
66+
determine_attention_backend,
6667
device_support_pdl,
6768
get_device_sm_count,
6869
is_float8,
@@ -721,7 +722,7 @@ def __init__(
721722
self._jit_module = get_batch_prefill_jit_module(
722723
jit_args[0],
723724
gen_customize_batch_prefill_module(
724-
"fa2", *jit_args
725+
backend, *jit_args
725726
).build_and_load(),
726727
)
727728
else:
@@ -834,6 +835,7 @@ def plan(
834835
logits_soft_cap: Optional[float] = None,
835836
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
836837
kv_data_type: Optional[Union[str, torch.dtype]] = None,
838+
o_data_type: Optional[Union[str, torch.dtype]] = None,
837839
data_type: Optional[Union[str, torch.dtype]] = None,
838840
sm_scale: Optional[float] = None,
839841
rope_scale: Optional[float] = None,
@@ -881,6 +883,9 @@ def plan(
881883
kv_data_type : Optional[Union[str, torch.dtype]]
882884
The data type of the key/value tensor. If None, will be set to
883885
``q_data_type``. Defaults to ``None``.
886+
o_data_type : Optional[Union[str, torch.dtype]]
887+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
888+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
884889
data_type: Optional[Union[str, torch.dtype]]
885890
The data type of both the query and key/value tensors. Defaults to torch.float16.
886891
data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -966,6 +971,10 @@ def plan(
966971
if kv_data_type is None:
967972
kv_data_type = q_data_type
968973
kv_data_type = canonicalize_torch_dtype(kv_data_type)
974+
if o_data_type is None:
975+
o_data_type = q_data_type
976+
o_data_type = canonicalize_torch_dtype(o_data_type)
977+
969978
if fixed_split_size is not None and not self.use_tensor_cores:
970979
raise ValueError(
971980
"fixed_split_size is only supported by tensor core decode for now."
@@ -975,6 +984,7 @@ def plan(
975984

976985
self._cached_q_data_type = q_data_type
977986
self._cached_kv_data_type = kv_data_type
987+
self._cached_o_data_type = o_data_type
978988
self._batch_size = batch_size
979989
self._num_qo_heads = num_qo_heads
980990
self._num_kv_heads = num_kv_heads
@@ -1014,7 +1024,7 @@ def plan(
10141024
self._cached_module = get_trtllm_gen_decode_module(
10151025
q_data_type,
10161026
kv_data_type,
1017-
q_data_type,
1027+
o_data_type,
10181028
indptr.dtype,
10191029
head_dim,
10201030
head_dim,
@@ -1029,11 +1039,20 @@ def plan(
10291039
if self._jit_module is not None:
10301040
self._cached_module = self._jit_module
10311041
else:
1042+
if self._backend == "auto":
1043+
self._backend = determine_attention_backend(
1044+
self.device,
1045+
PosEncodingMode[pos_encoding_mode].value,
1046+
False, # use_fp16_qk_reduction
1047+
False, # use_custom_mask
1048+
q_data_type,
1049+
kv_data_type,
1050+
)
10321051
self._cached_module = get_batch_prefill_module(
1033-
"fa2",
1052+
self._backend,
10341053
q_data_type,
10351054
kv_data_type,
1036-
q_data_type,
1055+
o_data_type,
10371056
indptr.dtype,
10381057
head_dim, # head_dim_qk
10391058
head_dim, # head_dim_vo
@@ -1043,7 +1062,7 @@ def plan(
10431062
False, # use_fp16_qk_reduction
10441063
)
10451064

1046-
self._plan_info = self._cached_module.plan(
1065+
args = [
10471066
self._float_workspace_buffer,
10481067
self._int_workspace_buffer,
10491068
self._pin_memory_int_workspace_buffer,
@@ -1060,9 +1079,13 @@ def plan(
10601079
head_dim,
10611080
False, # causal
10621081
window_left,
1063-
fixed_split_size,
1064-
disable_split_kv,
1065-
0, # num_colocated_ctas
1082+
]
1083+
if self._backend == "fa2":
1084+
args.append(fixed_split_size)
1085+
args.append(disable_split_kv)
1086+
args.append(0) # num_colocated_ctas
1087+
self._plan_info = self._cached_module.plan(
1088+
*args,
10661089
)
10671090
else:
10681091
if self._jit_module is not None:
@@ -1071,7 +1094,7 @@ def plan(
10711094
self._cached_module = get_batch_decode_module(
10721095
q_data_type,
10731096
kv_data_type,
1074-
q_data_type,
1097+
o_data_type,
10751098
indptr.dtype,
10761099
head_dim, # head_dim_qk
10771100
head_dim, # head_dim_vo
@@ -1281,9 +1304,13 @@ def run(
12811304
)
12821305

12831306
if out is None:
1284-
out = torch.empty_like(q)
1307+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1308+
out = torch.empty(
1309+
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
1310+
)
12851311
else:
1286-
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1312+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1313+
check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")
12871314

12881315
if self._backend == "trtllm-gen":
12891316
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
@@ -1311,6 +1338,14 @@ def run(
13111338
if self._jit_module is not None:
13121339
run_args.extend(list(args))
13131340
else:
1341+
# Extract FP8 scale tensors from *args if q is FP8
1342+
fp8_scale_q = None
1343+
fp8_scale_k = None
1344+
fp8_scale_v = None
1345+
if is_float8(q) and len(args) >= 3:
1346+
fp8_scale_q = args[0]
1347+
fp8_scale_k = args[1]
1348+
fp8_scale_v = args[2]
13141349
run_args += [
13151350
None, # packed_custom_mask
13161351
None, # mask_indptr_buf
@@ -1320,9 +1355,9 @@ def run(
13201355
None, # maybe_max_item_len_ptr
13211356
logits_soft_cap,
13221357
sm_scale,
1323-
None, # scale_q, not supported yet
1324-
None, # scale_k
1325-
None, # scale_v
1358+
fp8_scale_q,
1359+
fp8_scale_k,
1360+
fp8_scale_v,
13261361
rope_scale,
13271362
rope_theta,
13281363
0, # token_pos_in_items_len
@@ -1375,7 +1410,7 @@ def run(
13751410
]
13761411

13771412
self._cached_module.run(*run_args)
1378-
if v_scale is not None:
1413+
if v_scale is not None and v_scale != 1.0:
13791414
# TODO(Zihao): fused into kernel
13801415
if is_float8(out):
13811416
out = (out.to(torch.float32) * v_scale).to(out.dtype)
@@ -2597,8 +2632,8 @@ def fast_decode_plan(
25972632
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
25982633

25992634
try:
2600-
# Make sure we pass exactly 16 arguments for tensor core version
2601-
self._plan_info = self._cached_module.plan(
2635+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2636+
args = [
26022637
self._float_workspace_buffer,
26032638
self._int_workspace_buffer,
26042639
self._pin_memory_int_workspace_buffer,
@@ -2615,9 +2650,13 @@ def fast_decode_plan(
26152650
head_dim,
26162651
False, # causal
26172652
window_left,
2618-
fixed_split_size,
2619-
disable_split_kv,
2620-
0, # num_colocated_ctas
2653+
]
2654+
if self._backend == "fa2":
2655+
args.append(fixed_split_size)
2656+
args.append(disable_split_kv)
2657+
args.append(0) # num_colocated_ctas
2658+
self._plan_info = self._cached_module.plan(
2659+
*args,
26212660
)
26222661
except Exception as e:
26232662
raise RuntimeError(f"Error in standard plan: {e}") from e

flashinfer/jit/attention/modules.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,13 @@ def gen_batch_prefill_module(
984984
# KV-only quant is not influenced by this flag
985985
fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]
986986

987+
assert backend in ["fa2", "fa3"], (
988+
f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
989+
)
990+
assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], (
991+
"FP8 output is not supported in fa2/fa3 backends yet"
992+
)
993+
987994
if backend == "fa2":
988995
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
989996
additional_tensor_names = [

flashinfer/prefill.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,7 +1701,7 @@ def plan(
17011701
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
17021702
o_data_type : Optional[Union[str, torch.dtype]]
17031703
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1704-
For FP8 inputs, this should typically be set to torch.float16.
1704+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
17051705
non_blocking : bool
17061706
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
17071707
prefix_len_ptr :Optional[torch.Tensor]
@@ -2077,6 +2077,8 @@ def run(
20772077
20782078
*args
20792079
Additional arguments for custom kernels.
2080+
q_scale : Optional[float]
2081+
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
20802082
k_scale : Optional[float]
20812083
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
20822084
v_scale : Optional[float]
@@ -2105,6 +2107,11 @@ def run(
21052107
_check_cached_qkv_data_type(
21062108
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
21072109
)
2110+
o_dtype = self._cached_o_data_type
2111+
if out is not None and out.dtype != o_dtype:
2112+
raise ValueError(
2113+
f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function."
2114+
)
21082115

21092116
if self._kv_layout == "NHD":
21102117
page_size = k_cache.shape[1]
@@ -2258,7 +2265,7 @@ def run(
22582265

22592266
assert self._cached_module is not None, "cached module is not initialized"
22602267
self._cached_module.paged_run(*run_args)
2261-
if v_scale is not None:
2268+
if v_scale is not None and v_scale != 1.0:
22622269
# TODO(Zihao): fused into kernel
22632270
if is_float8(out):
22642271
out = (out.to(torch.float32) * v_scale).to(out.dtype)
@@ -2646,7 +2653,7 @@ def plan(
26462653
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
26472654
o_data_type : Optional[Union[str, torch.dtype]]
26482655
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2649-
For FP8 inputs, this should typically be set to torch.float16.
2656+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
26502657
non_blocking : bool
26512658
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
26522659
prefix_len_ptr :Optional[torch.Tensor]

0 commit comments

Comments
 (0)