Skip to content

Commit ec4797a

Browse files
committed
Enable FlashInfer Hopper FP8 attention
Signed-off-by: Po-Han Huang <[email protected]>
1 parent be493e0 commit ec4797a

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ def __init__(
504504
if can_use_trtllm and not flashinfer_disable_q_quantization():
505505
self.q_data_type = self.kv_cache_dtype
506506
else:
507-
self.q_data_type = self.model_config.dtype
507+
# self.q_data_type = self.model_config.dtype
508+
self.q_data_type = self.kv_cache_dtype
508509

509510
# Prefer TRTLLM attention for decoding in all cases.
510511
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
@@ -803,7 +804,7 @@ def build(
803804

804805
# The q quantization is not supported for non-trtllm attention,
805806
# fall back to model dtype.
806-
self.q_data_type = self.model_config.dtype
807+
# self.q_data_type = self.model_config.dtype
807808

808809
attn_metadata = FlashInferMetadata(
809810
num_actual_tokens=num_actual_tokens,
@@ -917,6 +918,7 @@ def build(
917918
logits_soft_cap=self.logits_soft_cap,
918919
q_data_type=self.q_data_type,
919920
kv_data_type=self.kv_cache_dtype,
921+
o_data_type=self.model_config.dtype,
920922
fixed_split_size=self.prefill_fixed_split_size,
921923
disable_split_kv=self.disable_split_kv,
922924
)
@@ -961,6 +963,7 @@ def build(
961963
logits_soft_cap=self.logits_soft_cap,
962964
q_data_type=self.q_data_type,
963965
kv_data_type=self.kv_cache_dtype,
966+
o_data_type=self.model_config.dtype,
964967
fixed_split_size=self.decode_fixed_split_size,
965968
disable_split_kv=self.disable_split_kv,
966969
)
@@ -1047,7 +1050,7 @@ def supports_quant_query_input(self) -> bool:
10471050
if flashinfer_disable_q_quantization():
10481051
return False
10491052

1050-
return self.support_trtllm_attn
1053+
return self.kv_cache_dtype.startswith("fp8")
10511054

10521055
# FlashInfer requires attention sinks to be float32
10531056
def process_weights_after_loading(self, act_dtype: torch.dtype):
@@ -1235,6 +1238,7 @@ def forward(
12351238
prefill_wrapper.run(
12361239
prefill_query,
12371240
kv_cache_permute,
1241+
q_scale=layer._q_scale_float,
12381242
k_scale=layer._k_scale_float,
12391243
v_scale=layer._v_scale_float,
12401244
out=output[num_decode_tokens:],
@@ -1328,6 +1332,7 @@ def forward(
13281332
decode_wrapper.run(
13291333
decode_query,
13301334
kv_cache_permute,
1335+
q_scale=layer._q_scale_float,
13311336
k_scale=layer._k_scale_float,
13321337
v_scale=layer._v_scale_float,
13331338
out=output_tmp,
@@ -1341,6 +1346,7 @@ def forward(
13411346
decode_wrapper.run(
13421347
decode_query,
13431348
kv_cache_permute,
1349+
q_scale=layer._q_scale_float,
13441350
k_scale=layer._k_scale_float,
13451351
v_scale=layer._v_scale_float,
13461352
out=output[:num_decode_tokens],
@@ -1414,6 +1420,7 @@ def fast_plan_decode(
14141420
logits_soft_cap: float | None = None,
14151421
q_data_type: str | torch.dtype | None = "float16",
14161422
kv_data_type: str | torch.dtype | None = None,
1423+
o_data_type: str | torch.dtype | None = None,
14171424
data_type: str | torch.dtype | None = None,
14181425
sm_scale: float | None = None,
14191426
rope_scale: float | None = None,
@@ -1452,6 +1459,7 @@ def fast_plan_decode(
14521459
logits_soft_cap,
14531460
q_data_type,
14541461
kv_data_type,
1462+
o_data_type,
14551463
data_type,
14561464
sm_scale,
14571465
rope_scale,
@@ -1471,24 +1479,6 @@ def fast_plan_decode(
14711479
if logits_soft_cap is None:
14721480
logits_soft_cap = 0.0
14731481

1474-
# Handle data types consistently
1475-
if data_type is not None:
1476-
if q_data_type is None:
1477-
q_data_type = data_type
1478-
if kv_data_type is None:
1479-
kv_data_type = data_type
1480-
elif q_data_type is None:
1481-
q_data_type = "float16"
1482-
1483-
if kv_data_type is None:
1484-
kv_data_type = q_data_type
1485-
q_data_type = (
1486-
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1487-
)
1488-
kv_data_type = (
1489-
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
1490-
)
1491-
14921482
if batch_size != self._fixed_batch_size:
14931483
raise ValueError(
14941484
"The batch size should be fixed in cudagraph mode, the runtime "
@@ -1508,8 +1498,8 @@ def fast_plan_decode(
15081498
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
15091499

15101500
try:
1511-
# Make sure we pass exactly 19 arguments for tensor core version
1512-
self._plan_info = self._cached_module.plan(
1501+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
1502+
args = [
15131503
self._float_workspace_buffer,
15141504
self._int_workspace_buffer,
15151505
self._pin_memory_int_workspace_buffer,
@@ -1526,9 +1516,13 @@ def fast_plan_decode(
15261516
head_dim,
15271517
False, # causal
15281518
window_left,
1529-
fixed_split_size,
1530-
disable_split_kv,
1531-
0,
1519+
]
1520+
if self._backend == "fa2":
1521+
args.append(fixed_split_size)
1522+
args.append(disable_split_kv)
1523+
args.append(0) # num_colocated_ctas
1524+
self._plan_info = self._cached_module.plan(
1525+
*args,
15321526
)
15331527
except Exception as e:
15341528
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)