Skip to content

Commit a691699

Browse files
committed
Enable XQA kernels on Hopper
Signed-off-by: Po-Han Huang <[email protected]>
1 parent bc47808 commit a691699

File tree

2 files changed

+236
-170
lines changed

2 files changed

+236
-170
lines changed

vllm/utils/flashinfer.py

Lines changed: 120 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
import shutil
1414
from collections.abc import Callable
15-
from typing import Any, NoReturn
15+
from typing import Any, NoReturn, Tuple
1616

1717
import requests
1818
import torch
@@ -254,17 +254,64 @@ def has_nvidia_artifactory() -> bool:
254254

255255

256256
@functools.cache
257-
def supports_trtllm_attention() -> bool:
257+
def check_trtllm_attention_support(
258+
is_prefill: bool,
259+
num_qo_heads: int | None = None,
260+
num_kv_heads: int | None = None,
261+
dcp_world_size: int | None = None,
262+
kv_cache_dtype: str | None = None,
263+
q_data_type: torch.dtype | None = None,
264+
has_sinks: bool | None = None,
265+
has_spec: bool | None = None,
266+
) -> Tuple[bool, str]:
258267
"""
259-
TRTLLM attention is supported if the platform is SM100,
260-
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
268+
Check if the provided config + current platform is supported by TRTLLM attention.
269+
270+
Args:
271+
is_prefill: Whether it is prefill.
272+
num_qo_heads: Number of query heads.
273+
num_kv_heads: Number of key/value heads.
274+
dcp_world_size: World size of decode context parallel.
275+
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
276+
q_dtype: Data type of the query.
277+
has_sinks: Whether sinks are being used.
278+
has_spec: Whether speculative decoding is being used.
279+
280+
If any args (except for is_prefill) are set to None, the check for that arg is skipped.
281+
282+
Returns:
283+
A tuple of (bool, str). If the bool is:
284+
- True: TRTLLM attention must be used.
285+
- False: TRTLLM attention must not be used.
286+
- None: TRTLLM attention can be used.
287+
The str is the reason why it must or must not be used. Empty string if can be used.
261288
"""
262-
# Batch-invariant mode disables TRTLLM attention
289+
263290
if vllm_is_batch_invariant():
264-
return False
291+
return False, "Batch-invariant mode is enabled."
292+
293+
if not has_nvidia_artifactory():
294+
return False, "NVIDIA artifactory is not accessible."
295+
296+
if current_platform.is_device_capability(90):
297+
if is_prefill:
298+
return False, "SM90 is not supported for prefill."
299+
elif not any(current_platform.is_device_capability(cap) for cap in [100, 103, 110]):
300+
return False, "SMs other than 90/100/103/110 are not supported."
265301

266-
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
267-
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
302+
if dcp_world_size is not None and dcp_world_size > 1:
303+
return False, "DCP is not supported due to lack of LSE return support."
304+
305+
if num_qo_heads is not None and num_kv_heads is not None and num_qo_heads % num_kv_heads != 0:
306+
return False, "num_qo_heads must be a multiple of num_kv_heads."
307+
308+
if has_spec is not None and has_spec and not is_prefill:
309+
return True, "Has speculative decoding in decode phase."
310+
311+
if has_sinks is not None and has_sinks:
312+
return True, "Has attention sinks."
313+
314+
return None, ""
268315

269316

270317
def force_use_trtllm_attention() -> bool | None:
@@ -281,96 +328,85 @@ def force_use_trtllm_attention() -> bool | None:
281328
return vllm_config.attention_config.use_trtllm_attention
282329

283330

