Skip to content

Commit bfabf86

Browse files
committed
Enable XQA kernels on Hopper
Signed-off-by: Po-Han Huang <[email protected]>
1 parent 8d18837 commit bfabf86

File tree

2 files changed

+344
-202
lines changed

2 files changed

+344
-202
lines changed

vllm/utils/flashinfer.py

Lines changed: 150 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
import shutil
1414
from collections.abc import Callable
15-
from typing import Any, NoReturn
15+
from typing import Any, NoReturn, Tuple
1616

1717
import requests
1818
import torch
@@ -254,17 +254,82 @@ def has_nvidia_artifactory() -> bool:
254254

255255

256256
@functools.cache
257-
def supports_trtllm_attention() -> bool:
257+
def is_sm90_supported() -> bool:
258+
return current_platform.is_device_capability(90)
259+
260+
261+
@functools.cache
262+
def is_sm100f_supported() -> bool:
263+
return any(current_platform.is_device_capability(cap) for cap in [100, 103])
264+
265+
266+
@functools.cache
267+
def check_trtllm_attention_support(
268+
is_prefill: bool,
269+
num_qo_heads: int | None = None,
270+
num_kv_heads: int | None = None,
271+
dcp_world_size: int | None = None,
272+
kv_cache_dtype: str | None = None,
273+
q_data_type: torch.dtype | None = None,
274+
has_sinks: bool | None = None,
275+
has_spec: bool | None = None,
276+
) -> Tuple[bool, str]:
258277
"""
259-
TRTLLM attention is supported if the platform is SM100,
260-
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
278+
Check if the provided config + current platform is supported by TRTLLM attention.
279+
280+
Args:
281+
is_prefill: Whether it is prefill.
282+
num_qo_heads: Number of query heads.
283+
num_kv_heads: Number of key/value heads.
284+
dcp_world_size: World size of decode context parallel.
285+
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
286+
q_dtype: Data type of the query.
287+
has_sinks: Whether sinks are being used.
288+
has_spec: Whether speculative decoding is being used.
289+
290+
If any args (except for is_prefill) are set to None, the check for that arg is skipped.
291+
292+
Returns:
293+
A tuple of (bool, str). If the bool is:
294+
- True: TRTLLM attention must be used.
295+
- False: TRTLLM attention must not be used.
296+
- None: TRTLLM attention can be used.
297+
The str is the reason why it must or must not be used. Empty string if can be used.
261298
"""
262-
# Batch-invariant mode disables TRTLLM attention
299+
263300
if vllm_is_batch_invariant():
264-
return False
301+
return False, "Batch-invariant mode is enabled."
302+
303+
if not has_nvidia_artifactory():
304+
return False, "NVIDIA artifactory is not accessible."
305+
306+
if is_sm90_supported():
307+
if is_prefill:
308+
return False, "SM90 is not supported for prefill."
309+
if q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
310+
return False, "xqa does not support FP8-Q."
311+
elif is_sm100f_supported():
312+
if is_prefill and \
313+
kv_cache_dtype is not None and \
314+
not kv_cache_dtype.startswith("fp8") and \
315+
q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
316+
return False, "trtllm-gen prefill does not support FP8-Q with BF16/FP16-Q."
317+
else:
318+
return False, "SMs other than 90/100/103 are not supported."
319+
320+
if dcp_world_size is not None and dcp_world_size > 1:
321+
return False, "DCP is not supported due to lack of LSE return support."
322+
323+
if num_qo_heads is not None and num_kv_heads is not None and num_qo_heads % num_kv_heads != 0:
324+
return False, "num_qo_heads must be a multiple of num_kv_heads."
325+
326+
if has_spec and not is_prefill:
327+
return True, "Has speculative decoding in decode phase."
328+
329+
if has_sinks:
330+
return True, "Has attention sinks."
265331

266-
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
267-
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
332+
return None, ""
268333

269334

270335
def force_use_trtllm_attention() -> bool | None:
@@ -281,96 +346,90 @@ def force_use_trtllm_attention() -> bool | None:
281346
return vllm_config.attention_config.use_trtllm_attention
282347

283348

