1212import os
1313import shutil
1414from collections .abc import Callable
15- from typing import Any , NoReturn
15+ from typing import Any , NoReturn , Tuple
1616
1717import requests
1818import torch
@@ -254,17 +254,64 @@ def has_nvidia_artifactory() -> bool:
254254
255255
256256@functools .cache
257- def supports_trtllm_attention () -> bool :
257+ def check_trtllm_attention_support (
258+ is_prefill : bool ,
259+ num_qo_heads : int | None = None ,
260+ num_kv_heads : int | None = None ,
261+ dcp_world_size : int | None = None ,
262+ kv_cache_dtype : str | None = None ,
263+ q_data_type : torch .dtype | None = None ,
264+ has_sinks : bool | None = None ,
265+ has_spec : bool | None = None ,
266+ ) -> Tuple [bool , str ]:
258267 """
259- TRTLLM attention is supported if the platform is SM100,
260- NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
268+ Check if the provided config + current platform is supported by TRTLLM attention.
269+
270+ Args:
271+ is_prefill: Whether it is prefill.
272+ num_qo_heads: Number of query heads.
273+ num_kv_heads: Number of key/value heads.
274+ dcp_world_size: World size of decode context parallel.
275+ kv_cache_dtype: Data type of the key/value cache. Could be "auto".
276+ q_dtype: Data type of the query.
277+ has_sinks: Whether sinks are being used.
278+ has_spec: Whether speculative decoding is being used.
279+
280+ If any args (except for is_prefill) are set to None, the check for that arg is skipped.
281+
282+ Returns:
283+ A tuple of (bool, str). If the bool is:
284+ - True: TRTLLM attention must be used.
285+ - False: TRTLLM attention must not be used.
286+ - None: TRTLLM attention can be used.
287+ The str is the reason why it must or must not be used. Empty string if can be used.
261288 """
262- # Batch-invariant mode disables TRTLLM attention
289+
263290 if vllm_is_batch_invariant ():
264- return False
291+ return False , "Batch-invariant mode is enabled."
292+
293+ if not has_nvidia_artifactory ():
294+ return False , "NVIDIA artifactory is not accessible."
295+
296+ if current_platform .is_device_capability (90 ):
297+ if is_prefill :
298+ return False , "SM90 is not supported for prefill."
299+ elif not any (current_platform .is_device_capability (cap ) for cap in [100 , 103 , 110 ]):
300+ return False , "SMs other than 90/100/103/110 are not supported."
265301
266- # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
267- return current_platform .is_device_capability (100 ) and has_nvidia_artifactory ()
302+ if dcp_world_size is not None and dcp_world_size > 1 :
303+ return False , "DCP is not supported due to lack of LSE return support."
304+
305+ if num_qo_heads is not None and num_kv_heads is not None and num_qo_heads % num_kv_heads != 0 :
306+ return False , "num_qo_heads must be a multiple of num_kv_heads."
307+
308+ if has_spec is not None and has_spec and not is_prefill :
309+ return True , "Has speculative decoding in decode phase."
310+
311+ if has_sinks is not None and has_sinks :
312+ return True , "Has attention sinks."
313+
314+ return None , ""
268315
269316
270317def force_use_trtllm_attention () -> bool | None :
@@ -281,96 +328,85 @@ def force_use_trtllm_attention() -> bool | None:
281328 return vllm_config .attention_config .use_trtllm_attention
282329
283330
284- def can_use_trtllm_attention (num_qo_heads : int , num_kv_heads : int ) -> bool :
285- """Check if the current configuration supports TRTLLM attention."""
286- if force_use_trtllm_attention () is False :
287- return False
288- has_trtllm = supports_trtllm_attention ()
289- return has_trtllm and (num_qo_heads % num_kv_heads == 0 )
290-
291-
292331def use_trtllm_attention (
293332 num_qo_heads : int ,
294333 num_kv_heads : int ,
295- num_tokens : int ,
296- max_seq_len : int ,
297334 dcp_world_size : int ,
298335 kv_cache_dtype : str ,
299336 q_dtype : torch .dtype ,
300337 is_prefill : bool ,
301- # None means auto-detection, True means force on, False means force off
302- force_use_trtllm : bool | None = None ,
303338 has_sinks : bool = False ,
304339 has_spec : bool = False ,
340+ silent : bool = False ,
305341) -> bool :
306- """Return `True` if TRTLLM attention is used."""
307-
308- # CLI argument is set to 0 - respect it
309- if force_use_trtllm is not None and not force_use_trtllm :
310- return False
311-
312- # Decode context parallel is not supported
313- if dcp_world_size > 1 :
314- logger .warning_once (
315- "Trtllm does not support returning LSE and as a result "
316- "does not support DCP, reverting to FlashInfer"
317- )
318- return False
319-
320- # The platform is not supported
321- if not supports_trtllm_attention ():
322- if force_use_trtllm :
323- logger .warning_once (
324- "TRTLLM attention is not supported on this platform, "
325- "but --attention-config.use_trtllm_attention is set to 1"
342+ """
343+ Decides whether to use TRTLLM attention based on these two functions:
344+ - check_trtllm_attention_support(): whether TRTLLM attention must or must not be used.
345+ - force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM attention.
346+ If the decision does not match the user's preference, print the warning messages.
347+
348+ Args:
349+ is_prefill: Whether it is prefill.
350+ num_qo_heads: Number of query heads.
351+ num_kv_heads: Number of key/value heads.
352+ dcp_world_size: World size of decode context parallel.
353+ kv_cache_dtype: Data type of the key/value cache. Could be "auto".
354+ q_dtype: Data type of the query.
355+ has_sinks: Whether sinks are being used.
356+ has_spec: Whether speculative decoding is being used.
357+ silent: Whether to print the warning/info messages.
358+
359+ Returns: whether to use TRTLLM attention.
360+ """
361+ supports_trtllm , reason = check_trtllm_attention_support (
362+ is_prefill , num_qo_heads , num_kv_heads , dcp_world_size , kv_cache_dtype , q_dtype , has_sinks , has_spec
363+ )
364+ force_use_trtllm = force_use_trtllm_attention ()
365+ phase_str = "prefill" if is_prefill else "decode"
366+
367+ # Helper functions to print warning/info if not silent.
368+ def print_warning (msg : str ):
369+ if not silent :
370+ logger .warning_once (msg )
371+
372+ def print_info (msg : str ):
373+ if not silent :
374+ logger .info_once (msg )
375+
376+ # Follow users' preference if supports_trtllm is None.
377+ if supports_trtllm is None :
378+ if force_use_trtllm is True :
379+ print_info (
380+ f"Using TRTLLM attention for { phase_str } (--attention-config.use_trtllm_attention is set to 1)."
326381 )
327- return False
328-
329- # The combination of query and key heads is not supported
330- if num_qo_heads % num_kv_heads != 0 :
331- if force_use_trtllm :
332- logger .warning_once (
333- "TRTLLM attention is not supported for this combination of "
334- "query and key heads, but --attention-config.use_trtllm_attention is "
335- "set to 1"
382+ return True
383+ elif force_use_trtllm is False :
384+ print_info (
385+ f"Not using TRTLLM attention for { phase_str } (--attention-config.use_trtllm_attention is set to 0)."
336386 )
387+ return False
388+ else :
389+ print_info (
390+ f"Using TRTLLM attention for { phase_str } (auto-detected)."
391+ )
392+ return True
393+ # Print warning if supports_trtllm does not match force_use_trtllm.
394+ elif supports_trtllm is False :
395+ if force_use_trtllm is True :
396+ print_warning (
397+ f"Not using TRTLLM attention for { phase_str } even though --attention-config.use_trtllm_attention is set to 1. Reason: { reason } "
398+ )
399+ else :
400+ print_info (f"Not using TRTLLM attention for { phase_str } . Reason: { reason } " )
337401 return False
338-
339- if has_spec and not is_prefill :
340- # Speculative decoding requires TRTLLM attention for decodes
341- logger .info_once ("Using TRTLLM attention (enabled for speculative decoding)." )
342- return True
343-
344- # Must use TRTLLM attention if query is FP8 quantized
345- if q_dtype == current_platform .fp8_dtype ():
346- logger .info_once ("Using TRTLLM attention (query is quantized)." )
347- return True
348-
349- # If sinks are being used, we must use TRTLLM attention as it's
350- # the only backend that supports them
351- if has_sinks :
352- logger .info_once ("Using TRTLLM attention (required for attention sinks)." )
353- return True
354-
355- if force_use_trtllm is None :
356- # CLI argument not set - use auto-detection
357- if is_prefill :
358- # Prefill auto-detection
359- use_trtllm = kv_cache_dtype == "auto"
360- if use_trtllm :
361- logger .warning_once ("Using TRTLLM prefill attention (auto-detected)." )
402+ else :
403+ if force_use_trtllm is False :
404+ print_warning (
405+ f"Using TRTLLM attention for { phase_str } even though --attention-config.use_trtllm_attention is set to 0. Reason: { reason } "
406+ )
362407 else :
363- # Decode auto-detection
364- use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
365- if use_trtllm :
366- logger .warning_once ("Using TRTLLM decode attention (auto-detected)." )
367- return use_trtllm
368-
369- # CLI argument is set to 1 - respect it
370- logger .info_once (
371- "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
372- )
373- return True
408+ print_info (f"Using TRTLLM attention for { phase_str } . Reason: { reason } " )
409+ return True
374410
375411
376412if has_flashinfer ():
0 commit comments