Skip to content

Commit bc47808

Browse files
committed
Enable FlashInfer Hopper FP8 attention
Signed-off-by: Po-Han Huang <[email protected]>
1 parent 446ee64 commit bc47808

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,8 @@ def __init__(
514514
):
515515
self.q_data_type = self.kv_cache_dtype
516516
else:
517-
self.q_data_type = self.model_config.dtype
517+
# self.q_data_type = self.model_config.dtype
518+
self.q_data_type = self.kv_cache_dtype
518519

519520
# Prefer TRTLLM attention for decoding in all cases.
520521
# This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
@@ -814,7 +815,7 @@ def build(
814815

815816
# The q quantization is not supported for non-trtllm attention,
816817
# fall back to model dtype.
817-
self.q_data_type = self.model_config.dtype
818+
# self.q_data_type = self.model_config.dtype
818819

819820
attn_metadata = FlashInferMetadata(
820821
num_actual_tokens=num_actual_tokens,
@@ -928,6 +929,7 @@ def build(
928929
logits_soft_cap=self.logits_soft_cap,
929930
q_data_type=self.q_data_type,
930931
kv_data_type=self.kv_cache_dtype,
932+
o_data_type=self.model_config.dtype,
931933
fixed_split_size=self.prefill_fixed_split_size,
932934
disable_split_kv=self.disable_split_kv,
933935
)
@@ -972,6 +974,7 @@ def build(
972974
logits_soft_cap=self.logits_soft_cap,
973975
q_data_type=self.q_data_type,
974976
kv_data_type=self.kv_cache_dtype,
977+
o_data_type=self.model_config.dtype,
975978
fixed_split_size=self.decode_fixed_split_size,
976979
disable_split_kv=self.disable_split_kv,
977980
)
@@ -1045,7 +1048,7 @@ def __init__(
10451048
self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads)
10461049
vllm_config = get_current_vllm_config()
10471050
self.supports_quant_query_input = (
1048-
self.support_trtllm_attn
1051+
self.kv_cache_dtype.startswith("fp8")
10491052
and not vllm_config.attention_config.disable_flashinfer_q_quantization
10501053
)
10511054
self.bmm1_scale: float | None = None
@@ -1245,6 +1248,7 @@ def forward(
12451248
prefill_wrapper.run(
12461249
prefill_query,
12471250
kv_cache_permute,
1251+
q_scale=layer._q_scale_float,
12481252
k_scale=layer._k_scale_float,
12491253
v_scale=layer._v_scale_float,
12501254
out=output[num_decode_tokens:],
@@ -1338,6 +1342,7 @@ def forward(
13381342
decode_wrapper.run(
13391343
decode_query,
13401344
kv_cache_permute,
1345+
q_scale=layer._q_scale_float,
13411346
k_scale=layer._k_scale_float,
13421347
v_scale=layer._v_scale_float,
13431348
out=output_tmp,
@@ -1354,6 +1359,7 @@ def forward(
13541359
decode_wrapper.run(
13551360
decode_query,
13561361
kv_cache_permute,
1362+
q_scale=layer._q_scale_float,
13571363
k_scale=layer._k_scale_float,
13581364
v_scale=layer._v_scale_float,
13591365
out=output[:num_decode_tokens],
@@ -1427,6 +1433,7 @@ def fast_plan_decode(
14271433
logits_soft_cap: float | None = None,
14281434
q_data_type: str | torch.dtype | None = "float16",
14291435
kv_data_type: str | torch.dtype | None = None,
1436+
o_data_type: str | torch.dtype | None = None,
14301437
data_type: str | torch.dtype | None = None,
14311438
sm_scale: float | None = None,
14321439
rope_scale: float | None = None,
@@ -1465,6 +1472,7 @@ def fast_plan_decode(
14651472
logits_soft_cap,
14661473
q_data_type,
14671474
kv_data_type,
1475+
o_data_type,
14681476
data_type,
14691477
sm_scale,
14701478
rope_scale,
@@ -1484,24 +1492,6 @@ def fast_plan_decode(
14841492
if logits_soft_cap is None:
14851493
logits_soft_cap = 0.0
14861494

1487-
# Handle data types consistently
1488-
if data_type is not None:
1489-
if q_data_type is None:
1490-
q_data_type = data_type
1491-
if kv_data_type is None:
1492-
kv_data_type = data_type
1493-
elif q_data_type is None:
1494-
q_data_type = "float16"
1495-
1496-
if kv_data_type is None:
1497-
kv_data_type = q_data_type
1498-
q_data_type = (
1499-
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
1500-
)
1501-
kv_data_type = (
1502-
getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type
1503-
)
1504-
15051495
if batch_size != self._fixed_batch_size:
15061496
raise ValueError(
15071497
"The batch size should be fixed in cudagraph mode, the runtime "
@@ -1521,8 +1511,9 @@ def fast_plan_decode(
15211511
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
15221512

15231513
try:
1524-
# Make sure we pass exactly 19 arguments for tensor core version
1525-
self._plan_info = self._cached_module.plan(
1514+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for
1515+
# fa3 backend
1516+
args = [
15261517
self._float_workspace_buffer,
15271518
self._int_workspace_buffer,
15281519
self._pin_memory_int_workspace_buffer,
@@ -1539,9 +1530,13 @@ def fast_plan_decode(
15391530
head_dim,
15401531
False, # causal
15411532
window_left,
1542-
fixed_split_size,
1543-
disable_split_kv,
1544-
0,
1533+
]
1534+
if self._backend == "fa2":
1535+
args.append(fixed_split_size)
1536+
args.append(disable_split_kv)
1537+
args.append(0) # num_colocated_ctas
1538+
self._plan_info = self._cached_module.plan(
1539+
*args,
15451540
)
15461541
except Exception as e:
15471542
raise RuntimeError(f"Error in tensor core plan: {e}") from e

0 commit comments

Comments
 (0)