284-
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
285-
"""Check if the current configuration supports TRTLLM attention."""
286-
if force_use_trtllm_attention() is False:
287-
return False
288-
has_trtllm = supports_trtllm_attention()
289-
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
290-
291-
292331
def use_trtllm_attention(
293332
num_qo_heads: int,
294333
num_kv_heads: int,
295-
num_tokens: int,
296-
max_seq_len: int,
297334
dcp_world_size: int,
298335
kv_cache_dtype: str,
299336
q_dtype: torch.dtype,
300337
is_prefill: bool,
301-
# None means auto-detection, True means force on, False means force off
302-
force_use_trtllm: bool | None = None,
303338
has_sinks: bool = False,
304339
has_spec: bool = False,
340+
silent: bool = False,
305341
) -> bool:
306-
"""Return `True` if TRTLLM attention is used."""
307-
308-
# CLI argument is set to 0 - respect it
309-
if force_use_trtllm is not None and not force_use_trtllm:
310-
return False
311-
312-
# Decode context parallel is not supported
313-
if dcp_world_size > 1:
314-
logger.warning_once(
315-
"Trtllm does not support returning LSE and as a result "
316-
"does not support DCP, reverting to FlashInfer"
317-
)
318-
return False
319-
320-
# The platform is not supported
321-
if not supports_trtllm_attention():
322-
if force_use_trtllm:
323-
logger.warning_once(
324-
"TRTLLM attention is not supported on this platform, "
325-
"but --attention-config.use_trtllm_attention is set to 1"
342+
"""
343+
Decides whether to use TRTLLM attention based on these two functions:
344+
- check_trtllm_attention_support(): whether TRTLLM attention must or must not be used.
345+
- force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM attention.
346+
If the decision does not match the user's preference, print the warning messages.
347+
348+
Args:
349+
is_prefill: Whether it is prefill.
350+
num_qo_heads: Number of query heads.
351+
num_kv_heads: Number of key/value heads.
352+
dcp_world_size: World size of decode context parallel.
353+
kv_cache_dtype: Data type of the key/value cache. Could be "auto".
354+
q_dtype: Data type of the query.
355+
has_sinks: Whether sinks are being used.
356+
has_spec: Whether speculative decoding is being used.
357+
silent: Whether to print the warning/info messages.
358+
359+
Returns: whether to use TRTLLM attention.
360+
"""
361+
supports_trtllm, reason = check_trtllm_attention_support(
362+
is_prefill, num_qo_heads, num_kv_heads, dcp_world_size, kv_cache_dtype, q_dtype, has_sinks, has_spec
363+
)
364+
force_use_trtllm = force_use_trtllm_attention()
365+
phase_str = "prefill" if is_prefill else "decode"
366+
367+
# Helper functions to print warning/info if not silent.
368+
def print_warning(msg: str):
369+
if not silent:
370+
logger.warning_once(msg)
371+
372+
def print_info(msg: str):
373+
if not silent:
374+
logger.info_once(msg)
375+
376+
# Follow users' preference if supports_trtllm is None.
377+
if supports_trtllm is None:
378+
if force_use_trtllm is True:
379+
print_info(
380+
f"Using TRTLLM attention for {phase_str} (--attention-config.use_trtllm_attention is set to 1)."
326381
)
327-
return False
328-
329-
# The combination of query and key heads is not supported
330-
if num_qo_heads % num_kv_heads != 0:
331-
if force_use_trtllm:
332-
logger.warning_once(
333-
"TRTLLM attention is not supported for this combination of "
334-
"query and key heads, but --attention-config.use_trtllm_attention is "
335-
"set to 1"
382+
return True
383+
elif force_use_trtllm is False:
384+
print_info(
385+
f"Not using TRTLLM attention for {phase_str} (--attention-config.use_trtllm_attention is set to 0)."
336386
)
387+
return False
388+
else:
389+
print_info(
390+
f"Using TRTLLM attention for {phase_str} (auto-detected)."
391+
)
392+
return True
393+
# Print warning if supports_trtllm does not match force_use_trtllm.
394+
elif supports_trtllm is False:
395+
if force_use_trtllm is True:
396+
print_warning(
397+
f"Not using TRTLLM attention for {phase_str} even though --attention-config.use_trtllm_attention is set to 1. Reason: {reason}"
398+
)
399+
else:
400+
print_info(f"Not using TRTLLM attention for {phase_str}. Reason: {reason}")
337401
return False
338-
339-
if has_spec and not is_prefill:
340-
# Speculative decoding requires TRTLLM attention for decodes
341-
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
342-
return True
343-
344-
# Must use TRTLLM attention if query is FP8 quantized
345-
if q_dtype == current_platform.fp8_dtype():
346-
logger.info_once("Using TRTLLM attention (query is quantized).")
347-
return True
348-
349-
# If sinks are being used, we must use TRTLLM attention as it's
350-
# the only backend that supports them
351-
if has_sinks:
352-
logger.info_once("Using TRTLLM attention (required for attention sinks).")
353-
return True
354-
355-
if force_use_trtllm is None:
356-
# CLI argument not set - use auto-detection
357-
if is_prefill:
358-
# Prefill auto-detection
359-
use_trtllm = kv_cache_dtype == "auto"
360-
if use_trtllm:
361-
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
402+
else:
403+
if force_use_trtllm is False:
404+
print_warning(
405+
f"Using TRTLLM attention for {phase_str} even though --attention-config.use_trtllm_attention is set to 0. Reason: {reason}"
406+
)
362407
else:
363-
# Decode auto-detection
364-
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
365-
if use_trtllm:
366-
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
367-
return use_trtllm
368-
369-
# CLI argument is set to 1 - respect it
370-
logger.info_once(
371-
"Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
372-
)
373-
return True
408+
print_info(f"Using TRTLLM attention for {phase_str}. Reason: {reason}")
409+
return True
374410

375411

376412
if has_flashinfer():

0 commit comments

Comments
 (0)