Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 170 additions & 91 deletions vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,90 @@


@functools.cache
def supports_trtllm_attention() -> bool:
def is_sm90_supported() -> bool:
return current_platform.is_device_capability(90)


@functools.cache
def is_sm100f_supported() -> bool:
return any(current_platform.is_device_capability(cap) for cap in [100, 103])


@functools.cache
def check_trtllm_attention_support(
is_prefill: bool,
num_qo_heads: int | None = None,
num_kv_heads: int | None = None,
dcp_world_size: int | None = None,
kv_cache_dtype: str | None = None,
q_data_type: torch.dtype | None = None,
has_sinks: bool | None = None,
has_spec: bool | None = None,
) -> tuple[bool | None, str]:
"""
TRTLLM attention is supported if the platform is SM100,
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
Check if the provided config + current platform is supported by TRTLLM attention.

Args:
is_prefill: Whether it is prefill.
num_qo_heads: Number of query heads.
num_kv_heads: Number of key/value heads.
dcp_world_size: World size of decode context parallel.
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
q_data_type: Data type of the query.
has_sinks: Whether sinks are being used.
has_spec: Whether speculative decoding is being used.

If any args (except for is_prefill) are set to None, the check for that arg is
skipped.

Returns:
A tuple of (bool | None, str). If the bool is:
- True: TRTLLM attention must be used.
- False: TRTLLM attention must not be used.
- None: TRTLLM attention can be used.
The str is the reason why it must or must not be used. Empty string if can be
used.
"""
# Batch-invariant mode disables TRTLLM attention

if vllm_is_batch_invariant():
return False
return False, "Batch-invariant mode is enabled."

# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
if not has_nvidia_artifactory():
return False, "NVIDIA artifactory is not accessible."

if is_sm90_supported():
if is_prefill:
return False, "SM90 is not supported for prefill."
if q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
return False, "xqa does not support FP8-Q."
elif is_sm100f_supported():
if (
is_prefill
and kv_cache_dtype is not None
and not kv_cache_dtype.startswith("fp8")
and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]
):
return False, "trtllm-gen prefill does not support FP8-Q with BF16/FP16-Q."
else:
return False, "SMs other than 90/100/103 are not supported."

if dcp_world_size is not None and dcp_world_size > 1:
return False, "DCP is not supported due to lack of LSE return support."

if (
num_qo_heads is not None
and num_kv_heads is not None
and num_qo_heads % num_kv_heads != 0
):
return False, "num_qo_heads must be a multiple of num_kv_heads."

if has_spec and not is_prefill:
return True, "Has speculative decoding in decode phase."

if has_sinks:
return True, "Has attention sinks."

return None, ""


def force_use_trtllm_attention() -> bool | None:
Expand All @@ -281,96 +354,102 @@
return vllm_config.attention_config.use_trtllm_attention


def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)


def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
# None means auto-detection, True means force on, False means force off
force_use_trtllm: bool | None = None,
has_sinks: bool = False,
has_spec: bool = False,
num_qo_heads: int | None = None,
num_kv_heads: int | None = None,
dcp_world_size: int | None = None,
kv_cache_dtype: str | None = None,
q_data_type: torch.dtype | None = None,
has_sinks: bool | None = None,
has_spec: bool | None = None,
silent: bool = False,
) -> bool:
"""Return `True` if TRTLLM attention is used."""

# CLI argument is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False

# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result "
"does not support DCP, reverting to FlashInfer"
)
return False

# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but --attention-config.use_trtllm_attention is set to 1"
"""
Decides whether to use TRTLLM attention based on these two functions:
- check_trtllm_attention_support(): whether TRTLLM attention must or must not be
used.
- force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM
attention.
If the decision does not match the user's preference, print the warning messages.

Args:
is_prefill: Whether it is prefill.
num_qo_heads: Number of query heads.
num_kv_heads: Number of key/value heads.
dcp_world_size: World size of decode context parallel.
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
q_data_type: Data type of the query.
has_sinks: Whether sinks are being used.
has_spec: Whether speculative decoding is being used.
silent: Whether to print the warning/info messages.

If any args (except for is_prefill) are set to None, the check for that arg is
skipped.

Returns: whether to use TRTLLM attention.
"""
supports_trtllm, reason = check_trtllm_attention_support(
is_prefill,
num_qo_heads,
num_kv_heads,
dcp_world_size,
kv_cache_dtype,
q_data_type,
has_sinks,
has_spec,
)
force_use_trtllm = force_use_trtllm_attention()
phase_str = "prefill" if is_prefill else "decode"
prefix = "[FlashInfer Attention]"

# Helper functions to print warning/info if not silent.
def print_warning(msg: str):
if not silent:
logger.warning_once(msg)

def print_info(msg: str):
if not silent:
logger.info_once(msg)

if force_use_trtllm is True:
if supports_trtllm is False:
print_warning(
f"{prefix} Using non-TRTLLM for {phase_str} even though --attention-"
f"config.use_trtllm_attention is set to 1. (Reason: {reason})"
)
return False

# The combination of query and key heads is not supported
if num_qo_heads % num_kv_heads != 0:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but --attention-config.use_trtllm_attention is "
"set to 1"
return False
else:
print_info(
f"{prefix} Using TRTLLM for {phase_str}. (Reason: --attention-config."
f"use_trtllm_attention is set to 1.)"
)
return False

if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
return True

# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
logger.info_once("Using TRTLLM attention (query is quantized).")
return True

# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
logger.info_once("Using TRTLLM attention (required for attention sinks).")
return True

if force_use_trtllm is None:
# CLI argument not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
return True
elif force_use_trtllm is False:
if supports_trtllm is True:
print_warning(
f"{prefix} Using TRTLLM for {phase_str} even though --attention-config."
f"use_trtllm_attention is set to 0. (Reason: {reason})"
)
return True
else:
# Decode auto-detection
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm

# CLI argument is set to 1 - respect it
logger.info_once(
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
)
return True
print_info(
f"{prefix} Using non-TRTLLM for {phase_str}. (Reason: --attention-"
f"config.use_trtllm_attention is set to 0.)"
)
return False
else:
if supports_trtllm is True:
print_info(f"{prefix} Using TRTLLM for {phase_str}. (Reason: {reason})")
return True
elif supports_trtllm is False:
print_info(f"{prefix} Using non-TRTLLM for {phase_str}. (Reason: {reason})")
return False
else:
print_info(
f"{prefix} Using TRTLLM for {phase_str}. (Reason: auto-detected.)"
)
return True


if has_flashinfer():
Expand Down Expand Up @@ -518,8 +597,8 @@
"has_flashinfer_cutlass_fused_moe",
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_nvidia_artifactory",
"supports_trtllm_attention",

Check failure on line 600 in vllm/utils/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F822)

vllm/utils/flashinfer.py:600:5: F822 Undefined name `supports_trtllm_attention` in `__all__`
"can_use_trtllm_attention",

Check failure on line 601 in vllm/utils/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F822)

vllm/utils/flashinfer.py:601:5: F822 Undefined name `can_use_trtllm_attention` in `__all__`
"use_trtllm_attention",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
Expand Down
Loading
Loading