284-
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
285-
"""Check if the current configuration supports TRTLLM attention."""
286-
if force_use_trtllm_attention() is False:
287-
return False
288-
has_trtllm = supports_trtllm_attention()
289-
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
290-
291-
292349
def use_trtllm_attention(
293-
num_qo_heads: int,
294-
num_kv_heads: int,
295-
num_tokens: int,
296-
max_seq_len: int,
297-
dcp_world_size: int,
298-
kv_cache_dtype: str,
299-
q_dtype: torch.dtype,
300350
is_prefill: bool,
301-
# None means auto-detection, True means force on, False means force off
302-
force_use_trtllm: bool | None = None,
303-
has_sinks: bool = False,
304-
has_spec: bool = False,
351+
num_qo_heads: int | None = None,
352+
num_kv_heads: int | None = None,
353+
dcp_world_size: int | None = None,
354+
kv_cache_dtype: str | None = None,
355+
q_data_type: torch.dtype | None = None,
356+
has_sinks: bool | None = None,
357+
has_spec: bool | None = None,
358+
silent: bool = False,
305359
) -> bool:
306-
"""Return `True` if TRTLLM attention is used."""
307-
308-
# CLI argument is set to 0 - respect it
309-
if force_use_trtllm is not None and not force_use_trtllm:
310-
return False
311-
312-
# Decode context parallel is not supported
313-
if dcp_world_size > 1:
314-
logger.warning_once(
315-
"Trtllm does not support returning LSE and as a result "
316-
"does not support DCP, reverting to FlashInfer"
317-
)
318-
return False
319-
320-
# The platform is not supported
321-
if not supports_trtllm_attention():
322-
if force_use_trtllm:
323-
logger.warning_once(
324-
"TRTLLM attention is not supported on this platform, "
325-
"but --attention-config.use_trtllm_attention is set to 1"
360+
"""
361+
Decides whether to use TRTLLM attention based on these two functions:
362+
- check_trtllm_attention_support(): whether TRTLLM attention must or must not be used.
363+
- force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM attention.
364+
If the decision does not match the user's preference, print the warning messages.
365+
366+
Args:
367+
is_prefill: Whether it is prefill.
368+
num_qo_heads: Number of query heads.
369+
num_kv_heads: Number of key/value heads.
370+
dcp_world_size: World size of decode context parallel.
371+
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
372+
q_data_type: Data type of the query.
373+
has_sinks: Whether sinks are being used.
374+
has_spec: Whether speculative decoding is being used.
375+
silent: Whether to print the warning/info messages.
376+
377+
If any args (except for is_prefill) are set to None, the check for that arg is skipped.
378+
379+
Returns: whether to use TRTLLM attention.
380+
"""
381+
supports_trtllm, reason = check_trtllm_attention_support(
382+
is_prefill, num_qo_heads, num_kv_heads, dcp_world_size, kv_cache_dtype, q_data_type, has_sinks, has_spec
383+
)
384+
force_use_trtllm = force_use_trtllm_attention()
385+
phase_str = "prefill" if is_prefill else "decode"
386+
prefix = "[FlashInfer Attention]"
387+
388+
# Helper functions to print warning/info if not silent.
389+
def print_warning(msg: str):
390+
if not silent:
391+
logger.warning_once(msg)
392+
393+
def print_info(msg: str):
394+
if not silent:
395+
logger.info_once(msg)
396+
397+
# Follow users' preference if supports_trtllm is None.
398+
if supports_trtllm is None:
399+
if force_use_trtllm is True:
400+
print_info(
401+
f"{prefix} Using TRTLLM for {phase_str} (--attention-config.use_trtllm_attention is set to 1)."
326402
)
327-
return False
328-
329-
# The combination of query and key heads is not supported
330-
if num_qo_heads % num_kv_heads != 0:
331-
if force_use_trtllm:
332-
logger.warning_once(
333-
"TRTLLM attention is not supported for this combination of "
334-
"query and key heads, but --attention-config.use_trtllm_attention is "
335-
"set to 1"
403+
return True
404+
elif force_use_trtllm is False:
405+
print_info(
406+
f"{prefix} Using non-TRTLLM for {phase_str} (--attention-config.use_trtllm_attention is set to 0)."
407+
)
408+
return False
409+
else:
410+
print_info(
411+
f"{prefix} Using TRTLLM for {phase_str} (auto-detected)."
412+
)
413+
return True
414+
# Print warning if supports_trtllm does not match force_use_trtllm.
415+
elif supports_trtllm is False:
416+
if force_use_trtllm is True:
417+
print_warning(
418+
f"{prefix} Using non-TRTLLM for {phase_str} even though --attention-config.use_trtllm_attention is set to 1. Reason: {reason}"
419+
)
420+
else:
421+
print_info(
422+
f"{prefix} Using non-TRTLLM for {phase_str}. Reason: {reason}"
336423
)
337424
return False
338-
339-
if has_spec and not is_prefill:
340-
# Speculative decoding requires TRTLLM attention for decodes
341-
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
342-
return True
343-
344-
# Must use TRTLLM attention if query is FP8 quantized
345-
if q_dtype == current_platform.fp8_dtype():
346-
logger.info_once("Using TRTLLM attention (query is quantized).")
347-
return True
348-
349-
# If sinks are being used, we must use TRTLLM attention as it's
350-
# the only backend that supports them
351-
if has_sinks:
352-
logger.info_once("Using TRTLLM attention (required for attention sinks).")
353-
return True
354-
355-
if force_use_trtllm is None:
356-
# CLI argument not set - use auto-detection
357-
if is_prefill:
358-
# Prefill auto-detection
359-
use_trtllm = kv_cache_dtype == "auto"
360-
if use_trtllm:
361-
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
425+
else:
426+
if force_use_trtllm is False:
427+
print_warning(
428+
f"{prefix} Using TRTLLM for {phase_str} even though --attention-config.use_trtllm_attention is set to 0. Reason: {reason}"
429+
)
362430
else:
363-
# Decode auto-detection
364-
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
365-
if use_trtllm:
366-
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
367-
return use_trtllm
368-
369-
# CLI argument is set to 1 - respect it
370-
logger.info_once(
371-
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
372-
)
373-
return True
431+
print_info(f"{prefix} Using TRTLLM for {phase_str}. Reason: {reason}")
432+
return True
374433

375434

376435
if has_flashinfer():

0 commit comments

Comments
 (0)