diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 9a66049350cd..dd518bdf2e65 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -254,17 +254,90 @@ def has_nvidia_artifactory() -> bool: @functools.cache -def supports_trtllm_attention() -> bool: +def is_sm90_supported() -> bool: + return current_platform.is_device_capability(90) + + +@functools.cache +def is_sm100f_supported() -> bool: + return any(current_platform.is_device_capability(cap) for cap in [100, 103]) + + +@functools.cache +def check_trtllm_attention_support( + is_prefill: bool, + num_qo_heads: int | None = None, + num_kv_heads: int | None = None, + dcp_world_size: int | None = None, + kv_cache_dtype: str | None = None, + q_data_type: torch.dtype | None = None, + has_sinks: bool | None = None, + has_spec: bool | None = None, +) -> tuple[bool | None, str]: """ - TRTLLM attention is supported if the platform is SM100, - NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. + Check if the provided config + current platform is supported by TRTLLM attention. + + Args: + is_prefill: Whether it is prefill. + num_qo_heads: Number of query heads. + num_kv_heads: Number of key/value heads. + dcp_world_size: World size of decode context parallel. + kv_cache_dtype: Data type of the key/value cache. Could be "auto". + q_data_type: Data type of the query. + has_sinks: Whether sinks are being used. + has_spec: Whether speculative decoding is being used. + + If any args (except for is_prefill) are set to None, the check for that arg is + skipped. + + Returns: + A tuple of (bool | None, str). If the bool is: + - True: TRTLLM attention must be used. + - False: TRTLLM attention must not be used. + - None: TRTLLM attention can be used. + The str is the reason why it must or must not be used. Empty string if can be + used. """ - # Batch-invariant mode disables TRTLLM attention + if vllm_is_batch_invariant(): - return False + return False, "Batch-invariant mode is enabled." - # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability(100) and has_nvidia_artifactory() + if not has_nvidia_artifactory(): + return False, "NVIDIA artifactory is not accessible." + + if is_sm90_supported(): + if is_prefill: + return False, "SM90 is not supported for prefill." + if q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]: + return False, "xqa does not support FP8-Q." + elif is_sm100f_supported(): + if ( + is_prefill + and kv_cache_dtype is not None + and not kv_cache_dtype.startswith("fp8") + and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2] + ): + return False, "trtllm-gen prefill does not support FP8-Q with BF16/FP16-Q." + else: + return False, "SMs other than 90/100/103 are not supported." + + if dcp_world_size is not None and dcp_world_size > 1: + return False, "DCP is not supported due to lack of LSE return support." + + if ( + num_qo_heads is not None + and num_kv_heads is not None + and num_qo_heads % num_kv_heads != 0 + ): + return False, "num_qo_heads must be a multiple of num_kv_heads." + + if has_spec and not is_prefill: + return True, "Has speculative decoding in decode phase." + + if has_sinks: + return True, "Has attention sinks." + + return None, "" def force_use_trtllm_attention() -> bool | None: @@ -281,96 +354,102 @@ def force_use_trtllm_attention() -> bool | None: return vllm_config.attention_config.use_trtllm_attention -def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: - """Check if the current configuration supports TRTLLM attention.""" - if force_use_trtllm_attention() is False: - return False - has_trtllm = supports_trtllm_attention() - return has_trtllm and (num_qo_heads % num_kv_heads == 0) - - def use_trtllm_attention( - num_qo_heads: int, - num_kv_heads: int, - num_tokens: int, - max_seq_len: int, - dcp_world_size: int, - kv_cache_dtype: str, - q_dtype: torch.dtype, is_prefill: bool, - # None means auto-detection, True means force on, False means force off - force_use_trtllm: bool | None = None, - has_sinks: bool = False, - has_spec: bool = False, + num_qo_heads: int | None = None, + num_kv_heads: int | None = None, + dcp_world_size: int | None = None, + kv_cache_dtype: str | None = None, + q_data_type: torch.dtype | None = None, + has_sinks: bool | None = None, + has_spec: bool | None = None, + silent: bool = False, ) -> bool: - """Return `True` if TRTLLM attention is used.""" - - # CLI argument is set to 0 - respect it - if force_use_trtllm is not None and not force_use_trtllm: - return False - - # Decode context parallel is not supported - if dcp_world_size > 1: - logger.warning_once( - "Trtllm does not support returning LSE and as a result " - "does not support DCP, reverting to FlashInfer" - ) - return False - - # The platform is not supported - if not supports_trtllm_attention(): - if force_use_trtllm: - logger.warning_once( - "TRTLLM attention is not supported on this platform, " - "but --attention-config.use_trtllm_attention is set to 1" + """ + Decides whether to use TRTLLM attention based on these two functions: + - check_trtllm_attention_support(): whether TRTLLM attention must or must not be + used. + - force_use_trtllm_attention(): whether the user wants to force/disable TRTLLM + attention. + If the decision does not match the user's preference, print the warning messages. + + Args: + is_prefill: Whether it is prefill. + num_qo_heads: Number of query heads. + num_kv_heads: Number of key/value heads. + dcp_world_size: World size of decode context parallel. + kv_cache_dtype: Data type of the key/value cache. Could be "auto". + q_data_type: Data type of the query. + has_sinks: Whether sinks are being used. + has_spec: Whether speculative decoding is being used. + silent: Whether to print the warning/info messages. + + If any args (except for is_prefill) are set to None, the check for that arg is + skipped. + + Returns: whether to use TRTLLM attention. + """ + supports_trtllm, reason = check_trtllm_attention_support( + is_prefill, + num_qo_heads, + num_kv_heads, + dcp_world_size, + kv_cache_dtype, + q_data_type, + has_sinks, + has_spec, + ) + force_use_trtllm = force_use_trtllm_attention() + phase_str = "prefill" if is_prefill else "decode" + prefix = "[FlashInfer Attention]" + + # Helper functions to print warning/info if not silent. + def print_warning(msg: str): + if not silent: + logger.warning_once(msg) + + def print_info(msg: str): + if not silent: + logger.info_once(msg) + + if force_use_trtllm is True: + if supports_trtllm is False: + print_warning( + f"{prefix} Using non-TRTLLM for {phase_str} even though --attention-" + f"config.use_trtllm_attention is set to 1. (Reason: {reason})" ) - return False - - # The combination of query and key heads is not supported - if num_qo_heads % num_kv_heads != 0: - if force_use_trtllm: - logger.warning_once( - "TRTLLM attention is not supported for this combination of " - "query and key heads, but --attention-config.use_trtllm_attention is " - "set to 1" + return False + else: + print_info( + f"{prefix} Using TRTLLM for {phase_str}. (Reason: --attention-config." + f"use_trtllm_attention is set to 1.)" ) - return False - - if has_spec and not is_prefill: - # Speculative decoding requires TRTLLM attention for decodes - logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") - return True - - # Must use TRTLLM attention if query is FP8 quantized - if q_dtype == current_platform.fp8_dtype(): - logger.info_once("Using TRTLLM attention (query is quantized).") - return True - - # If sinks are being used, we must use TRTLLM attention as it's - # the only backend that supports them - if has_sinks: - logger.info_once("Using TRTLLM attention (required for attention sinks).") - return True - - if force_use_trtllm is None: - # CLI argument not set - use auto-detection - if is_prefill: - # Prefill auto-detection - use_trtllm = kv_cache_dtype == "auto" - if use_trtllm: - logger.warning_once("Using TRTLLM prefill attention (auto-detected).") + return True + elif force_use_trtllm is False: + if supports_trtllm is True: + print_warning( + f"{prefix} Using TRTLLM for {phase_str} even though --attention-config." + f"use_trtllm_attention is set to 0. (Reason: {reason})" + ) + return True else: - # Decode auto-detection - use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto" - if use_trtllm: - logger.warning_once("Using TRTLLM decode attention (auto-detected).") - return use_trtllm - - # CLI argument is set to 1 - respect it - logger.info_once( - "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)" - ) - return True + print_info( + f"{prefix} Using non-TRTLLM for {phase_str}. (Reason: --attention-" + f"config.use_trtllm_attention is set to 0.)" + ) + return False + else: + if supports_trtllm is True: + print_info(f"{prefix} Using TRTLLM for {phase_str}. (Reason: {reason})") + return True + elif supports_trtllm is False: + print_info(f"{prefix} Using non-TRTLLM for {phase_str}. (Reason: {reason})") + return False + else: + print_info( + f"{prefix} Using TRTLLM for {phase_str}. (Reason: auto-detected.)" + ) + return True if has_flashinfer(): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4174b80ee312..3d4fdfe61472 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -17,6 +17,7 @@ from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor +from vllm import _custom_ops as custom_ops from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, @@ -42,7 +43,10 @@ from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.flashinfer import ( - can_use_trtllm_attention, + check_trtllm_attention_support, + force_use_trtllm_attention, + is_sm90_supported, + is_sm100f_supported, use_trtllm_attention, ) from vllm.utils.math_utils import cdiv @@ -355,26 +359,28 @@ def supports_compute_capability(cls, capability: DeviceCapability) -> bool: @classmethod def supports_sink(cls) -> bool: - """FlashInfer supports sinks when TRTLLM attention is available (SM100).""" - from vllm.utils.flashinfer import ( - force_use_trtllm_attention, - supports_trtllm_attention, - ) + """ + FlashInfer supports sinks when TRTLLM attention is available on both prefill + and decode. + """ - # Respect explicit disable flag (e.g., - # --attention-config.use_trtllm_attention=0) + # If TRTLLM attention is explicitly disabled, sink is not supported. if force_use_trtllm_attention() is False: return False - # Check if TRTLLM is supported on this platform - return supports_trtllm_attention() + # Check if TRTLLM is supported on this platform for both prefill and decode. + prefill_supported, _ = check_trtllm_attention_support( + is_prefill=True, has_sinks=True + ) + decode_supported, _ = check_trtllm_attention_support( + is_prefill=False, has_sinks=True + ) + return prefill_supported and decode_supported @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: - from vllm.platforms import current_platform - - capability = current_platform.get_device_capability() - if capability is not None and capability.major == 10: + # TRTLLM-GEN attention requires NHD layout. + if is_sm100f_supported(): return "HND" return None @@ -383,8 +389,10 @@ def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: class FlashInferMetadata: num_actual_tokens: int # Number of tokens excluding padding. - # The data type of the query - q_data_type: torch.dtype + # The data types of the query for prefill and decode. + # On SM90, these two data types may be different. + q_data_type_prefill: torch.dtype + q_data_type_decode: torch.dtype slot_mapping: torch.Tensor @@ -429,7 +437,6 @@ def __init__( super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config - self.attention_config = vllm_config.attention_config self._workspace_buffer = None self._prefill_wrapper: ( BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None @@ -471,52 +478,29 @@ def __init__( self.compilation_config.max_cudagraph_capture_size, ) + self.dcp_world_size = self.get_dcp_world_size() try: - self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group self.dcp_kv_cache_interleave_size = ( vllm_config.parallel_config.dcp_kv_cache_interleave_size ) except AssertionError: # DCP might not be initialized in testing - self.dcp_world_size = 1 self.dcp_rank = 0 self.dcp_kv_cache_interleave_size = 1 - self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config - ) - - self.num_kv_heads = self.kv_cache_spec.num_kv_heads + self.num_qo_heads = self.get_num_qo_heads(vllm_config) + self.num_kv_heads = self.get_num_kv_heads(self.kv_cache_spec) self.head_dim = self.kv_cache_spec.head_size self.page_size = self.kv_cache_spec.block_size - self.cache_dtype = self.cache_config.cache_dtype - if self.cache_dtype.startswith("fp8"): - self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.cache_dtype - ) - else: - assert self.kv_cache_spec.dtype == self.model_config.dtype - self.kv_cache_dtype = self.kv_cache_spec.dtype - - # Use model dtype as q dtype when TRTLLM attn is not supported, or - # --attention-config.disable_flashinfer_q_quantization is set to 1. Otherwise, - # try to use fp8 q if kv cache is fp8, and will fall back to model dtype - # if TRTLLM attention kernel is not used when building attn metadata - can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) - if ( - can_use_trtllm - and not vllm_config.attention_config.disable_flashinfer_q_quantization - ): - self.q_data_type = self.kv_cache_dtype - else: - self.q_data_type = self.model_config.dtype - - # Prefer TRTLLM attention for decoding in all cases. - # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode. - self.use_trtllm_decode_attention = can_use_trtllm - self._init_reorder_batch_threshold(1, supports_spec_as_decode=can_use_trtllm) + self.kv_cache_dtype = self.get_kv_cache_dtype(vllm_config, self.kv_cache_spec) + self.q_data_type_prefill = self.get_q_data_type( + vllm_config, self.kv_cache_spec, is_prefill=True, + ) + self.q_data_type_decode = self.get_q_data_type( + vllm_config, self.kv_cache_spec, is_prefill=False, + ) self._cascade_wrapper = None # Wrapper for cascade attention @@ -529,12 +513,35 @@ def __init__( self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap self.has_sinks = self.global_hyperparameters.has_sinks - if self.has_sinks and not can_use_trtllm: - raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs." - ) + self.has_spec = self.get_has_spec(vllm_config) + + # Decide whether to use TRTLLM attention for prefill and decode. + self.prefill_use_trtllm = use_trtllm_attention( + is_prefill=True, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + dcp_world_size=self.dcp_world_size, + kv_cache_dtype=self.kv_cache_dtype, + q_data_type=self.q_data_type_prefill, + has_sinks=self.has_sinks, + has_spec=self.has_spec, + ) + self.decode_use_trtllm = use_trtllm_attention( + is_prefill=False, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + dcp_world_size=self.dcp_world_size, + kv_cache_dtype=self.kv_cache_dtype, + q_data_type=self.q_data_type_decode, + has_sinks=self.has_sinks, + has_spec=self.has_spec, + ) + + # Only TRTLLM attention supports spec-as-decode. + self._init_reorder_batch_threshold( + 1, supports_spec_as_decode=self.decode_use_trtllm + ) + # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=self.device @@ -564,14 +571,68 @@ def __init__( ) self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() - if self.head_dim == 256 and current_platform.is_device_capability(100): - # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that - # head size 256 and block size 16 is not supported on blackwell. - assert kv_cache_spec.block_size != 16, ( - "There is a bug in FlashInfer " - "block_size 16 head size 256 support. Please avoid this combination by " - "passing --block-size 32 or --block-size 64." - ) + # The class methods below are helper functions to extract config values + # from the configs to avoid duplciated code between __init__() and + # get_cudagraph_support(). + @classmethod + def get_num_qo_heads(cls, vllm_config: VllmConfig) -> int: + return vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + + @classmethod + def get_num_kv_heads(cls, kv_cache_spec: AttentionSpec) -> int: + return kv_cache_spec.num_kv_heads + + @classmethod + def get_dcp_world_size(cls) -> int: + try: + return get_dcp_group().world_size + except AssertionError: + # DCP might not be initialized in testing + return 1 + + @classmethod + def get_kv_cache_dtype( + cls, vllm_config: VllmConfig, kv_cache_spec: AttentionSpec + ) -> torch.dtype: + cache_dtype = vllm_config.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + return FlashInferBackend.get_fp8_dtype_for_flashinfer(cache_dtype) + + assert kv_cache_spec.dtype == vllm_config.model_config.dtype + return kv_cache_spec.dtype + + @classmethod + def get_q_data_type( + cls, vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, is_prefill: bool, + ) -> torch.dtype: + # The user sets --attention-config.disable_flashinfer_q_quantization to 1 + # explicitly, use model dtype for query. + if vllm_config.attention_config.disable_flashinfer_q_quantization: + return vllm_config.model_config.dtype + + # On SM90, if kv-cache is FP8, use FP8-Q for prefill and BF16/FP16-Q for decode, + # except when TRTLLM is disabled. + cache_dtype = vllm_config.cache_config.cache_dtype + if ( + is_sm90_supported() + and not is_prefill + and force_use_trtllm_attention() is not False + and cache_dtype.startswith("fp8") + ): + return vllm_config.model_config.dtype + + # Otherwise, always match q dtype to kv cache dtype. + return cls.get_kv_cache_dtype(vllm_config, kv_cache_spec) + + @classmethod + def get_has_spec(cls, vllm_config: VllmConfig) -> bool: + speculative_config = vllm_config.speculative_config + return ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ) @classmethod def get_cudagraph_support( @@ -579,13 +640,30 @@ def get_cudagraph_support( vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, ) -> AttentionCGSupport: - has_trtllm_support = can_use_trtllm_attention( - num_qo_heads=vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ), - num_kv_heads=kv_cache_spec.num_kv_heads, + # Extract necessary information for checking TRTLLM attention support. + num_qo_heads = cls.get_num_qo_heads(vllm_config) + num_kv_heads = cls.get_num_kv_heads(kv_cache_spec) + dcp_world_size = cls.get_dcp_world_size() + kv_cache_dtype = cls.get_kv_cache_dtype(vllm_config, kv_cache_spec) + q_data_type = cls.get_q_data_type(vllm_config, kv_cache_spec, is_prefill=False) + # Set has_sinks to True would force using TRTLLM attention. + # Be conservative here. + has_sinks = False + has_spec = cls.get_has_spec(vllm_config) + decode_use_trtllm = use_trtllm_attention( + num_qo_heads, + num_kv_heads, + dcp_world_size, + kv_cache_dtype, + q_data_type, + False, + has_sinks, + has_spec, + silent=True, ) - if has_trtllm_support: + # The difference between UNIFORM_BATCH and UNIFORM_SINGLE_TOKEN_DECODE is only + # related to decode phase. + if decode_use_trtllm: return AttentionCGSupport.UNIFORM_BATCH else: return AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE @@ -771,29 +849,13 @@ def build( ) uses_spec_reorder = self.reorder_batch_threshold > 1 - prefill_use_trtllm = use_trtllm_attention( - self.num_qo_heads, - self.num_kv_heads, - num_prefill_tokens, - max_seq_len, - self.dcp_world_size, - self.cache_dtype, - self.q_data_type, - is_prefill=True, - force_use_trtllm=self.attention_config.use_trtllm_attention, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder, - ) - decode_use_trtllm = ( - self.use_trtllm_decode_attention and self.dcp_world_size <= 1 - ) - - if not (prefill_use_trtllm and decode_use_trtllm): + both_use_trtllm = self.prefill_use_trtllm and self.decode_use_trtllm + if not both_use_trtllm: if self.has_sinks: raise NotImplementedError( - "FlashInfer backend currently does not support attention " - "sinks, please use trtllm on blackwell or flash attention " - "on earlier GPUs." + "Non-TRTLLM attention backend currently does not support attention " + "sinks, please use a platform supported by TRTLLM attention or use " + "flash attention instead." ) if not self.global_hyperparameters.has_same_window_lefts: @@ -803,27 +865,24 @@ def build( ) assert self.global_hyperparameters.has_same_all_params, ( - "FlashInfer backend currently only supports models in which " + "Non-TRTLLM attention backend currently only supports models in which " "all layers share the same values for the following " "hyperparameters: `window_left`, `logits_soft_cap`, " "`sm_scale`." ) - # The q quantization is not supported for non-trtllm attention, - # fall back to model dtype. - self.q_data_type = self.model_config.dtype - attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - q_data_type=self.q_data_type, + q_data_type_prefill=self.q_data_type_prefill, + q_data_type_decode=self.q_data_type_decode, slot_mapping=common_attn_metadata.slot_mapping, max_q_len=max_q_len, max_q_len_prefill=max_q_len, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, - prefill_use_trtllm=prefill_use_trtllm, - decode_use_trtllm=decode_use_trtllm, + prefill_use_trtllm=self.prefill_use_trtllm, + decode_use_trtllm=self.decode_use_trtllm, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, num_prefills=num_prefills, @@ -836,6 +895,9 @@ def build( if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() + # Cascade attention must use the same q dtype for prefill and decode + # because it does not support FP8 kv-cache or FP8 query yet. + assert self.q_data_type_prefill == self.q_data_type_decode attn_metadata.cascade_wrapper.plan( [shared_qo_indptr_cpu, qo_indptr_cpu], [shared_kv_page_indptr_cpu, paged_kv_indptr_cpu], @@ -849,7 +911,7 @@ def build( sm_scale=self.sm_scale, window_left=self.window_left, logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.q_data_type_prefill, kv_data_type=self.kv_cache_dtype, ) else: @@ -900,7 +962,7 @@ def build( sm_scale=self.sm_scale, window_left=self.window_left, logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.q_data_type_prefill, kv_cache_dtype=self.kv_cache_dtype, prefill_fixed_split_size=self.prefill_fixed_split_size, disable_split_kv=self.disable_split_kv, @@ -923,8 +985,9 @@ def build( sm_scale=self.sm_scale, window_left=self.window_left, logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.q_data_type_prefill, kv_data_type=self.kv_cache_dtype, + o_data_type=self.model_config.dtype, fixed_split_size=self.prefill_fixed_split_size, disable_split_kv=self.disable_split_kv, ) @@ -967,8 +1030,9 @@ def build( sm_scale=self.sm_scale, window_left=self.window_left, logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, + q_data_type=self.q_data_type_decode, kv_data_type=self.kv_cache_dtype, + o_data_type=self.model_config.dtype, fixed_split_size=self.decode_fixed_split_size, disable_split_kv=self.disable_split_kv, ) @@ -1039,10 +1103,12 @@ def __init__( ) self.sinks = sinks - self.support_trtllm_attn = can_use_trtllm_attention(num_heads, num_kv_heads) vllm_config = get_current_vllm_config() self.supports_quant_query_input = ( - self.support_trtllm_attn + self.kv_cache_dtype.startswith("fp8") + # For SM90, prefill needs FP8 query but decode needs BF16/FP16-Q. + # Therefore, set to False and do the quant inside forward() instead. + and is_sm100f_supported() and not vllm_config.attention_config.disable_flashinfer_q_quantization ) self.bmm1_scale: float | None = None @@ -1050,9 +1116,26 @@ def __init__( self.o_sf_scale: float | None = None def fused_output_quant_supported(self, quant_key: QuantKey): + prefill_use_trtllm = use_trtllm_attention( + is_prefill=True, + num_qo_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + kv_cache_dtype=self.kv_cache_dtype, + ) + decode_use_trtllm = use_trtllm_attention( + is_prefill=False, + num_qo_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + kv_cache_dtype=self.kv_cache_dtype, + ) return ( - self.support_trtllm_attn + # Only TRTLLM attention supports FP8/NVFP4 output. + prefill_use_trtllm + and decode_use_trtllm + # kv-cache must be FP8. and self.kv_cache_dtype.startswith("fp8") + # XQA does not support FP8/NVFP4 output. + and is_sm100f_supported() and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) ) @@ -1061,6 +1144,29 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): if self.sinks is not None and self.sinks.dtype != torch.float32: self.sinks = self.sinks.to(torch.float32) + # Helper function to quantize query to the expected dtype if needed. + # In general, this quantization should be handled outside of the forward() and + # self.supports_quant_query_input should be set to True so that query is already + # quantized in forward(). However, if prefill and decode require different query + # dtypes, we need to quantize the query in forward() instead. + def maybe_quant_query( + self, query: torch.Tensor, q_data_type: torch.dtype, scale: torch.Tensor, + ) -> torch.Tensor: + if query.dtype != q_data_type: + assert query.dtype in [torch.float16, torch.bfloat16] + assert q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2] + assert query.is_contiguous() + assert query.dim() == 3 + num_tokens = query.shape[0] + num_heads = query.shape[1] + head_size = query.shape[2] + query_quantized, _ = custom_ops.scaled_fp8_quant( + query.view(num_tokens, num_heads * head_size), scale=scale + ) + return query_quantized.view(num_tokens, num_heads, head_size) + + return query + def forward( self, layer: torch.nn.Module, @@ -1092,12 +1198,6 @@ def forward( # Profiling run. return output.fill_(0) - # Ensure query dtype matches the expected dtype from attention metadata - assert attn_metadata.q_data_type == query.dtype, ( - f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " - f"got {query.dtype}" - ) - if self.bmm1_scale is None: self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale @@ -1110,7 +1210,10 @@ def forward( "output_block_scale is not supported when fusion has not happened" ) else: - assert attn_metadata.q_data_type == FP8_DTYPE, ( + assert attn_metadata.q_data_type_prefill == FP8_DTYPE, ( + "Query must be FP8 when attn+quant fusion happened." + ) + assert attn_metadata.q_data_type_decode == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." ) assert ( @@ -1205,6 +1308,11 @@ def forward( assert prefill_query.shape[0] == num_prefill_tokens assert prefill_wrapper is not None + # Convert query to the expected dtype for prefill if needed. + prefill_query = self.maybe_quant_query( + prefill_query, attn_metadata.q_data_type_prefill, layer._q_scale, + ) + if not attn_metadata.prefill_use_trtllm: if self.dcp_world_size > 1: assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper) @@ -1242,6 +1350,7 @@ def forward( prefill_wrapper.run( prefill_query, kv_cache_permute, + q_scale=layer._q_scale_float, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], @@ -1274,7 +1383,7 @@ def forward( out = output[num_decode_tokens:] if ( - attn_metadata.q_data_type != FP8_DTYPE + attn_metadata.q_data_type_prefill != FP8_DTYPE and self.kv_cache_dtype.startswith("fp8") ): # TRTLLM prefill attention does not support BF16 Q @@ -1286,7 +1395,7 @@ def forward( block_tables_prefill, layer._k_scale, layer._v_scale, - attn_metadata.q_data_type, + attn_metadata.q_data_type_prefill, ) else: mock_kv_cache = kv_cache_permute @@ -1317,6 +1426,11 @@ def forward( assert decode_query.shape[0] == num_decode_tokens assert decode_wrapper is not None + # Convert query to the expected dtype for decode if needed. + decode_query = self.maybe_quant_query( + decode_query, attn_metadata.q_data_type_decode, layer._q_scale, + ) + if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) @@ -1335,6 +1449,7 @@ def forward( decode_wrapper.run( decode_query, kv_cache_permute, + q_scale=layer._q_scale_float, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output_tmp, @@ -1351,6 +1466,7 @@ def forward( decode_wrapper.run( decode_query, kv_cache_permute, + q_scale=layer._q_scale_float, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[:num_decode_tokens], @@ -1364,8 +1480,10 @@ def forward( ] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] - # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND - assert get_kv_cache_layout() == "HND" + # TRTLLM attention needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + # on sm100f GPUs. + if is_sm100f_supported(): + assert get_kv_cache_layout() == "HND" assert decode_query.is_contiguous() assert kv_cache_permute.is_contiguous() assert workspace_buffer.is_contiguous() @@ -1404,6 +1522,7 @@ def forward( sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, + kv_layout=get_kv_cache_layout(), q_len_per_req=q_len_per_req, ) return output_padded @@ -1424,6 +1543,7 @@ def fast_plan_decode( logits_soft_cap: float | None = None, q_data_type: str | torch.dtype | None = "float16", kv_data_type: str | torch.dtype | None = None, + o_data_type: str | torch.dtype | None = None, data_type: str | torch.dtype | None = None, sm_scale: float | None = None, rope_scale: float | None = None, @@ -1462,6 +1582,7 @@ def fast_plan_decode( logits_soft_cap, q_data_type, kv_data_type, + o_data_type, data_type, sm_scale, rope_scale, @@ -1481,24 +1602,6 @@ def fast_plan_decode( if logits_soft_cap is None: logits_soft_cap = 0.0 - # Handle data types consistently - if data_type is not None: - if q_data_type is None: - q_data_type = data_type - if kv_data_type is None: - kv_data_type = data_type - elif q_data_type is None: - q_data_type = "float16" - - if kv_data_type is None: - kv_data_type = q_data_type - q_data_type = ( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type - ) - kv_data_type = ( - getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type - ) - if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " @@ -1518,8 +1621,9 @@ def fast_plan_decode( qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: - # Make sure we pass exactly 19 arguments for tensor core version - self._plan_info = self._cached_module.plan( + # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for + # fa3 backend + args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, @@ -1536,9 +1640,13 @@ def fast_plan_decode( head_dim, False, # causal window_left, - fixed_split_size, - disable_split_kv, - 0, + ] + if self._backend == "fa2": + args.append(fixed_split_size) + args.append(disable_split_kv) + args.append(0) # num_colocated_ctas + self._plan_info = self._cached_module.plan( + *args, ) except Exception as e: raise RuntimeError(f"Error in tensor core plan: {e}") from e