1212import os
1313import shutil
1414from collections .abc import Callable
15- from typing import Any , NoReturn
15+ from typing import Any , NoReturn , Tuple
1616
1717import requests
1818import torch
@@ -254,17 +254,82 @@ def has_nvidia_artifactory() -> bool:
254254
255255
256256@functools .cache
257- def supports_trtllm_attention () -> bool :
257+ def is_sm90_supported () -> bool :
258+ return current_platform .is_device_capability (90 )
259+
260+
261+ @functools .cache
262+ def is_sm100f_supported () -> bool :
263+ return any (current_platform .is_device_capability (cap ) for cap in [100 , 103 ])
264+
265+
266+ @functools .cache
267+ def check_trtllm_attention_support (
268+ is_prefill : bool ,
269+ num_qo_heads : int | None = None ,
270+ num_kv_heads : int | None = None ,
271+ dcp_world_size : int | None = None ,
272+ kv_cache_dtype : str | None = None ,
273+ q_data_type : torch .dtype | None = None ,
274+ has_sinks : bool | None = None ,
275+ has_spec : bool | None = None ,
276+ ) -> Tuple [bool , str ]:
258277 """
259- TRTLLM attention is supported if the platform is SM100,
260- NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
278+ Check if the provided config + current platform is supported by TRTLLM attention.
279+
280+ Args:
281+ is_prefill: Whether it is prefill.
282+ num_qo_heads: Number of query heads.
283+ num_kv_heads: Number of key/value heads.
284+ dcp_world_size: World size of decode context parallel.
285+ kv_cache_dtype: Data type of the key/value cache. Could be "auto".
286+ q_dtype: Data type of the query.
287+ has_sinks: Whether sinks are being used.
288+ has_spec: Whether speculative decoding is being used.
289+
290+ If any args (except for is_prefill) are set to None, the check for that arg is skipped.
291+
292+ Returns:
293+ A tuple of (bool, str). If the bool is:
294+ - True: TRTLLM attention must be used.
295+ - False: TRTLLM attention must not be used.
296+ - None: TRTLLM attention can be used.
297+ The str is the reason why it must or must not be used. Empty string if can be used.
261298 """
262- # Batch-invariant mode disables TRTLLM attention
299+
263300 if vllm_is_batch_invariant ():
264- return False
301+ return False , "Batch-invariant mode is enabled."
302+
303+ if not has_nvidia_artifactory ():
304+ return False , "NVIDIA artifactory is not accessible."
305+
306+ if is_sm90_supported ():
307+ if is_prefill :
308+ return False , "SM90 is not supported for prefill."
309+ if q_data_type in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
310+ return False , "xqa does not support FP8-Q."
311+ elif is_sm100f_supported ():
312+ if is_prefill and \
313+ kv_cache_dtype is not None and \
314+ not kv_cache_dtype .startswith ("fp8" ) and \
315+ q_data_type in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
316+ return False , "trtllm-gen prefill does not support FP8-Q with BF16/FP16-Q."
317+ else :
318+ return False , "SMs other than 90/100/103 are not supported."
319+
320+ if dcp_world_size is not None and dcp_world_size > 1 :
321+ return False , "DCP is not supported due to lack of LSE return support."
322+
323+ if num_qo_heads is not None and num_kv_heads is not None and num_qo_heads % num_kv_heads != 0 :
324+ return False , "num_qo_heads must be a multiple of num_kv_heads."
325+
326+ if has_spec and not is_prefill :
327+ return True , "Has speculative decoding in decode phase."
328+
329+ if has_sinks :
330+ return True , "Has attention sinks."
265331
266- # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
267- return current_platform .is_device_capability (100 ) and has_nvidia_artifactory ()
332+ return None , ""
268333
269334
270335def force_use_trtllm_attention () -> bool | None :
@@ -281,96 +346,90 @@ def force_use_trtllm_attention() -> bool | None:
281346 return vllm_config .attention_config .use_trtllm_attention
282347
283348
284- def can_use_trtllm_attention (num_qo_heads : int , num_kv_heads : int ) -> bool :
285- """Check if the current configuration supports TRTLLM attention."""
286- if force_use_trtllm_attention () is False :
287- return False
288- has_trtllm = supports_trtllm_attention ()
289- return has_trtllm and (num_qo_heads % num_kv_heads == 0 )
290-
291-
292349def use_trtllm_attention (
293- num_qo_heads : int ,
294- num_kv_heads : int ,
295- num_tokens : int ,
296- max_seq_len : int ,
297- dcp_world_size : int ,
298- kv_cache_dtype : str ,
299- q_dtype : torch .dtype ,
300350 is_prefill : bool ,
301- # None means auto-detection, True means force on, False means force off
302- force_use_trtllm : bool | None = None ,
303- has_sinks : bool = False ,
304- has_spec : bool = False ,
351+ num_qo_heads : int | None = None ,
352+ num_kv_heads : int | None = None ,
353+ dcp_world_size : int | None = None ,
354+ kv_cache_dtype : str | None = None ,
355+ q_data_type : torch .dtype | None = None ,
356+ has_sinks : bool | None = None ,
357+ has_spec : bool | None = None ,
358+ silent : bool = False ,
305359) -> bool :
306- """Return `True` if TRTLLM attention is used."""
307-
308- # CLI argument is set to 0 - respect it
309- if force_use_trtllm is not None and not force_use_trtllm :
310- return False
311-
312- # Decode context parallel is not supported
313- if dcp_world_size > 1 :
314- logger .warning_once (
315- "Trtllm does not support returning LSE and as a result "
316- "does not support DCP, reverting to FlashInfer"
317- )
318- return False
319-
320- # The platform is not supported
321- if not supports_trtllm_attention ():
322- if force_use_trtllm :
323- logger .warning_once (
324- "TRTLLM attention is not supported on this platform, "
325- "but --attention-config.use_trtllm_attention is set to 1"
360+ """
361+ Decides whether to use TRTLLM attention based on these two functions:
362+ - check_trtllm_attention_support(): whether TRTLLM attention must or must not be used.
363+ - force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM attention.
364+ If the decision does not match the user's preference, print the warning messages.
365+
366+ Args:
367+ is_prefill: Whether it is prefill.
368+ num_qo_heads: Number of query heads.
369+ num_kv_heads: Number of key/value heads.
370+ dcp_world_size: World size of decode context parallel.
371+ kv_cache_dtype: Data type of the key/value cache. Could be "auto".
372+ q_data_type: Data type of the query.
373+ has_sinks: Whether sinks are being used.
374+ has_spec: Whether speculative decoding is being used.
375+ silent: Whether to print the warning/info messages.
376+
377+ If any args (except for is_prefill) are set to None, the check for that arg is skipped.
378+
379+ Returns: whether to use TRTLLM attention.
380+ """
381+ supports_trtllm , reason = check_trtllm_attention_support (
382+ is_prefill , num_qo_heads , num_kv_heads , dcp_world_size , kv_cache_dtype , q_data_type , has_sinks , has_spec
383+ )
384+ force_use_trtllm = force_use_trtllm_attention ()
385+ phase_str = "prefill" if is_prefill else "decode"
386+ prefix = "[FlashInfer Attention]"
387+
388+ # Helper functions to print warning/info if not silent.
389+ def print_warning (msg : str ):
390+ if not silent :
391+ logger .warning_once (msg )
392+
393+ def print_info (msg : str ):
394+ if not silent :
395+ logger .info_once (msg )
396+
397+ # Follow users' preference if supports_trtllm is None.
398+ if supports_trtllm is None :
399+ if force_use_trtllm is True :
400+ print_info (
401+ f"{ prefix } Using TRTLLM for { phase_str } (--attention-config.use_trtllm_attention is set to 1)."
326402 )
327- return False
328-
329- # The combination of query and key heads is not supported
330- if num_qo_heads % num_kv_heads != 0 :
331- if force_use_trtllm :
332- logger .warning_once (
333- "TRTLLM attention is not supported for this combination of "
334- "query and key heads, but --attention-config.use_trtllm_attention is "
335- "set to 1"
403+ return True
404+ elif force_use_trtllm is False :
405+ print_info (
406+ f"{ prefix } Using non-TRTLLM for { phase_str } (--attention-config.use_trtllm_attention is set to 0)."
407+ )
408+ return False
409+ else :
410+ print_info (
411+ f"{ prefix } Using TRTLLM for { phase_str } (auto-detected)."
412+ )
413+ return True
414+ # Print warning if supports_trtllm does not match force_use_trtllm.
415+ elif supports_trtllm is False :
416+ if force_use_trtllm is True :
417+ print_warning (
418+ f"{ prefix } Using non-TRTLLM for { phase_str } even though --attention-config.use_trtllm_attention is set to 1. Reason: { reason } "
419+ )
420+ else :
421+ print_info (
422+ f"{ prefix } Using non-TRTLLM for { phase_str } . Reason: { reason } "
336423 )
337424 return False
338-
339- if has_spec and not is_prefill :
340- # Speculative decoding requires TRTLLM attention for decodes
341- logger .info_once ("Using TRTLLM attention (enabled for speculative decoding)." )
342- return True
343-
344- # Must use TRTLLM attention if query is FP8 quantized
345- if q_dtype == current_platform .fp8_dtype ():
346- logger .info_once ("Using TRTLLM attention (query is quantized)." )
347- return True
348-
349- # If sinks are being used, we must use TRTLLM attention as it's
350- # the only backend that supports them
351- if has_sinks :
352- logger .info_once ("Using TRTLLM attention (required for attention sinks)." )
353- return True
354-
355- if force_use_trtllm is None :
356- # CLI argument not set - use auto-detection
357- if is_prefill :
358- # Prefill auto-detection
359- use_trtllm = kv_cache_dtype == "auto"
360- if use_trtllm :
361- logger .warning_once ("Using TRTLLM prefill attention (auto-detected)." )
425+ else :
426+ if force_use_trtllm is False :
427+ print_warning (
428+ f"{ prefix } Using TRTLLM for { phase_str } even though --attention-config.use_trtllm_attention is set to 0. Reason: { reason } "
429+ )
362430 else :
363- # Decode auto-detection
364- use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
365- if use_trtllm :
366- logger .warning_once ("Using TRTLLM decode attention (auto-detected)." )
367- return use_trtllm
368-
369- # CLI argument is set to 1 - respect it
370- logger .info_once (
371- "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
372- )
373- return True
431+ print_info (f"{ prefix } Using TRTLLM for { phase_str } . Reason: { reason } " )
432+ return True
374433
375434
376435if has_flashinfer ():
0 commit comments