diff --git a/README.md b/README.md index 13c1ac75..0e67e552 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n ### Setup -1. Install vllm from the official [v0.13.0](https://github.com/vllm-project/vllm/tree/v0.13.0) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL). +1. Install vllm from the official [v0.18.1](https://github.com/vllm-project/vllm/tree/v0.18.1) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL). 2. Install vllm-plugin-FL @@ -66,6 +66,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n ```sh git clone https://github.com/flagos-ai/FlagGems + git checkout v5.0.0 cd FlagGems pip install --no-build-isolation . # or editble install diff --git a/tests/unit_tests/ops/test_layernorm.py b/tests/unit_tests/ops/test_layernorm.py index 4119a3d0..04acfe5a 100644 --- a/tests/unit_tests/ops/test_layernorm.py +++ b/tests/unit_tests/ops/test_layernorm.py @@ -13,6 +13,11 @@ class TestRMSNormFL: """Test RMSNormFL class behavior.""" + def __init__(self): + from vllm.config import VllmConfig, set_current_vllm_config + + set_current_vllm_config(VllmConfig()) + @pytest.fixture def mock_call_op(self): with patch("vllm_fl.ops.layernorm.call_op") as mock: diff --git a/tests/unit_tests/worker/test_model_runner.py b/tests/unit_tests/worker/test_model_runner.py index 2e844470..d1e8f998 100644 --- a/tests/unit_tests/worker/test_model_runner.py +++ b/tests/unit_tests/worker/test_model_runner.py @@ -60,6 +60,7 @@ def test_fields_match_expected_contract(self): "aux_hidden_states", "ec_connector_output", "cudagraph_stats", + "slot_mappings", ) assert ExecuteModelState._fields == expected_fields, ( "ExecuteModelState fields changed - this may break execute_model consumers" @@ -79,6 +80,7 @@ def test_immutability_prevents_accidental_mutation(self): aux_hidden_states=None, ec_connector_output=None, cudagraph_stats=None, + slot_mappings=None, ) with pytest.raises(AttributeError): @@ -101,6 +103,7 @@ def test_unpacking_for_downstream_processing(self): aux_hidden_states=None, ec_connector_output=None, cudagraph_stats=None, + slot_mappings=None, ) # Simulate downstream unpacking diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index cb612ed9..f53973ce 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -45,75 +45,19 @@ def register(): def register_model(): - """Register the FL model.""" - from vllm import ModelRegistry - import vllm.model_executor.models.qwen3_next as qwen3_next_module + """Register FL-specific models not yet upstream.""" + # Models now upstream in vLLM v0.18.1 (no longer need plugin registration): + # Qwen3NextForCausalLM, Qwen3_5MoeForConditionalGeneration, + # MiniCPMO, KimiK25ForConditionalGeneration, Qwen3_5MoeConfig - # Register Qwen3.5 MoE config - try: - from vllm.transformers_utils.config import _CONFIG_REGISTRY - from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig - _CONFIG_REGISTRY["qwen3_5_moe"] = Qwen3_5MoeConfig - except Exception as e: - logger.error(f"Register Qwen3.5 MoE config error: {str(e)}") - - # Register Qwen3Next model - try: - from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401 - - qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM - logger.warning( - "Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, " - "original vLLM implementation is overridden" - ) - - ModelRegistry.register_model( - "Qwen3NextForCausalLM", - "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM" - ) - except Exception as e: - logger.error(f"Register Qwen3Next model error: {str(e)}") - - # Register Qwen3.5 MoE model - try: - ModelRegistry.register_model( - "Qwen3_5MoeForConditionalGeneration", - "vllm_fl.models.qwen3_5:Qwen3_5MoeForConditionalGeneration" - ) - except Exception as e: - logger.error(f"Register Qwen3.5 MoE model error: {str(e)}") - - # Register MiniCPMO model - try: - ModelRegistry.register_model( - "MiniCPMO", - "vllm_fl.models.minicpmo:MiniCPMO" - ) - except Exception as e: - logger.error(f"Register MiniCPMO model error: {str(e)}") - - # Register Kimi-K2.5 model - try: - ModelRegistry.register_model( - "KimiK25ForConditionalGeneration", - "vllm_fl.models.kimi_k25:KimiK25ForConditionalGeneration", - ) - except Exception as e: - logger.error(f"Register KimiK25 model error: {str(e)}") - - # Register GLM-5 (GlmMoeDsa) model + # Register GLM-5 (GlmMoeDsa) — config not yet upstream try: from vllm.transformers_utils.config import _CONFIG_REGISTRY from vllm_fl.configs.glm_moe_dsa import GlmMoeDsaConfig _CONFIG_REGISTRY["glm_moe_dsa"] = GlmMoeDsaConfig - from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model - glm5_model() - - ModelRegistry.register_model( - "GlmMoeDsaForCausalLM", - "vllm_fl.models.glm_moe_dsa:GlmMoeDsaForCausalLM" - ) + #from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model + #glm5_model() except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") diff --git a/vllm_fl/attention/utils.py b/vllm_fl/attention/utils.py index 642e7800..88982dfc 100644 --- a/vllm_fl/attention/utils.py +++ b/vllm_fl/attention/utils.py @@ -13,8 +13,8 @@ def patch_mm_encoder_attention(): FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a fallback to flash_attn. """ - import vllm.attention.layers.mm_encoder_attention as mm_mod - from vllm.attention.backends.registry import AttentionBackendEnum + import vllm.model_executor.layers.attention.mm_encoder_attention as mm_mod + from vllm.v1.attention.backends.registry import AttentionBackendEnum def _patched_maybe_get_vit_flash_attn_backend(attn_backend): if attn_backend == AttentionBackendEnum.FLASH_ATTN: diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index 2ff4eb8e..111aa9d6 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -1,14 +1,15 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/compilation/cuda_graph.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/compilation/cuda_graph.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import weakref from collections import Counter from collections.abc import Callable from contextlib import ExitStack -from typing import Any, Optional +from typing import Any, ClassVar from unittest.mock import patch import torch @@ -18,13 +19,18 @@ from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import ( + BatchDescriptor, + get_forward_context, + is_forward_context_available, +) from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) +# FL-specific: platform-agnostic weak_ref_tensors def weak_ref_tensors(tensor: Any) -> Any: if current_platform.device_type == "cuda": from vllm.utils.torch_utils import weak_ref_tensors @@ -34,6 +40,7 @@ def weak_ref_tensors(tensor: Any) -> Any: return tensor +# FL-specific: platform-agnostic graph class selection class Graph: if current_platform.device_type == "cuda": graph = torch.cuda.CUDAGraph @@ -44,15 +51,21 @@ class Graph: else: raise NotImplementedError("not support graph") + +# Re-export CUDAGraphStat for compatibility +from vllm.compilation.cuda_graph import CUDAGraphStat # noqa: F401, E402 + + @dataclasses.dataclass class GraphEntry: batch_descriptor: BatchDescriptor - graph: Optional[Graph] = None - output: Optional[Any] = None + graph: Any | None = None + output: Any | None = None # for graph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: list[int] | None = None + @dataclasses.dataclass class GraphOptions: @@ -62,11 +75,22 @@ class GraphOptions: class GraphWrapper: + """FL-specific graph wrapper that supports multiple device types (CUDA, NPU). + Adapted from upstream CUDAGraphWrapper with platform-agnostic graph capture.""" + + _all_instances: ClassVar[weakref.WeakSet["GraphWrapper"]] = weakref.WeakSet() + + @classmethod + def clear_all_graphs(cls) -> None: + """Clear captured graphs from all GraphWrapper instances.""" + for instance in list(cls._all_instances): + instance.clear_graphs() + def __init__(self, runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - cudagraph_options: Optional[GraphOptions] = None): + cudagraph_options: GraphOptions | None = None): self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -74,9 +98,10 @@ def __init__(self, self.first_run_finished = False self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + self._runnable_str = str(runnable) if self.is_debugging_mode else None # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't - # need to initialize a CUDAGraphWrapper. + # need to initialize a GraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -85,25 +110,41 @@ def __init__(self, if cudagraph_options is None: cudagraph_options = GraphOptions() - self.graph_options = cudagraph_options + self.cudagraph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # cudagraphs for. self.concrete_graph_entries: dict[BatchDescriptor, GraphEntry] = {} - def __getattr__(self, key: str): + GraphWrapper._all_instances.add(self) + + def __getattr__(self, key: str) -> Any: # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError( - f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}" - ) + if self.is_debugging_mode: + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self._runnable_str}" + ) + raise AttributeError def unwrap(self) -> Callable: # in case we need to access the original runnable. return self.runnable + @property + def cudagraph_wrapper(self) -> "GraphWrapper": + return self + + def clear_graphs(self) -> None: + self.concrete_graph_entries.clear() + def __call__(self, *args, **kwargs): + if not is_forward_context_available(): + # No forward context means we are outside the normal + # inference path (e.g. a vision encoder forward pass). + return self.runnable(*args, **kwargs) + forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor graph_runtime_mode = forward_context.cudagraph_runtime_mode @@ -112,14 +153,9 @@ def __call__(self, *args, **kwargs): graph_runtime_mode == CUDAGraphMode.NONE or graph_runtime_mode != self.runtime_mode ): - # CUDAGraphMode.NONE could mean the profile run, a warmup run, or - # running without cudagraphs. - # We do not trigger capture/replay if the runtime mode is not - # matches. This enables properly dispatching to the correct - # CUDAGraphWrapper when nesting multiple instances with different - # runtime modes. return self.runnable(*args, **kwargs) + assert batch_descriptor is not None if batch_descriptor not in self.concrete_graph_entries: # create a new entry for this batch descriptor self.concrete_graph_entries[batch_descriptor] = GraphEntry( @@ -129,11 +165,7 @@ def __call__(self, *args, **kwargs): entry = self.concrete_graph_entries[batch_descriptor] if entry.graph is None: - if self.graph_options.debug_log_enable: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every - # shape. E.g. we only log it for the first subgraph in - # piecewise mode. + if self.cudagraph_options.debug_log_enable: logger.debug( "Capturing a cudagraph on (%s,%s)", self.runtime_mode.name, @@ -149,32 +181,40 @@ def __call__(self, *args, **kwargs): graph = Graph.graph() with ExitStack() as stack: - if self.graph_options.gc_disable: - # during every model forward for piecewise graph - # mode, we will capture many pieces of graphs - # (roughly one per layer). running gc again and again - # across layers will make the graph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. + if self.cudagraph_options.gc_disable: stack.enter_context(patch("gc.collect", lambda: None)) + # FL-specific: patch our platform's empty_cache stack.enter_context( - patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None) + patch("vllm_fl.platform.PlatformFL.empty_cache", + lambda: None) ) - set_graph_pool_id(self.graph_pool) - - # mind-exploding: carefully manage the reference and memory. - with current_platform.torch_device_fn.graph(graph, pool=self.graph_pool): + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) + + # Sync offloader's copy stream before capture if available. + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().sync_prev_onload() + except (ImportError, RuntimeError): + pass + + # FL-specific: use platform-agnostic graph capture + with current_platform.torch_device_fn.graph( + graph, pool=self.graph_pool + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) - if self.graph_options.weak_ref_output: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph in piecewise cuadgraph mode, because - # the output of the last graph will not be used by - # any other cuda graph. - output = weak_ref_tensors(output) + # Join offloader's copy stream after forward if available + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().join_after_forward() + except (ImportError, RuntimeError): + pass + if self.cudagraph_options.weak_ref_output: + output = weak_ref_tensors(output) entry.output = weak_ref_tensors(output) entry.graph = graph @@ -197,6 +237,13 @@ def __call__(self, *args, **kwargs): f"got {new_input_addresses}" ) + # Sync offloader before replay if available + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().sync_prev_onload() + except (ImportError, RuntimeError): + pass + current_platform.torch_device_fn.synchronize() entry.graph.replay() return entry.output diff --git a/vllm_fl/configs/qwen3_5_moe.py b/vllm_fl/configs/qwen3_5_moe.py deleted file mode 100644 index 48d68865..00000000 --- a/vllm_fl/configs/qwen3_5_moe.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. -# All rights reserved. -"""Qwen3.5-MoE model configuration for vLLM plugin.""" - -from transformers.configuration_utils import PretrainedConfig - - -def _layer_type_validation(layer_types, num_hidden_layers): - if layer_types is not None and num_hidden_layers is not None: - if len(layer_types) != num_hidden_layers: - raise ValueError( - f"Length of layer_types ({len(layer_types)}) must match " - f"num_hidden_layers ({num_hidden_layers})" - ) - - -class Qwen3_5MoeTextConfig(PretrainedConfig): - model_type = "qwen3_5_moe_text" - keys_to_ignore_at_inference = ["past_key_values"] - base_config_key = "text_config" - - def __init__( - self, - vocab_size=248320, - hidden_size=2048, - num_hidden_layers=40, - num_attention_heads=16, - num_key_value_heads=2, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_parameters=None, - attention_bias=False, - attention_dropout=0.0, - head_dim=256, - linear_conv_kernel_dim=4, - linear_key_head_dim=128, - linear_value_head_dim=128, - linear_num_key_heads=16, - linear_num_value_heads=32, - moe_intermediate_size=512, - shared_expert_intermediate_size=512, - num_experts_per_tok=8, - num_experts=256, - norm_topk_prob=True, - output_router_logits=False, - router_aux_loss_coef=0.001, - layer_types=None, - full_attention_interval=4, - attn_output_gate=True, - pad_token_id=None, - bos_token_id=None, - eos_token_id=None, - **kwargs, - ): - kwargs.pop("ignore_keys_at_rope_validation", None) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.head_dim = head_dim - self.rope_parameters = rope_parameters - self.attn_output_gate = attn_output_gate - - self.layer_types = layer_types - if self.layer_types is None: - self.layer_types = [ - "linear_attention" - if bool((i + 1) % full_attention_interval) - else "full_attention" - for i in range(self.num_hidden_layers) - ] - _layer_type_validation(self.layer_types, self.num_hidden_layers) - - self.linear_conv_kernel_dim = linear_conv_kernel_dim - self.linear_key_head_dim = linear_key_head_dim - self.linear_value_head_dim = linear_value_head_dim - self.linear_num_key_heads = linear_num_key_heads - self.linear_num_value_heads = linear_num_value_heads - self.moe_intermediate_size = moe_intermediate_size - self.shared_expert_intermediate_size = shared_expert_intermediate_size - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.full_attention_interval = full_attention_interval - - # partial_rotary_factor is needed by rope - kwargs.setdefault("partial_rotary_factor", 0.25) - - super().__init__(**kwargs) - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - - -class Qwen3_5MoeVisionConfig(PretrainedConfig): - model_type = "qwen3_5_moe" - base_config_key = "vision_config" - - def __init__( - self, - depth=27, - hidden_size=1152, - hidden_act="gelu_pytorch_tanh", - intermediate_size=4304, - num_heads=16, - in_channels=3, - patch_size=16, - spatial_merge_size=2, - temporal_patch_size=2, - out_hidden_size=3584, - num_position_embeddings=2304, - initializer_range=0.02, - deepstack_visual_indexes=None, - **kwargs, - ): - super().__init__(**kwargs) - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.out_hidden_size = out_hidden_size - self.num_position_embeddings = num_position_embeddings - self.initializer_range = initializer_range - self.deepstack_visual_indexes = deepstack_visual_indexes or [] - - -class Qwen3_5MoeConfig(PretrainedConfig): - model_type = "qwen3_5_moe" - sub_configs = { - "vision_config": Qwen3_5MoeVisionConfig, - "text_config": Qwen3_5MoeTextConfig, - } - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - text_config=None, - vision_config=None, - image_token_id=248056, - video_token_id=248057, - vision_start_token_id=248053, - vision_end_token_id=248054, - tie_word_embeddings=False, - **kwargs, - ): - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - - if isinstance(text_config, dict): - self.text_config = self.sub_configs["text_config"](**text_config) - elif text_config is None: - self.text_config = self.sub_configs["text_config"]() - - self.image_token_id = image_token_id - self.video_token_id = video_token_id - self.vision_start_token_id = vision_start_token_id - self.vision_end_token_id = vision_end_token_id - super().__init__(**kwargs) - self.tie_word_embeddings = tie_word_embeddings - - -__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig", "Qwen3_5MoeVisionConfig"] diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index e12afaf7..a7b166ea 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -144,7 +144,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum # TritonAttentionBackend requires CUDA, check if available if not torch.cuda.is_available(): diff --git a/vllm_fl/dispatch/backends/flaggems/impl/attention.py b/vllm_fl/dispatch/backends/flaggems/impl/attention.py index 44c751c4..d7e14dea 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/attention.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/attention.py @@ -11,27 +11,29 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, is_quantized_kv_cache, ) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs -from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.model_executor.layers.attention.attention import Attention +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs +from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.model_executor.layers.batch_invariant import _batch_invariant_MODE as _bi_mode from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, +) +from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, get_dcp_local_seq_lens, get_kv_cache_layout, @@ -457,7 +459,7 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = 3 # 2 #get_flash_attn_version() # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() + self.batch_invariant_enabled = _bi_mode if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( diff --git a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py index 5aa2860d..f6dc37e3 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py @@ -1,4 +1,4 @@ -from vllm.attention.backends.registry import ( +from vllm.v1.attention.backends.registry import ( AttentionBackendEnum, register_backend, ) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/mla.py b/vllm_fl/dispatch/backends/flaggems/impl/mla.py index 866dcd51..1789498d 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/mla.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/mla.py @@ -8,7 +8,7 @@ import torch -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, is_quantized_kv_cache, @@ -17,7 +17,7 @@ # from vllm.attention.ops.triton_decode_attention import decode_attention_fwd # from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import ( +from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBackend, MLACommonImpl, MLACommonMetadata, diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 7f20276c..9c10a2e5 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -150,7 +150,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Fully qualified class path string (vLLM native backend) """ # Return vLLM's native flash attention backend as reference - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: # vLLM native MLA backend diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py index 4dcdd7d7..494cf508 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionLayer, @@ -36,7 +36,8 @@ ) from vllm.config import VllmConfig, get_current_vllm_config from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport, CommonAttentionMetadata +from vllm.v1.attention.backend import AttentionCGSupport +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm_fl.dispatch.backends.vendor.ascend.impl.attention_mask import ( AttentionMaskBuilder, diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py index 2aad980a..a6fa90f2 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py @@ -15,7 +15,7 @@ import torch.nn.functional as F import triton import triton.language as tl -from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import PAD_SLOT_ID def causal_conv1d_ref( diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py index 4e8c0758..78dfcd24 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py @@ -21,7 +21,7 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.config import MultiModalConfig MIN_PAD_SIZE = 64 # min_size to pad weight diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 8260f067..e504559d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -174,7 +174,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: if use_sparse: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index 20f67053..049deb3d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -22,11 +22,9 @@ def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor: Returns: Output tensor of shape [..., d] """ - from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul - d = x.shape[-1] // 2 out = torch.empty(*x.shape[:-1], d, dtype=x.dtype, device=x.device) - vllm_silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) return out diff --git a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py index 4d7d1449..d4896cd4 100644 --- a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py +++ b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py @@ -154,7 +154,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: if use_sparse: diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py index 35cdd6a9..7d10e52f 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py @@ -9,15 +9,15 @@ import numpy as np import torch -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, is_quantized_kv_cache, ) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.model_executor.layers.attention.attention import Attention +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs # -------------------------------------------------------------- # Note: use Maca's merge_attn_states to get cuda kernel invoked @@ -42,13 +42,15 @@ from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, +) +from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, get_dcp_local_seq_lens, get_kv_cache_layout, @@ -56,7 +58,7 @@ reshape_attn_output_for_spec_decode, # used for prefill decode split with mtp reshape_query_for_spec_decode, # used for prefill decode split with mtp ) -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend from vllm.v1.kv_cache_interface import AttentionSpec # -------------------------------------------------------------- @@ -450,7 +452,7 @@ def build( prefill_block_table_tensor = None # \------------------------- Metax Modification -------------------------/ - if vllm_is_batch_invariant(): + if _bi_mode: max_num_splits = 1 def schedule( @@ -645,7 +647,7 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() + self.batch_invariant_enabled = _bi_mode if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py index 3d06c6c3..94ec2e51 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py @@ -198,13 +198,13 @@ from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionLayer, MLAAttentionImpl, ) -from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.model_executor.layers.attention.mla_attention import get_mla_dims +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs # -------------------------------------------------------------- # Note: use Maca's merge_attn_states to get cuda kernel invoked @@ -214,7 +214,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -272,7 +272,9 @@ class QueryLenSupport(Enum): flashinfer_available = False -from vllm.v1.attention.backends.mla.common import logger +from vllm.logger import init_logger + +logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -1276,7 +1278,7 @@ def _flash_attn_varlen_diff_headdims( # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse - if vllm_is_batch_invariant(): + if _bi_mode: kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py index 823ff4ba..6f10c44e 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py @@ -7,7 +7,7 @@ import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from ..ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, @@ -17,7 +17,7 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.platforms.interface import DeviceCapability from .common import ( @@ -28,13 +28,13 @@ MLACommonMetadataBuilder, QueryLenSupport, ) +from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, ) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend logger = init_logger(__name__) @@ -271,7 +271,7 @@ def _forward_decode( tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata num_splits = attn_metadata.decode.num_splits - if vllm_is_batch_invariant(): + if _bi_mode: device = q.device dtype = torch.int32 diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py index 446772c2..c2c51159 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py @@ -43,7 +43,7 @@ def supported_headdim(o: torch.Tensor) -> bool: output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse ) else: - from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states return merge_attn_states( output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py index 56d2c787..b21344c5 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py @@ -2,7 +2,9 @@ # 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.utils.fa_utils import logger +import logging + +logger = logging.getLogger(__name__) from vllm.platforms import current_platform diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py index 66de575f..d9dee3c9 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/metax.py +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -14,7 +14,7 @@ from vllm_fl.dispatch.backends.base import Backend -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend # Register attention backends for MACA @@ -147,7 +147,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum # register before selection register_attention_backends() diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py index ec207a5b..c8a8c990 100644 --- a/vllm_fl/dispatch/builtin_ops.py +++ b/vllm_fl/dispatch/builtin_ops.py @@ -21,7 +21,6 @@ # Directory containing vendor backends _VENDOR_BACKENDS_DIR = os.path.join(os.path.dirname(__file__), "backends", "vendor") - def _find_vendor_backend_dir( vendor_name: str, available_vendor_dirs: set[str], @@ -61,7 +60,6 @@ def _get_current_vendor_backend_dirs(available_vendor_dirs: set[str]) -> set[str "Failed to detect current vendor backend from current_platform." ) from exc - def _register_vendor_backends(registry: OpRegistry) -> None: """ Auto-discover and register all vendor backends. diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py index bafd6d4e..cc4e8b49 100644 --- a/vllm_fl/dispatch/config/__init__.py +++ b/vllm_fl/dispatch/config/__init__.py @@ -14,8 +14,8 @@ get_oot_blacklist, get_per_op_order, get_platform_name, - get_vendor_device_map, load_platform_config, + get_vendor_device_map, ) __all__ = [ diff --git a/vllm_fl/dispatch/config/nvidia.yaml b/vllm_fl/dispatch/config/nvidia.yaml index 3c8676ff..0b06a8a3 100644 --- a/vllm_fl/dispatch/config/nvidia.yaml +++ b/vllm_fl/dispatch/config/nvidia.yaml @@ -10,8 +10,8 @@ prefer: flagos strict: false # Vendor Whitelist (Optional, allows all if not set) -# allow_vendors: -# - cuda +allow_vendors: + - cuda # Vendor Blacklist (Optional) # deny_vendors: diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index 25f1fbfa..087239a1 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -208,7 +208,6 @@ def get_effective_config() -> dict[str, Any]: # Return empty config return {} - def get_vendor_device_map() -> dict[str, dict[str, str]]: """Load vendor mapping from Python config module. diff --git a/vllm_fl/dispatch/logger_manager.py b/vllm_fl/dispatch/logger_manager.py index 37867e8e..57e4a4f0 100644 --- a/vllm_fl/dispatch/logger_manager.py +++ b/vllm_fl/dispatch/logger_manager.py @@ -38,7 +38,7 @@ def get_logger(name: str = "vllm_fl.dispatch") -> logging.Logger: if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( - "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", + "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) diff --git a/vllm_fl/models/glm_moe_dsa.py b/vllm_fl/models/glm_moe_dsa.py deleted file mode 100644 index 4b88c4d0..00000000 --- a/vllm_fl/models/glm_moe_dsa.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Inference-only GLM-5 (GlmMoeDsa) model. - -GLM-5 uses a DeepSeek V2/V3-style architecture with MLA (Multi-head Latent -Attention) and Mixture of Experts. The HF model type is ``glm_moe_dsa`` and -the architecture class is ``GlmMoeDsaForCausalLM``. - -This thin wrapper inherits from vLLM's ``DeepseekV2ForCausalLM`` which already -handles MLA, MoE, the DSA Indexer, and MTP speculative decoding layers. -""" - -from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM - - -class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM): - """GLM-5 model for causal language modelling.""" - pass diff --git a/vllm_fl/models/kimi_k25.py b/vllm_fl/models/kimi_k25.py deleted file mode 100644 index 894839e5..00000000 --- a/vllm_fl/models/kimi_k25.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal Kimi-K2.5 model support for text-only inference. - -This is a simplified implementation that wraps DeepseekV3 for text-only -benchmarking. Vision components are not included. -""" - -import copy -from collections.abc import Iterable - -import torch -from torch import nn - -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Model, - get_spec_layer_idx_from_weight_name, -) -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.utils import ( - PPMissingLayer, - is_pp_missing_parameter, - maybe_prefix, -) -from vllm.sequence import IntermediateTensors - -logger = init_logger(__name__) - - -class KimiK25ForConditionalGeneration(nn.Module, SupportsPP): - """Kimi-K2.5 model for text-only conditional generation. - - This is a minimal implementation that uses DeepseekV2Model as the - language backbone. Vision components are not included for simplicity. - """ - - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - super().__init__() - model_config = vllm_config.model_config - config = model_config.hf_config - self.config = config - quant_config = vllm_config.quant_config - - # Extract text config from KimiK25Config - # The text_config is a DeepseekV3Config - text_config = getattr(config, "text_config", config) - self.hidden_size = text_config.hidden_size - - # Create a modified vllm_config with text_config as hf_config - sub_vllm_config = copy.deepcopy(vllm_config) - sub_vllm_config.model_config.hf_config = text_config - - # Build language model using DeepseekV2Model - self.language_model = DeepseekV2Model( - vllm_config=sub_vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - # Build lm_head - if get_pp_group().is_last_rank: - vocab_size = getattr(config, "vocab_size", text_config.vocab_size) - self.lm_head = ParallelLMHead( - vocab_size, - text_config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - logit_scale = getattr(config, "logit_scale", 1.0) - vocab_size = getattr(config, "vocab_size", text_config.vocab_size) - self.logits_processor = LogitsProcessor(vocab_size, scale=logit_scale) - - def get_language_model(self) -> nn.Module: - return self.language_model - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Apply token embeddings to input_ids.""" - return self.language_model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - hidden_states = self.language_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) - return logits - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - """Get expert parameter mapping for MoE layers.""" - text_config = getattr(self.config, "text_config", self.config) - if not getattr(text_config, "n_routed_experts", None): - return [] - return SharedFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=text_config.n_routed_experts, - num_redundant_experts=0, - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - """Load weights with proper name remapping for Kimi-K2.5.""" - text_config = getattr(self.config, "text_config", self.config) - - # Weight name remapping for Kimi-K2.5 -> DeepseekV2 - _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", - } - - stacked_params_mapping = [ - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - if getattr(text_config, "kv_lora_rank", None) and getattr( - text_config, "q_lora_rank", None - ): - stacked_params_mapping += [ - (".fused_qkv_a_proj", ".q_a_proj", 0), - (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), - ] - expert_params_mapping = self.get_expert_mapping() - - params_dict = dict(self.named_parameters()) - - for args in weights: - name, loaded_weight = args[:2] - kwargs = args[2] if len(args) > 2 else {} - - # Skip rotary embedding cached values - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - continue - - # Skip speculative decode layers - spec_layer = get_spec_layer_idx_from_weight_name(text_config, name) - if spec_layer is not None: - continue - - # Skip vision tower weights (not needed for text-only inference) - if "vision_tower" in name or "mm_projector" in name: - continue - - # Apply key remapping - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - - use_default_weight_loading = False - - # Handle stacked parameters - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, **kwargs) - break - else: - # Handle expert parameters - for _, ( - param_name, - weight_name, - expert_id, - shard_id, - ) in enumerate(expert_params_mapping): - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - expert_id=expert_id, - shard_id=shard_id, - **kwargs, - ) - break - else: - use_default_weight_loading = True - - if use_default_weight_loading: - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict.get(name) - if param is None: - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, **kwargs) diff --git a/vllm_fl/models/minicpmo.py b/vllm_fl/models/minicpmo.py deleted file mode 100644 index 5affd08d..00000000 --- a/vllm_fl/models/minicpmo.py +++ /dev/null @@ -1,854 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" - -from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias - -import torch -from torch import nn -from transformers import BatchFeature -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import ( - ACT2FN, - WhisperAttention, - WhisperConfig, - WhisperEncoder, -) - -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - NestedTensors, -) -from vllm.multimodal.parse import ( - AudioItem, - AudioProcessorItems, - DictEmbeddingItems, - ModalityData, - ModalityDataItems, - MultiModalDataItems, - MultiModalDataParser, -) -from vllm.multimodal.processing import ( - PromptReplacement, - PromptUpdate, - PromptUpdateDetails, -) -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from vllm.model_executor.models.minicpmv import ( - _MAX_FRAMES_PER_VIDEO, - MiniCPMV2_6, - MiniCPMV4_5, - MiniCPMVDummyInputsBuilder, - MiniCPMVMultiModalDataParser, - MiniCPMVMultiModalProcessor, - MiniCPMVProcessingInfo, - _minicpmv_field_config, -) -from vllm.model_executor.models.utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix - -CPU_DEVICE = torch.device("cpu") - - -class MiniCPMOAudioFeatureInputs(TensorSchema): - """ - Dimensions: - - bns: Batch size * number of audios * number of slices - - bn: Batch size * number of audios - - c: Number of channels - - l: Length - - s: Number of slices - """ - - type: Literal["audio_features"] = "audio_features" - - audio_features: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bns", "c", "l", dynamic_dims={"l"}), - ] - """ - Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. Padding is used therefore `audio_features` is - `torch.Tensor`. - """ - - audio_feature_lens: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "s"), - ] - """ - This should be feature length of each audio slice, - which equals to `audio_features.shape[-1]` - """ - - -class MiniCPMOAudioEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of audios - - s: Number of slices - - h: Hidden size (must match language model backbone) - - Length of each slice may vary, so pass it as a list. - """ - - type: Literal["audio_embeds"] = "audio_embeds" - - audio_embeds: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "s", "h", dynamic_dims={"s"}), - ] - - -MiniCPMOAudioInputs: TypeAlias = ( - MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs -) - - -def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): - return dict( - **_minicpmv_field_config(hf_inputs), - audio_features=MultiModalFieldConfig.batched("audio"), - audio_feature_lens=MultiModalFieldConfig.batched("audio"), - audio_embeds=MultiModalFieldConfig.batched("audio"), - ) - - -class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): - def __init__( - self, - data: Mapping[str, torch.Tensor], - fields_factory: Callable[ - [Mapping[str, torch.Tensor]], - Mapping[str, MultiModalFieldConfig], - ], - ) -> None: - super().__init__( - data, - modality="image", - required_fields={"audio_embeds"}, - fields_factory=fields_factory, - ) - - -class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): - def _parse_audio_data( - self, - data: dict[str, torch.Tensor] | ModalityData[AudioItem], - ) -> ModalityDataItems[Any, Any] | None: - if isinstance(data, dict): - return MiniCPMOAudioEmbeddingItems( - data, - fields_factory=_minicpmo_field_config, - ) - - return super()._parse_audio_data(data) - - -class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): - audio_pattern = "()" - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {**super().get_supported_mm_limits(), "audio": None} - - def get_audio_placeholder( - self, - audio_lens: int, - chunk_input: bool = True, - chunk_length: int = 1, - ) -> str: - hf_processor = self.get_hf_processor() - - return hf_processor.get_audio_placeholder( - audio_lens, - chunk_input=chunk_input, - chunk_length=chunk_length, - ) - - def get_default_audio_pool_step(self) -> int: - hf_config = self.get_hf_config() - # MiniCPM-o 4.5 uses pool_step=5, older versions use 2 - return getattr(hf_config, "audio_pool_step", 2) - - def get_default_audio_sampling_rate(self) -> int: - return 16000 - - def get_chunk_length(self) -> int: - return self.get_hf_config().audio_chunk_length - - def get_max_audio_tokens_per_chunk(self) -> int: - pool_step = self.get_default_audio_pool_step() - fbank_feat_in_chunk = 100 - cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 - return (cnn_feat_in_chunk - pool_step) // pool_step + 1 - - def get_max_audio_chunks_with_most_features(self) -> int: - return 30 - - def get_max_audio_tokens(self) -> int: - num_chunks = self.get_max_audio_chunks_with_most_features() - return self.get_max_audio_tokens_per_chunk() * num_chunks - - def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: - sampling_rate = self.get_default_audio_sampling_rate() - num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 - - def get_num_frames_with_most_features( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> int: - max_images = mm_counts.get("image", 0) - max_videos = mm_counts.get("video", 0) - max_audios = mm_counts.get("audio", 0) - - max_image_tokens = self.get_max_image_tokens() * max_images - max_audio_tokens = self.get_max_audio_tokens() * max_audios - max_total_frames = self.get_max_video_frames( - seq_len - max_image_tokens - max_audio_tokens - ) - max_frames_per_video = min( - max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO - ) - - return max(max_frames_per_video, 1) - - -class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_audios = mm_counts.get("audio", 0) - - audio_prompt_texts = self.info.audio_pattern * num_audios - - return super().get_dummy_text(mm_counts) + audio_prompt_texts - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_audios = mm_counts.get("audio", 0) - audio_len = ( - self.info.get_max_audio_chunks_with_most_features() - * self.info.get_default_audio_sampling_rate() - ) - - audio_overrides = mm_options.get("audio") if mm_options else None - - audio_mm_data = { - "audio": self._get_dummy_audios( - length=audio_len, num_audios=num_audios, overrides=audio_overrides - ) - } - - return { - **super().get_dummy_mm_data(seq_len, mm_counts, mm_options), - **audio_mm_data, - } - - -class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: - return MiniCPMOMultiModalDataParser( - target_sr=self.info.get_default_audio_sampling_rate() - ) - - def get_audio_prompt_texts( - self, - audio_lens: int, - chunk_input: bool = True, - chunk_length: int = 1, - ) -> str: - return self.info.get_audio_placeholder( - audio_lens, - chunk_input=chunk_input, - chunk_length=chunk_length, - ) - - def process_audios( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: - if (audios := mm_data.get("audios")) is None: - return {} - - parsed_audios = ( - self._get_data_parser() - .parse_mm_data({"audio": audios}) - .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) - ) - - if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): - audio_inputs = {} - else: - audio_inputs = self._base_call_hf_processor( - prompts=[self.info.audio_pattern] * len(parsed_audios), - mm_data={"audios": [[audio] for audio in parsed_audios]}, - mm_kwargs={**mm_kwargs, "chunk_input": True}, - tok_kwargs=tok_kwargs, - out_keys={"audio_features", "audio_feature_lens"}, - ) - - # Avoid padding since we need the output for each audio to be - # independent of other audios for the cache to work correctly - unpadded_audio_features = [ - feat[:, :feature_len] - for feat, feature_len in zip( - audio_inputs["audio_features"], - audio_inputs["audio_feature_lens"], - ) - ] - audio_inputs["audio_features"] = unpadded_audio_features - - return audio_inputs - - def process_mm_inputs( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: - return { - **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs), - **self.process_audios(mm_data, mm_kwargs, tok_kwargs), - } - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - base_updates = super()._get_prompt_updates( - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - out_mm_kwargs=out_mm_kwargs, - ) - - audio_placeholder = self.info.audio_pattern - - def get_audio_replacement(item_idx: int): - audios = mm_items.get_items( - "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems) - ) - - if isinstance(audios, MiniCPMOAudioEmbeddingItems): - single_audio_embeds = audios.get(item_idx)["audio_embeds"] - audio_len = self.info.get_audio_len_by_num_chunks( - sum(map(len, single_audio_embeds)) - ) - else: - audio_len = audios.get_audio_length(item_idx) - - return PromptUpdateDetails.select_text( - self.get_audio_prompt_texts(audio_len), - "", - ) - - return [ - *base_updates, - PromptReplacement( - modality="audio", - target=audio_placeholder, - replacement=get_audio_replacement, - ), - ] - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return _minicpmo_field_config(hf_inputs) - - -class MultiModalProjector(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) - self.relu = nn.ReLU() - self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) - - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.relu(self.linear1(audio_features)) - hidden_states = self.linear2(hidden_states) - return hidden_states - - -class MiniCPMWhisperEncoderLayer(nn.Module): - def __init__(self, config: WhisperConfig, layer_idx: int): - super().__init__() - self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout( - hidden_states, p=self.activation_dropout, training=self.training - ) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16: - hidden_states = cast_overflow_tensors(hidden_states) - - outputs = (hidden_states,) - - return outputs - - -class MiniCPMWhisperEncoder(WhisperEncoder): - def __init__(self, config: WhisperConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [ - MiniCPMWhisperEncoderLayer(config, layer_idx=i) - for i in range(config.encoder_layers) - ] - ) - - def forward( - self, - input_features: torch.Tensor, - attention_mask: torch.Tensor | None = None, - ) -> BaseModelOutputWithPast: - # Ignore copy - input_features = input_features.to( - dtype=self.conv1.weight.dtype, device=self.conv1.weight.device - ) - - inputs_embeds = nn.functional.gelu(self.conv1(input_features)) - inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - - inputs_embeds = inputs_embeds.permute(0, 2, 1) - - embed_pos = self.embed_positions.weight - - embed_pos = embed_pos[: inputs_embeds.shape[1], :] - - hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - encoder_states = () - - for idx, encoder_layer in enumerate(self.layers): - encoder_states = encoder_states + (hidden_states,) - to_drop = False - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: # skip the layer - to_drop = True - - # Ignore copy - if to_drop: - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - ) - - hidden_states = layer_outputs[0] - - hidden_states = self.layer_norm(hidden_states) - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - ) - - -class MiniCPMOBaseModel: - """Base mixin class for MiniCPM-O models with audio support.""" - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "(./)" - if modality.startswith("video"): - return "()" - if modality.startswith("audio"): - return "()" - - raise ValueError("Only image, video or audio modality is supported") - - def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Do not use parameters temporarily - audio_config = self.config.audio_config - model = MiniCPMWhisperEncoder(audio_config) - audio_output_dim = int(audio_config.encoder_ffn_dim // 4) - self.audio_avg_pooler = nn.AvgPool1d( - self.config.audio_pool_step, stride=self.config.audio_pool_step - ) - self.audio_projection_layer = MultiModalProjector( - in_dim=audio_output_dim, out_dim=self.embed_dim - ) - self.audio_encoder_layer = -1 - return model - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) - return loader.load_weights(weights) - - def subsequent_chunk_mask( - self, - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = CPU_DEVICE, - num_lookhead: int = 0, - ) -> torch.Tensor: - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - # Vectorized computation of row indices and chunk boundaries - row_indices = torch.arange(size, device=device) - chunk_indices = row_indices // chunk_size - if num_left_chunks < 0: - # If num_left_chunks < 0, start is always 0 for all rows - start_indices = torch.zeros_like(row_indices) - else: - # Compute start indices vectorially - start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0) - start_indices = start_chunk_indices * chunk_size - # Compute ending indices vectorially - end_chunk_indices = chunk_indices + 1 - end_indices = torch.clamp( - end_chunk_indices * chunk_size + num_lookhead, max=size - ) - # Create column indices for broadcasting - col_indices = torch.arange(size, device=device).unsqueeze(0) - start_indices = start_indices.unsqueeze(1) - end_indices = end_indices.unsqueeze(1) - # Vectorized mask creation - ret = (col_indices >= start_indices) & (col_indices < end_indices) - return ret - - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 - input_lengths_after_pooling = ( - input_lengths_after_cnn - self.config.audio_pool_step - ) // self.config.audio_pool_step + 1 - input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) - - return input_lengths_after_cnn, input_lengths_after_pooling - - def get_audio_hidden_states( - self, data: MiniCPMOAudioFeatureInputs - ) -> list[torch.Tensor]: - chunk_length = self.config.audio_chunk_length - - # (bs, 80, frames) or [], multi audios need filled in advance - wavforms_raw = data["audio_features"] - if isinstance(wavforms_raw, list): - B = len(wavforms_raw) - C = wavforms_raw[0].shape[-2] - L = max(item.shape[-1] for item in wavforms_raw) - device = wavforms_raw[0].device - dtype = wavforms_raw[0].dtype - - wavforms = torch.zeros((B, C, L), dtype=dtype, device=device) - for i, wavforms_item in enumerate(wavforms_raw): - L_item = wavforms_item.shape[-1] - wavforms[i, ..., :L_item] = wavforms_item - else: - wavforms = wavforms_raw - - # list, [[x1, x2], [y1], [z1]] - audio_feature_lens_raw = data["audio_feature_lens"] - if isinstance(audio_feature_lens_raw, torch.Tensor): - audio_feature_lens_raw = audio_feature_lens_raw.unbind(0) - - audio_feature_lens = torch.hstack(audio_feature_lens_raw) - batch_size, _, max_mel_seq_len = wavforms.shape - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 - - # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = ( - torch.arange( - 0, - max_seq_len, - dtype=audio_feature_lens.dtype, - device=audio_feature_lens.device, - ) - .unsqueeze(0) - .expand(batch_size, max_seq_len) - ) - lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) - # Create mask - padding_mask = seq_range >= lengths_expand # 1 for padded values - - audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( - batch_size, 1, max_seq_len, max_seq_len - ) - audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device - ) - - if chunk_length > 0: - chunk_num_frame = int(chunk_length * 50) - chunk_mask = self.subsequent_chunk_mask( - size=max_seq_len, - chunk_size=chunk_num_frame, - num_left_chunks=-1, - device=audio_attention_mask_.device, - ) - audio_attention_mask_ = torch.logical_or( - audio_attention_mask_, torch.logical_not(chunk_mask) - ) - - audio_attention_mask[audio_attention_mask_] = float("-inf") - audio_states = self.apm( - wavforms, attention_mask=audio_attention_mask - ).hidden_states[self.audio_encoder_layer] - audio_embeds = self.audio_projection_layer(audio_states) - - audio_embeds = audio_embeds.transpose(1, 2) - audio_embeds = self.audio_avg_pooler(audio_embeds) - audio_embeds = audio_embeds.transpose(1, 2) - - _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( - audio_feature_lens - ) - - num_audio_tokens = feature_lens_after_pooling - - final_audio_embeds = list[torch.Tensor]() - idx = 0 - for i in range(len(audio_feature_lens_raw)): - target_audio_embeds_lst = list[torch.Tensor]() - for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds_lst.append( - audio_embeds[idx, : num_audio_tokens[idx], :] - ) - idx += 1 - - final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) - - return final_audio_embeds - - def _parse_and_validate_audio_input( - self, **kwargs: object - ) -> MiniCPMOAudioInputs | None: - audio_features = kwargs.pop("audio_features", None) - audio_embeds = kwargs.pop("audio_embeds", None) - - if audio_features is None and audio_embeds is None: - return None - - if audio_embeds is not None: - return MiniCPMOAudioEmbeddingInputs( - type="audio_embeds", - audio_embeds=audio_embeds, - ) - - audio_feature_lens = kwargs.pop("audio_feature_lens") - - return MiniCPMOAudioFeatureInputs( - type="audio_features", - audio_features=audio_features, - audio_feature_lens=audio_feature_lens, - ) - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = super()._parse_and_validate_multimodal_inputs(**kwargs) - - # Preserve the order of modalities if there are multiple of them - # from the order of kwargs. - for input_key in kwargs: - if ( - input_key in ("audio_features", "audio_embeds") - and "audios" not in modalities - ): - modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) - - return modalities - - def _process_audio_input( - self, - audio_input: MiniCPMOAudioInputs, - ) -> torch.Tensor | list[torch.Tensor]: - if audio_input["type"] == "audio_embeds": - return audio_input["audio_embeds"] - - return self.get_audio_hidden_states(audio_input) - - def _process_multimodal_inputs(self, modalities: dict): - multimodal_embeddings = super()._process_multimodal_inputs(modalities) - - for modality in modalities: - if modality == "audios": - audio_input = modalities["audios"] - audio_embeddings = self._process_audio_input(audio_input) - multimodal_embeddings += tuple(audio_embeddings) - - return multimodal_embeddings - - -class MiniCPMO2_6(MiniCPMOBaseModel, MiniCPMV2_6): - """MiniCPM-O model based on MiniCPMV 2.6 (Qwen2 backbone).""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Skip MiniCPMV2_6.__init__ version assertion, call MiniCPMVBaseModel directly - from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel - MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix) - # Override version for MiniCPM-O 2.6 - self.version = (2, 6) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) - - -class MiniCPMO4_5(MiniCPMOBaseModel, MiniCPMV4_5): - """MiniCPM-O 4.5 model based on MiniCPMV 4.5 (Qwen3 backbone).""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Skip MiniCPMV4_5.__init__ version assertion, call MiniCPMVBaseModel directly - from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel - MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix) - # Override version for MiniCPM-O 4.5 - self.version = (4, 5) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) - - -_MINICPMO_SUPPORT_VERSION = { - (2, 6): MiniCPMO2_6, - (4, 5): MiniCPMO4_5, -} - - -@MULTIMODAL_REGISTRY.register_processor( - MiniCPMOMultiModalProcessor, - info=MiniCPMOProcessingInfo, - dummy_inputs=MiniCPMODummyInputsBuilder, -) -class MiniCPMO(MiniCPMOBaseModel, MiniCPMV2_6): - """ - MiniCPM-O model with audio support. - Different versions use different LLM backbones: - - Version 2.6: Uses Qwen2 - - Version 4.5: Uses Qwen3 - """ - - def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - - # Determine version from config - if hasattr(config, "version"): - version = str(config.version).split(".") - version = tuple([int(x) for x in version]) - else: - # Auto-detect version based on config features: - # - MiniCPM-o 4.5 (Qwen3 backbone): has head_dim attribute, - # hidden_size=4096, num_hidden_layers=36 - # - MiniCPM-o 2.6 (Qwen2 backbone): no head_dim, different arch - has_head_dim = hasattr(config, "head_dim") - is_qwen3_like = ( - has_head_dim - and getattr(config, "hidden_size", 0) == 4096 - and getattr(config, "num_hidden_layers", 0) == 36 - ) - if is_qwen3_like: - version = (4, 5) - else: - # Default to 2.6 for backward compatibility - version = (2, 6) - - # Dispatch class based on version - instance_cls = _MINICPMO_SUPPORT_VERSION.get(version) - if instance_cls is None: - supported_versions = ", ".join( - [f"{v[0]}.{v[1]}" for v in sorted(_MINICPMO_SUPPORT_VERSION.keys())] - ) - raise ValueError( - f"Currently, MiniCPMO only supports versions " - f"{supported_versions}. Got version: {version}" - ) - - return instance_cls(vllm_config=vllm_config, prefix=prefix) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # This __init__ won't be called due to __new__ returning a different class - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) diff --git a/vllm_fl/models/qwen3_5.py b/vllm_fl/models/qwen3_5.py deleted file mode 100644 index ad3ec5f9..00000000 --- a/vllm_fl/models/qwen3_5.py +++ /dev/null @@ -1,951 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright 2025 The vLLM team. -# Copyright 2025 The Qwen Team. -"""Inference-only Qwen3.5 MoE model compatible with HuggingFace weights.""" - -import typing -from collections.abc import Callable, Iterable - -import torch -from einops import rearrange -from torch import nn - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3_5RMSNorm, -) -from vllm.model_executor.layers.linear import MergedColumnParallelLinear -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import IntermediateTensors - -from vllm.model_executor.models.interfaces import ( - HasInnerState, - IsHybrid, - MixtureOfExperts, - MultiModalEmbeddings, - SupportsLoRA, - SupportsPP, - _require_is_multimodal, -) -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - PPMissingLayer, - _merge_multimodal_embeddings, - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, - Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, - Qwen3VLProcessingInfo, -) - -# from vllm_fl.models.qwen3_next import ( -from vllm.model_executor.models.qwen3_next import ( - Qwen3NextAttention, - Qwen3NextDecoderLayer, - Qwen3NextGatedDeltaNet, - Qwen3NextModel, - Qwen3NextSparseMoeBlock, - QwenNextMixtureOfExperts, -) -from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig -import vllm_fl.models.qwen3_next # for error ''_OpNamespace' 'vllm' object has no attribute 'gdn_attention_core'' - -logger = init_logger(__name__) - - -class Qwen3_5SparseMoeBlock(Qwen3NextSparseMoeBlock): - """Override SparseMoeBlock to read config from hf_text_config.""" - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - # Temporarily patch hf_config to point to hf_text_config - # so the parent class reads the right config - original_hf_config = vllm_config.model_config.hf_config - vllm_config.model_config.hf_config = ( - vllm_config.model_config.hf_text_config - ) - try: - super().__init__(vllm_config=vllm_config, prefix=prefix) - finally: - vllm_config.model_config.hf_config = original_hf_config - - -class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): - """Qwen3.5 uses MergedColumnParallelLinear for qkvz projection - and has a different forward pass without fix_query_key_value_ordering.""" - - def fix_query_key_value_ordering( - self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, - ): - raise NotImplementedError( - "Qwen3.5 Series dont need to fix query key value ordering" - ) - - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - ) - - def __init__(self, config, model_config=None, cache_config=None, - quant_config=None, speculative_config=None, prefix=""): - # Call grandparent init to skip Qwen3NextGatedDeltaNet.__init__ - # but set up the same attributes - nn.Module.__init__(self) - from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - ) - from vllm.config import get_current_vllm_config - from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - RowParallelLinear, - ) - from vllm.model_executor.layers.layernorm import RMSNormGated - from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - mamba_v2_sharded_weight_loader, - ) - from vllm.model_executor.model_loader.weight_utils import ( - sharded_weight_loader, - ) - from vllm.model_executor.utils import set_weight_attrs - from vllm.platforms import current_platform - from transformers.activations import ACT2FN - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ) - - # Convolution - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [query_key_settings, query_key_settings, value_settings], - self.tp_size, - self.tp_rank, - ) - }, - ) - - # Use MergedColumnParallelLinear for qkvz projection - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) - - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj_ba = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.num_v_heads, self.num_v_heads], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_ba", - ) - - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), - ) - self.A_log = nn.Parameter( - torch.empty(divide(self.num_v_heads, self.tp_size)), - ) - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=current_platform.current_device(), - dtype=getattr(config, "dtype", torch.bfloat16), - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - num_tokens = hidden_states.size(0) - - # Input Projection - Qwen3.5 style - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) - qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size - z_size = self.value_dim // self.tp_size - mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) - z = z.reshape(z.size(0), -1, self.head_v_dim) - ba, _ = self.in_proj_ba(hidden_states) - b, a = ba.chunk(2, dim=-1) - - b = b.contiguous() - a = a.contiguous() - - # Core Attention - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # Output Projection - z_shape_og = z.shape - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - -class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): - def __init__( - self, - vllm_config: VllmConfig, - layer_type: str, - prefix: str = "", - ) -> None: - # Call nn.Module.__init__ directly, skipping Qwen3NextDecoderLayer's - super(Qwen3NextDecoderLayer, self).__init__() - - config = vllm_config.model_config.hf_text_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - speculative_config = vllm_config.speculative_config - - self.layer_type = layer_type - self.layer_idx = extract_layer_index(prefix) - - if self.layer_type == "linear_attention": - self.linear_attn = Qwen3_5GatedDeltaNet( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, - prefix=f"{prefix}.linear_attn", - ) - elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - raise ValueError(f"Invalid layer_type {self.layer_type}") - - # Qwen3.5 MoE: all layers use sparse MoE blocks - self.mlp = Qwen3_5SparseMoeBlock( - vllm_config=vllm_config, - prefix=f"{prefix}.mlp", - ) - - self.input_layernorm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - self.layer_scale = getattr(config, "layer_scale", False) - if self.layer_scale: - self.attn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, 1, config.hidden_size, - dtype=getattr(config, "dtype", torch.bfloat16), - ), - ) - self.ffn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, 1, config.hidden_size, - dtype=getattr(config, "dtype", torch.bfloat16), - ), - ) - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - } -) -class Qwen3_5Model(Qwen3NextModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super(Qwen3NextModel, self).__init__() - - config = vllm_config.model_config.hf_text_config - parallel_config = vllm_config.parallel_config - - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - ) - - def get_layer(prefix: str): - return Qwen3_5DecoderLayer( - vllm_config, - layer_type=config.layer_types[extract_layer_index(prefix)], - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" - ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - ) - - if get_pp_group().is_last_rank: - self.norm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - else: - self.norm = PPMissingLayer() - - def load_fused_expert_weights( - self, - name: str, - params_dict: dict, - loaded_weight: torch.Tensor, - shard_id: str, - num_experts: int, - ) -> bool: - param = params_dict[name] - weight_loader = typing.cast(Callable[..., bool], param.weight_loader) - loaded_local_expert = False - for expert_id in range(num_experts): - curr_expert_weight = loaded_weight[expert_id] - success = weight_loader( - param, - curr_expert_weight, - name, - shard_id, - expert_id, - return_success=True, - ) - if success: - loaded_local_expert = True - return loaded_local_expert - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - stacked_params_mapping = [ - # self attention - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - # mlp - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - # GDN - Qwen3.5 style - ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), - ("in_proj_qkvz", "in_proj_z", 3), - ("in_proj_ba", "in_proj_b", 0), - ("in_proj_ba", "in_proj_a", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - is_fused_expert = False - fused_expert_params_mapping = [ - ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), - ("experts.w2_weight", "experts.down_proj", 0, "w2"), - ] - num_experts = ( - self.config.num_experts - if hasattr(self.config, "num_experts") - else 0 - ) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if name.startswith("mtp."): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if ( - "experts.gate_up_proj" in name - or "experts.down_proj" in name - ): - is_fused_expert = True - expert_params_mapping = fused_expert_params_mapping - - if weight_name not in name: - continue - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - if isinstance(shard_id, tuple): - # Multi-shard: split loaded_weight and load each - layer = weight_loader.__self__ - split_sizes = [ - layer.output_sizes[s] for s in shard_id - ] - output_dim = getattr(param, "output_dim", 0) - parts = loaded_weight.split( - split_sizes, dim=output_dim - ) - for s, part in zip(shard_id, parts): - weight_loader(param, part, s) - else: - weight_loader(param, loaded_weight, shard_id) - break - else: - is_expert_weight = False - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - is_expert_weight = True - name_mapped = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name_mapped, self): - continue - if is_fused_expert: - if "experts.gate_up_proj" in name: - loaded_weight = loaded_weight.chunk(2, dim=-2) - success_w1 = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight[0], "w1", num_experts, - ) - success_w3 = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight[1], "w3", num_experts, - ) - success = success_w1 and success_w3 - else: - success = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight, shard_id, num_experts, - ) - if success: - name = name_mapped - break - else: - if ( - name_mapped.endswith(".bias") - or name_mapped.endswith("_bias") - ) and name_mapped not in params_dict: - continue - param = params_dict[name_mapped] - weight_loader = param.weight_loader - success = weight_loader( - param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - name = name_mapped - break - else: - if is_expert_weight: - continue - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - logger.warning_once( - f"Parameter {name} not found in params_dict, " - "skip loading" - ) - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts): - def _get_moe_model_layers(self): - """Get the model layers, handling both CausalLM and VL wrapper.""" - if hasattr(self, "language_model"): - return self.language_model.model.layers - return self.model.layers - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = ( - num_physical_experts - self.num_logical_experts - ) - for layer in self._get_moe_model_layers(): - if isinstance(layer.mlp, Qwen3_5SparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - def set_moe_parameters(self): - self.expert_weights = [] - self.moe_layers = [] - example_moe = None - for layer in self._get_moe_model_layers(): - if isinstance(layer, Qwen3_5DecoderLayer) and isinstance( - layer.mlp, (Qwen3NextSparseMoeBlock, Qwen3_5SparseMoeBlock) - ): - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError( - "No Qwen3_5 MoE layer found in the model.layers." - ) - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - -class Qwen3_5MoeForCausalLM( - nn.Module, - HasInnerState, - SupportsLoRA, - SupportsPP, - Qwen3_5_MoeMixtureOfExperts, - IsHybrid, -): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_text_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3.5 currently does not support prefix caching" - ) - self.quant_config = vllm_config.quant_config - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.model = Qwen3_5Model( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ): - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.logits_processor(self.lm_head, hidden_states) - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - - def _remap_weights( - weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - for name, tensor in weights: - # The HF checkpoint for Qwen3_5MoeForConditionalGeneration - # uses model.language_model.* for the text model, but our - # model structure uses model.* directly (no language_model - # wrapper level). - name = name.replace("model.language_model.", "model.") - yield name, tensor - - loader = AutoWeightsLoader( - self, - skip_prefixes=["mtp.", "model.visual."], - ) - return loader.load_weights(_remap_weights(weights)) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype, - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_text_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) - - -######################################################## -# Qwen3_5-MoE Multimodal (VL) Wrapper -######################################################## - - -class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): - """Processing info that uses Qwen3_5MoeConfig instead of Qwen3VLConfig.""" - - def get_hf_config(self): - return self.ctx.get_hf_config(Qwen3_5MoeConfig) - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen3VLMultiModalProcessor, - info=Qwen3_5MoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder, -) -class Qwen3_5MoeForConditionalGeneration( - Qwen3VLForConditionalGeneration, - HasInnerState, - Qwen3_5_MoeMixtureOfExperts, - IsHybrid, -): - """Multimodal Qwen3.5-MoE model wrapping VisionTransformer + CausalLM.""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): - nn.Module.__init__(self) - config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - - self.config = config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) - - # Vision encoder - if not multimodal_config.get_limit_per_prompt( - "image" - ) and not multimodal_config.get_limit_per_prompt("video"): - self.visual = None - else: - self.visual = Qwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - # Language model (MoE CausalLM) - self.language_model = Qwen3_5MoeForCausalLM( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - # Deepstack support - self.use_deepstack = hasattr( - config.vision_config, "deepstack_visual_indexes" - ) and bool(config.vision_config.deepstack_visual_indexes) - self.deepstack_num_level = ( - len(config.vision_config.deepstack_visual_indexes) - if self.use_deepstack - else 0 - ) - if self.use_deepstack and self.visual is not None: - self.deepstack_input_embeds = [ - torch.zeros( - vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size, - ) - for _ in range(self.deepstack_num_level) - ] - else: - self.deepstack_input_embeds = None - self.visual_dim = config.vision_config.out_hidden_size - self.multiscale_dim = self.visual_dim * self.deepstack_num_level - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids( - self, - input_ids: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings | None = None, - *, - is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, - ) -> torch.Tensor: - inputs_embeds = self._embed_text_input_ids( - input_ids, - self.language_model.embed_input_ids, - is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, - ) - - if multimodal_embeddings is None or len(multimodal_embeddings) == 0: - return inputs_embeds - - is_multimodal = _require_is_multimodal(is_multimodal) - - inputs_embeds = _merge_multimodal_embeddings( - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, - ) - - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor | IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - skip_prefixes = ["mtp."] - if self.visual is None: - skip_prefixes.append("visual.") - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.language_model.model.get_expert_mapping() - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype, - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_text_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) diff --git a/vllm_fl/models/qwen3_next.py b/vllm_fl/models/qwen3_next.py deleted file mode 100644 index f3afa3fe..00000000 --- a/vllm_fl/models/qwen3_next.py +++ /dev/null @@ -1,1396 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Inference-only Qwen3Next model.""" - -from collections.abc import Iterable -from itertools import islice - -import torch -from einops import rearrange -from torch import nn -from transformers.activations import ACT2FN -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import ( - CacheConfig, - ModelConfig, - SpeculativeConfig, - VllmConfig, - get_current_vllm_config, -) -from vllm.distributed import ( - divide, - get_ep_group, - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, -) -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3NextRMSNorm, -) -from vllm.model_executor.layers.layernorm import RMSNormGated -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, - causal_conv1d_update, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - sharded_weight_loader, -) -from vllm.model_executor.models.interfaces import ( - HasInnerState, - IsHybrid, - MixtureOfExperts, - SupportsLoRA, - SupportsPP, -) -from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - PPMissingLayer, - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, - sequence_parallel_chunk, -) -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs import Qwen3NextConfig -from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm_fl.ops.fla import ChunkGatedDeltaRuleOp, FusedRecurrentGatedDeltaRuleOp - -logger = init_logger(__name__) - -KVCache = tuple[torch.Tensor, torch.Tensor] - -class Qwen3NextSparseMoeBlock(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - parallel_config = vllm_config.parallel_config - quant_config = vllm_config.quant_config - - self.tp_size = get_tensor_model_parallel_world_size() - - self.ep_group = get_ep_group().device_group - self.ep_rank = get_ep_group().rank_in_group - self.ep_size = self.ep_group.size() - self.n_routed_experts = config.num_experts - - self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}." - ) - - # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = parallel_config.enable_eplb - - self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts - self.n_local_physical_experts = self.n_physical_experts // self.ep_size - - self.physical_expert_start = self.ep_rank * self.n_local_physical_experts - self.physical_expert_end = ( - self.physical_expert_start + self.n_local_physical_experts - ) - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - - if config.shared_expert_intermediate_size > 0: - self.shared_expert = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.shared_expert_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - expert_gate=self.shared_expert_gate, - prefix=f"{prefix}.shared_expert", - ) - else: - self.shared_expert = None - - self.experts = SharedFusedMoE( - shared_experts=self.shared_expert, - gate=self.gate, - num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.Renormalize, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - - if self.is_sequence_parallel: - hidden_states = sequence_parallel_chunk(hidden_states) - - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=hidden_states - ) - else: - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) - - if self.shared_expert is not None: - final_hidden_states = final_hidden_states[0] + final_hidden_states[1] - - if self.is_sequence_parallel: - final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0 - ) - final_hidden_states = final_hidden_states[:num_tokens] - elif self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states - ) - - return final_hidden_states.view(orig_shape) - - -class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): - @property - def mamba_type(self) -> str: - return "gdn_attention" - - def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype - ) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - self.conv_kernel_size, - self.num_spec, - ) - - def __init__( - self, - config: Qwen3NextConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - speculative_config: SpeculativeConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ) - - # QKV - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - # projection of the input hidden states - self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj_qkvz = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.projection_size_qkvz, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) - # ba_proj doesn't support blockwise fp8 quantization. - self.in_proj_ba = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.projection_size_ba, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_ba", - ) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.tp_size, - self.tp_rank, - ) - }, - ) - - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), - ) - self.A_log = nn.Parameter( - torch.empty( - divide(self.num_v_heads, self.tp_size), - ) - ) - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=current_platform.current_device(), - dtype=config.dtype, - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp( - output_final_state = True, - use_qk_l2norm_in_kernel=True, - ) - - self.fused_recurrent_gated_delta_rule_multi_query = FusedRecurrentGatedDeltaRuleOp( - inplace_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - self.fused_recurrent_gated_delta_rule_remain_query = FusedRecurrentGatedDeltaRuleOp( - inplace_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - - def fix_query_key_value_ordering( - self, - mixed_qkvz, - mixed_ba, - ): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - ( - self.head_k_dim - + self.head_k_dim - + (self.head_v_dim + self.head_v_dim) - * self.num_v_heads - // self.num_k_heads - ), - ) - new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - 2 * self.num_v_heads // self.num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - self.head_k_dim, - self.head_k_dim, - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - ] - split_arg_list_ba = [ - self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads, - ] - - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], - # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) - a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) - - return query, key, value, z, b, a - - def rearrange_mixed_qkv(self, mixed_qkv): - if mixed_qkv is None: - return None, None, None - query, key, value = torch.split( - mixed_qkv, - [ - self.key_dim // self.tp_size, - self.key_dim // self.tp_size, - self.value_dim // self.tp_size, - ], - dim=-1, - ) - query, key = map( - lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), - (query, key), - ) - value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) - return query, key, value - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - """ - Forward pass with three parts: - 1. Input projection - 2. Core attention (custom op) - 3. Output projection - """ - num_tokens = hidden_states.size(0) - - # ============================================================ - # Part 1: Input Projection - # ============================================================ - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba - ) - query, key, value = map( - lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) - ) - mixed_qkv = torch.cat((query, key, value), dim=-1) - - # ============================================================ - # Part 2: Core Attention (Custom Op) - # ============================================================ - # Note: we should not use torch.empty here like other attention backends, - # see discussions in https://github.com/vllm-project/vllm/pull/28182 - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # ============================================================ - # Part 3: Output Projection - # ============================================================ - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - def _forward_core( - self, - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - ): - """ - Core attention computation (called by custom op). - """ - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - num_actual_tokens = attn_metadata.num_actual_tokens - num_accepted_tokens = attn_metadata.num_accepted_tokens - - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] - - # 1. Convolution sequence transformation - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) - - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 1.1: Process the multi-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0][ - : attn_metadata.num_spec_decodes - ], - num_accepted_tokens=num_accepted_tokens, - query_start_loc=spec_query_start_loc, - max_query_len=spec_state_indices_tensor.size(-1), - validate_data=False, - ) - - # 1.2: Process the remaining part - if attn_metadata.num_prefills > 0: - mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d_fn( - mixed_qkv_non_spec_T, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, - query_start_loc=non_spec_query_start_loc, - metadata=attn_metadata, - ).transpose(0, 1) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_actual_tokens - ], - validate_data=True, - ) - else: - mixed_qkv_non_spec = None - - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec - ) - - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 2. Recurrent attention - - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = self.fused_recurrent_gated_delta_rule_multi_query( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - ) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 2.2: Process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = self.chunk_gated_delta_rule( - q=query_non_spec.contiguous(), - k=key_non_spec.contiguous(), - v=value_non_spec.contiguous(), - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - cu_seqlens=non_spec_query_start_loc, - ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype - ) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = ( - self.fused_recurrent_gated_delta_rule_remain_query( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - cu_seqlens=non_spec_query_start_loc[ - : attn_metadata.num_decodes + 1 - ], - ssm_state_indices=non_spec_state_indices_tensor, - ) - ) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - # 3. Merge core attention output - if spec_sequence_masks is not None and core_attn_out_non_spec is not None: - merged_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - elif spec_sequence_masks is not None: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - - -class Qwen3NextAttention(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.dual_chunk_attention_config = getattr( - config, "dual_chunk_attention_config", None - ) - self.attn_output_gate = getattr(config, "attn_output_gate", True) - - self.qkv_proj = QKVParallelLinear( - config.hidden_size, - self.head_dim, - self.total_num_heads * (1 + self.attn_output_gate), - self.total_num_kv_heads, - bias=getattr(config, "qkv_bias", False), - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - head_size=self.head_dim, - max_position=config.max_position_embeddings, - rope_parameters=config.rope_parameters, - dual_chunk_attention_config=self.dual_chunk_attention_config, - ) - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": self.dual_chunk_attention_config, - } - if self.dual_chunk_attention_config - else {}, - ) - - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - output: torch.Tensor, - hidden_states: torch.Tensor, - ): - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn_output_gate: - q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 - ) - orig_shape = q_gate.shape[:-1] - q_gate = q_gate.view(*orig_shape, self.num_heads, -1) - q, gate = torch.chunk(q_gate, 2, dim=-1) - q = q.reshape(*orig_shape, -1) - gate = gate.reshape(*orig_shape, -1) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( - -1, self.num_heads * self.head_dim - ) - k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( - -1, self.num_kv_heads * self.head_dim - ) - - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - - if self.attn_output_gate: - gate = torch.sigmoid(gate) - attn_output = attn_output * gate - - output[:], _ = self.o_proj(attn_output) - - -class Qwen3NextDecoderLayer(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - layer_type: str, - prefix: str = "", - ) -> None: - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - speculative_config = vllm_config.speculative_config - - self.layer_type = layer_type - self.layer_idx = extract_layer_index(prefix) - - if self.layer_type == "linear_attention": - self.linear_attn = Qwen3NextGatedDeltaNet( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, - prefix=f"{prefix}.linear_attn", - ) - elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - raise ValueError(f"Invalid layer_type {self.layer_type}") - - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers - ) - if (self.layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 - and (self.layer_idx + 1) % config.decoder_sparse_step == 0 - ): - self.mlp = Qwen3NextSparseMoeBlock( - vllm_config=vllm_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - - self.input_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - self.layer_scale = getattr(config, "layer_scale", False) - if self.layer_scale: - self.attn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.dtype, - ), - ) - self.ffn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.dtype, - ), - ) - - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - positions: torch.Tensor = None, - **kwargs: object, - ): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - - self_attention_output = torch.empty_like(hidden_states) - if self.layer_type == "linear_attention": - self.linear_attn( - hidden_states=hidden_states, - output=self_attention_output, - ) - elif self.layer_type == "full_attention": - self.self_attn( - hidden_states=hidden_states, - output=self_attention_output, - positions=positions, - ) - else: - raise ValueError("Invalid layer_type") - hidden_states = self_attention_output - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 - ) - else: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype) + 1 - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - - hidden_states = self.mlp(hidden_states) - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 - ) - else: - assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( - f"shape must be the same {len(hidden_states.shape)}, " - f"{len(self.ffn_layer_scale.shape)}" - ) - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype) + 1 - ) - - return hidden_states, residual - - -@support_torch_compile -class Qwen3NextModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config: Qwen3NextConfig = vllm_config.model_config.hf_config - parallel_config = vllm_config.parallel_config - - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - - self.config = config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - ) - - def get_layer(prefix: str): - return Qwen3NextDecoderLayer( - vllm_config, - layer_type=config.layer_types[extract_layer_index(prefix)], - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" - ) - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - - if get_pp_group().is_last_rank: - self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts, - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if name.startswith("mtp."): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # name = apply_attn_prefix(name, params_dict) - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - logger.warning_once( - f"Parameter {name} not found in params_dict, skip loading" - ) - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class QwenNextMixtureOfExperts(MixtureOfExperts): - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for layer in self.model.layers: - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - def set_moe_parameters(self): - self.expert_weights = [] - - self.moe_layers = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( - layer.mlp, Qwen3NextSparseMoeBlock - ): - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError("No Qwen3Next layer found in the model.layers.") - - # Set MoE hyperparameters - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - -class Qwen3NextForCausalLM( - nn.Module, - HasInnerState, - SupportsLoRA, - SupportsPP, - QwenNextMixtureOfExperts, - IsHybrid, -): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) - self.quant_config = vllm_config.quant_config - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.model = Qwen3NextModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ): - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - - return hidden_states - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.logits_processor(self.lm_head, hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["mtp."], - ) - return loader.load_weights(weights) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() - - -def gdn_attention_core( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - layer_name: str, -) -> None: - """ - Custom op for the core attention computation. - Only handles the convolution + recurrent attention part. - Input/output projections are handled outside this op. - """ - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self._forward_core( - mixed_qkv=mixed_qkv, - b=b, - a=a, - core_attn_out=core_attn_out, - ) - - -def gdn_attention_core_fake( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - layer_name: str, -) -> None: - """Fake implementation for torch.compile.""" - return - - -if not hasattr(torch.ops.vllm, "gdn_attention_core"): - direct_register_custom_op( - op_name="gdn_attention_core", - op_func=gdn_attention_core, - mutates_args=["core_attn_out"], - fake_impl=gdn_attention_core_fake, - ) - - -@triton.jit -def fused_gdn_gating_kernel( - g, - beta_output, - A_log, - a, - b, - dt_bias, - seq_len, - NUM_HEADS: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, - BLK_HEADS: tl.constexpr, -): - i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) - head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off - mask = head_off < NUM_HEADS - blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_b = tl.load(b + off, mask=mask) - blk_bias = tl.load(dt_bias + head_off, mask=mask) - # If the model is loaded in fp16, without the .float() here, A might be -inf - x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where( - beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x - ) - blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x - tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) - # compute beta_output = sigmoid(b) - blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) - tl.store( - beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask - ) - - -def fused_gdn_gating( - A_log: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - dt_bias: torch.Tensor, - beta: float = 1.0, - threshold: float = 20.0, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fused computation of g and beta for Gated Delta Net. - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - beta_output = b.sigmoid() - TODO maybe use torch.compile to replace this triton kernel - """ - batch, num_heads = a.shape - seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) - fused_gdn_gating_kernel[grid]( - g, - beta_output, - A_log, - a, - b, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1, - ) - return g, beta_output diff --git a/vllm_fl/ops/fused_moe/fused_moe.py b/vllm_fl/ops/fused_moe/fused_moe.py index b35dc3c5..72465aa9 100644 --- a/vllm_fl/ops/fused_moe/fused_moe.py +++ b/vllm_fl/ops/fused_moe/fused_moe.py @@ -8,7 +8,6 @@ import functools import torch import torch.nn.functional as F -import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -17,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( _get_config_quant_dtype, try_get_optimal_moe_config, - invoke_fused_moe_kernel, + dispatch_fused_moe_kernel, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.triton_utils import tl @@ -109,7 +108,9 @@ def fused_experts_impl( top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + # Note: upstream removed VLLM_FUSED_MOE_CHUNK_SIZE in v0.18.1; + # keep the old default (64K) for the FL chunked implementation. + CHUNK_SIZE = 65536 M = min(num_tokens, CHUNK_SIZE) config_dtype = _get_config_dtype_str( @@ -209,7 +210,7 @@ def fused_experts_impl( ignore_invalid_experts=True, ) - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qcurr_hidden_states, w1, intermediate_cache1, @@ -235,6 +236,9 @@ def fused_experts_impl( # Activation function with multiplication # todo: dispatch to flag_gems and other backends + # v0.18.1: activation may be a MoEActivation enum; normalize to str + if hasattr(activation, "value"): + activation = activation.value if activation == "silu": intermediate_cache2 = call_op( "silu_and_mul", None, intermediate_cache1.view(-1, N) @@ -262,7 +266,7 @@ def fused_experts_impl( block_shape=block_shape, ) - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py index 3d3368b3..740dcd59 100644 --- a/vllm_fl/ops/fused_moe/layer.py +++ b/vllm_fl/ops/fused_moe/layer.py @@ -11,8 +11,8 @@ import vllm.envs as envs from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod -from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator -from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import RoutingSimulator +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import grouped_topk from vllm.platforms import current_platform from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( @@ -21,7 +21,7 @@ if current_platform.is_cuda_alike(): - from vllm.model_executor.layers.fused_moe.fused_moe import ( + from vllm.model_executor.layers.fused_moe.router.base_router import ( eplb_map_to_physical_and_record, ) else: @@ -76,7 +76,30 @@ def forward_oot( else: return result - forward_native = forward_oot + def forward_native( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + """v0.18.1 forward_native signature — called when custom ops are + disabled (e.g. under torch.compile / Inductor). + Note: shared experts are handled by the upstream runner + (_apply_quant_method), so we must NOT return a tuple here.""" + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + ) class FusedMoEFL(FusedMoE): @@ -127,7 +150,7 @@ def select_experts( plain MoE implementations without redundant experts. """ from vllm_fl.ops.fused_moe.fused_moe import fused_topk - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk_bias + from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import fused_topk_bias if self.enable_eplb: if self.quant_method.supports_eplb: diff --git a/vllm_fl/patches/glm_moe_dsa.py b/vllm_fl/patches/glm_moe_dsa.py index 4d205537..f4a7dd67 100644 --- a/vllm_fl/patches/glm_moe_dsa.py +++ b/vllm_fl/patches/glm_moe_dsa.py @@ -52,52 +52,6 @@ def _patched_set(self, special_tokens=None): pass -def patch_is_deepseek_mla(): - """Patch ModelConfig.is_deepseek_mla to recognise glm_moe_dsa as MLA.""" - from vllm.config.model import ModelConfig - _orig_is_mla = ModelConfig.is_deepseek_mla.fget - - @property - def _patched_is_mla(self): - if ( - hasattr(self.hf_text_config, "model_type") - and self.hf_text_config.model_type == "glm_moe_dsa" - and getattr(self.hf_text_config, "kv_lora_rank", None) - is not None - ): - return True - return _orig_is_mla(self) - - ModelConfig.is_deepseek_mla = _patched_is_mla - - -def patch_fp8_mqa_logits_dim(): - """Fix k_scale dim mismatch for deep_gemm fp8_mqa_logits. - - vLLM 0.13.0 passes k_scale as [N, 1] but deep_gemm 2.3.0 expects [N]. - Upstream fix: https://github.com/vllm-project/vllm/pull/32652 - We wrap fp8_mqa_logits to flatten k_scale before calling the native impl. - """ - import vllm.utils.deep_gemm as dg_mod - - dg_mod._lazy_init() - _orig_impl = dg_mod._fp8_mqa_logits_impl - if _orig_impl is None: - return - - def _fixed_fp8_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke): - k_fp8, k_scale = kv - return _orig_impl( - q, (k_fp8, k_scale.flatten()), weights, - cu_seqlen_ks, cu_seqlen_ke, - ) - - dg_mod._fp8_mqa_logits_impl = _fixed_fp8_mqa_logits - dg_mod._lazy_init = lambda: None - logger.info("Patched fp8_mqa_logits: flatten k_scale [N,1] -> [N] " - "for deep_gemm 2.3.0 compat") - - def patch_indexer_schedule_metadata(): """Fix schedule_metadata not computed when VLLM_USE_DEEP_GEMM=0. @@ -144,8 +98,6 @@ def _patched_build(self, common_prefix_len, common_attn_metadata, def apply_platform_patches(): """All GLM-5 patches needed at platform registration time.""" patch_tokenizer_compat() - patch_fp8_mqa_logits_dim() - def patch_indexer_rope_reshape(): """Fix RoPE output shape in Indexer.forward for DSA models. @@ -222,6 +174,5 @@ def _patched_forward(self, hidden_states, qr, positions, rotary_emb): def apply_model_patches(): """All GLM-5 patches needed at model registration time.""" - patch_is_deepseek_mla() patch_indexer_schedule_metadata() patch_indexer_rope_reshape() diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 6a76600e..203e7a96 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -1,11 +1,11 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/platforms/cuda.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/platforms/cuda.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch @@ -16,15 +16,15 @@ except (ImportError, OSError): pass # NPU or other platforms may not have vllm._C -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import Platform, PlatformEnum from vllm.platforms.interface import DeviceCapability +from vllm.v1.attention.backends.registry import AttentionBackendEnum if TYPE_CHECKING: - from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import CacheDType + from vllm.v1.attention.selector import AttentionSelectorConfig else: VllmConfig = None CacheDType = None @@ -48,6 +48,12 @@ class PlatformFL(Platform): vendor_name = device_info.vendor_name device_type = get_device_type(vendor_name) device_name = get_device_name(vendor_name) + # cuda_alike (nvidia/metax): device_name = vendor_name (not used in torch.device) + # non-cuda_alike (iluvatar/ascend): device_name = device_type (used in torch.device) + device_name = device_info.vendor_name if ( + device_info.device_type == "cuda" and device_info.vendor_name != "iluvatar" + ) else device_info.device_type + device_type = device_info.device_type dispatch_key = device_info.dispatch_key torch_device_fn = device_info.torch_device_fn ray_device_key: str = "GPU" @@ -88,7 +94,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: cls.torch_device_fn.empty_cache() cls.torch_device_fn.reset_peak_memory_stats(device) @@ -120,6 +126,9 @@ def is_pin_memory_available(cls): def import_kernels(cls) -> None: """Import device-specific kernels.""" logger.info(f"current vendor_name is: {cls.vendor_name}") + # Always load base vLLM C extensions + super().import_kernels() + if cls.vendor_name == "metax": try: import mcoplib._C # noqa: F401 @@ -209,9 +218,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "AttentionBackendEnum", + selected_backend: "AttentionBackendEnum | None", attn_selector_config: "AttentionSelectorConfig", - ) -> list[str]: + num_heads: int | None = None, + ) -> str: """Get the attention backend class path using the dispatch mechanism.""" from vllm_fl.dispatch import call_op @@ -245,8 +255,8 @@ def get_vit_attn_backend( cls, head_size: int, dtype: torch.dtype, - backend: Optional["AttentionBackendEnum"] = None, - ) -> list[str]: + backend: "AttentionBackendEnum | None" = None, + ) -> "AttentionBackendEnum": from vllm_fl.attention.utils import patch_mm_encoder_attention patch_mm_encoder_attention() @@ -335,10 +345,31 @@ def use_custom_allreduce(cls) -> bool: return True @classmethod - def pre_register_and_update(cls, parser = None) -> None: + def pre_register_and_update(cls, parser=None) -> None: if cls.device_name == "npu": import vllm_fl.dispatch.backends.vendor.ascend + def supports_fp8(cls) -> bool: + if cls.vendor_name == "nvidia": + return True + return False + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + return cls.torch_device_fn.get_device_properties( + device_id + ).total_memory + + @classmethod + def use_custom_op_collectives(cls) -> bool: + if cls.vendor_name == "nvidia": + return True + return False + + @classmethod + def num_compute_units(cls, device_id: int = 0) -> int: + return cls.torch_device_fn.get_device_properties(device_id).multi_processor_count + @classmethod def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index bb2b2d38..73a68b93 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -10,7 +10,6 @@ _OP_CONFIG: Optional[dict[str, str]] = None - # Mapping used by dispatch registration to resolve the current runtime platform # into a backend directory under dispatch/backends/vendor. # diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 155c0316..e5f552b7 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -1,19 +1,20 @@ -# Mainly adopted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py +# Copyright (c) 2025 BAAI. All rights reserved. +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import functools import gc import itertools +import threading import time from collections import defaultdict -from collections.abc import Iterator, Sequence -from contextlib import contextmanager, nullcontext +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import contextmanager from copy import copy, deepcopy +from dataclasses import dataclass, replace from functools import reduce -from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -23,22 +24,18 @@ from tqdm import tqdm import vllm.envs as envs -from vllm.attention.backends.abstract import( - AttentionBackend, - AttentionMetadata, - AttentionType, - MultipleOf) -from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphStat #CUDAGraphWrapper +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config, update_config, ) +from vllm.config.cache import CacheConfig from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group @@ -47,22 +44,32 @@ get_dcp_group, get_pp_group, get_tp_group, + graph_capture, is_global_first_rank, prepare_communication_buffer_for_model, - GraphCaptureContext ) from vllm.forward_context import ( BatchDescriptor, set_forward_context, ) from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping, LoRAMappingType +from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding, XDRotaryEmbedding, ) -from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal, SupportsXDRoPE, @@ -70,6 +77,7 @@ supports_eagle3, supports_mrope, supports_multimodal_pruning, + supports_realtime, supports_transcription, supports_xdrope, ) @@ -78,26 +86,30 @@ is_pooling_model, is_text_generation_model, ) +from vllm.model_executor.offloader import ( + create_offloader, + get_offloader, + set_offloader, +) from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.multimodal.inputs import ( BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange, ) -from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.multimodal.utils import group_and_batch_mm_kwargs from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.tracing import instrument from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.utils.jsontree import json_map_leaves from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.mem_constants import GiB_bytes -from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.nvtx_pytorch_hooks import PytHooks -from vllm.utils.platform_utils import is_pin_memory_available - +from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.platforms import current_platform if current_platform.dist_backend == "flagcx": @contextmanager @@ -135,18 +147,23 @@ def graph_capture(device: torch.device): from vllm.utils.torch_utils import ( get_dtype_size, kv_cache_dtype_str_to_dtype, - supports_dynamo, ) -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, + AttentionMetadata, AttentionMetadataBuilder, + AttentionType, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, - split_attn_metadata, ) +from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( AttentionSpec, @@ -180,16 +197,28 @@ def graph_capture(device: torch.device): from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ngram_proposer_gpu import ( + NgramProposerGPU, + copy_num_valid_draft_tokens, + update_ngram_gpu_tensors_incremental, + update_scheduler_for_invalid_drafts, +) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + check_attention_cp_compatibility, + get_total_cp_world_size, +) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin +from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -198,22 +227,27 @@ def graph_capture(device: torch.device): UBatchSlices, check_ubatch_thresholds, maybe_create_ubatch_slices, + split_attn_metadata, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.workspace import lock_workspace from vllm.v1.worker.utils import ( AttentionGroup, - MultiModalBudget, + KVBlockZeroer, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + prepare_kernel_block_sizes, sanity_check_mm_encoder_outputs, ) if TYPE_CHECKING: - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + +logger = init_logger(__name__) +# FL-specific imports from vllm_fl.compilation.graph import GraphWrapper from vllm_fl.dispatch.io_common import managed_inference_mode from vllm_fl.dispatch.io_dumper import ( @@ -221,9 +255,7 @@ def graph_capture(device: torch.device): init_io_dump_from_env, register_io_module_hooks, ) - -logger = init_logger(__name__) - +GraphWrapper = GraphWrapper AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled @@ -238,7 +270,7 @@ def __init__( sampled_token_ids: torch.Tensor, logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], - async_output_copy_stream: current_platform.torch_device_fn.Stream | None, + async_output_copy_stream: current_platform.torch_device_fn.Stream, vocab_size: int, ): self._model_runner_output = model_runner_output @@ -258,7 +290,7 @@ def __init__( with current_platform.torch_device_fn.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self.sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True + "cpu", non_blocking=True ) self._logprobs_tensors_cpu = ( self._logprobs_tensors.to_cpu_nonblocking() @@ -282,22 +314,106 @@ def get_output(self) -> ModelRunnerOutput: valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: valid_sampled_token_ids[i].clear() - cu_num_tokens = None + logprobs_lists = None + if self._logprobs_tensors_cpu is not None: + logprobs_lists = self._logprobs_tensors_cpu.tolists() else: - valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output( self.sampled_token_ids_cpu, self.vocab_size, self._invalid_req_indices, - return_cu_num_tokens=self._logprobs_tensors_cpu is not None, + logprobs_tensors=self._logprobs_tensors_cpu, ) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids - if self._logprobs_tensors_cpu: - output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens) + output.logprobs = logprobs_lists return output +def _copy_pooler_output_to_cpu( + raw_pooler_output: PoolerOutput, finished_mask: list[bool] +) -> list[torch.Tensor | None]: + num_reqs = len(finished_mask) + + if isinstance(raw_pooler_output, torch.Tensor): + if raw_pooler_output.shape[0] != num_reqs: + raise ValueError( + "Pooler output batch size does not match finished mask size: " + f"{raw_pooler_output.shape[0]} != {num_reqs}." + ) + + num_finished = sum(finished_mask) + if num_finished == 0: + return [None] * num_reqs + if num_finished == num_reqs: + return list(raw_pooler_output.to("cpu", non_blocking=True)) + + # partial finished + finished_indices = [i for i, include in enumerate(finished_mask) if include] + index_tensor = torch.tensor( + finished_indices, device=raw_pooler_output.device, dtype=torch.long + ) + finished_outputs = raw_pooler_output.index_select(0, index_tensor).to( + "cpu", non_blocking=True + ) + partial_pooler_output: list[torch.Tensor | None] = [None] * num_reqs + for i, out in zip(finished_indices, finished_outputs): + partial_pooler_output[i] = out + return partial_pooler_output + + assert isinstance(raw_pooler_output, list) + if len(raw_pooler_output) != num_reqs: + raise ValueError( + "Pooler output batch size does not match finished mask size: " + f"{len(raw_pooler_output)} != {num_reqs}." + ) + + pooler_output: list[torch.Tensor | None] = [None] * num_reqs + for i, (out, include) in enumerate(zip(raw_pooler_output, finished_mask)): + if include and out is not None: + pooler_output[i] = out.to("cpu", non_blocking=True) + return pooler_output + + +class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + raw_pooler_output: PoolerOutput, + finished_mask: list[bool], + async_output_copy_stream: current_platform.torch_device_fn.Stream, + ): + self._model_runner_output = model_runner_output + + # Event on the copy stream so we can synchronize the non-blocking copy. + self.async_copy_ready_event = torch.Event() + + # Keep a reference to the device tensors to avoid them being + # deallocated until we finish copying it to the host. + self._raw_pooler_output = raw_pooler_output + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = current_platform.torch_device_fn.current_stream() + with current_platform.torch_device_fn.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._model_runner_output.pooler_output = _copy_pooler_output_to_cpu( + raw_pooler_output=self._raw_pooler_output, + finished_mask=finished_mask, + ) + self.async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + This function blocks until the copy is finished. + """ + self.async_copy_ready_event.synchronize() + + # Release the device tensors once the copy has completed. + del self._raw_pooler_output + return self._model_runner_output + + class ExecuteModelState(NamedTuple): """Ephemeral cached state transferred between execute_model() and sample_tokens(), after execute_model() returns None.""" @@ -311,6 +427,7 @@ class ExecuteModelState(NamedTuple): aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None class ModelRunnerFL( @@ -324,6 +441,7 @@ def __init__( self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.offload_config = vllm_config.offload_config self.compilation_config = vllm_config.compilation_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config @@ -332,10 +450,6 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config - from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - - set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) - model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config @@ -355,6 +469,9 @@ def __init__( ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False + # Set to True after init_routed_experts_capturer() completes. + # Prevents routed experts code from running during profiling/dummy run. + self.routed_experts_initialized = False self.max_model_len = model_config.max_model_len # Always set to false after the first forward pass @@ -366,11 +483,11 @@ def __init__( # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches + # TODO: Support overlapping micro-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( self.parallel_config.distributed_executor_backend == "external_launcher" - and len(get_pp_group().ranks) > 0 + and len(get_pp_group().ranks) > 1 ) # Model-related. @@ -398,10 +515,15 @@ def __init__( else: self.max_encoder_len = 0 + # Async scheduling + self.use_async_scheduling = self.scheduler_config.async_scheduling + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: EplbState | None = None + # NOTE(yongji): flag to temporarily disable EPLB during scaling up/down + self.eep_eplb_suppressed = False """ State of the expert parallelism load balancer. @@ -421,6 +543,7 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} + self.late_interaction_runner = LateInteractionRunner() self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -429,10 +552,41 @@ def __init__( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer # noqa: F823 + | NgramProposerGPU + | SuffixDecodingProposer + | EagleProposer + | DraftModelProposer + | MedusaProposer + | ExtractHiddenStatesProposer ) if self.speculative_config.method == "ngram": + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer( + vllm_config=self.vllm_config, + device=self.device, + runner=self, + ) + elif self.speculative_config.use_ngram_gpu(): + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) + self.num_tokens_no_spec_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) + self.token_ids_gpu_tensor = torch.zeros( + self.max_num_reqs, + self.max_model_len, + dtype=torch.int32, + device=device, + ) + self._ngram_pinned_idx_buf = torch.zeros( + self.max_num_reqs, dtype=torch.long, pin_memory=True + ) + self._ngram_pinned_val_buf = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=True + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -445,15 +599,26 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device ) + elif self.speculative_config.method == "extract_hidden_states": + self.drafter = ExtractHiddenStatesProposer( + vllm_config=self.vllm_config, device=self.device + ) + self.use_aux_hidden_state_outputs = True else: raise ValueError( "Unknown speculative decoding method: " f"{self.speculative_config.method}" ) self.rejection_sampler = RejectionSampler(self.sampler) + self.num_spec_tokens = 0 if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens + draft_config = self.speculative_config.draft_model_config + if draft_config is not None and draft_config.max_model_len is not None: + self.effective_drafter_max_model_len = draft_config.max_model_len + else: + self.effective_drafter_max_model_len = self.max_model_len # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -475,17 +640,22 @@ def __init__( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( tuple(logits_processors) if logits_processors is not None else () ) + placeholder_block_size = ( + self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE + ) + self._init_block_sizes = [placeholder_block_size] + self._init_kernel_block_sizes = [placeholder_block_size] self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - # We need to use the encoder length for encoder-decoer + # We need to use the encoder length for encoder-decoder # because of KV cache for cross-attention. max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.cache_config.block_size], - kernel_block_sizes=[self.cache_config.block_size], + block_sizes=[placeholder_block_size], + kernel_block_sizes=[placeholder_block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -501,13 +671,14 @@ def __init__( cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) - self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = current_platform.torch_device_fn.Stream() if \ - self.use_async_scheduling else None + # Separate cuda stream for overlapping transfer of sampled token ids from + # GPU to CPU when async scheduling is enabled. + self.async_output_copy_stream: current_platform.torch_device_fn.Stream | None = None # cuda event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. self.prepare_inputs_event: torch.Event | None = None if self.use_async_scheduling: + self.async_output_copy_stream = current_platform.torch_device_fn.Stream() self.prepare_inputs_event = torch.Event() # self.cudagraph_batch_sizes sorts in ascending order. @@ -518,6 +689,16 @@ def __init__( self.cudagraph_batch_sizes = sorted( self.compilation_config.cudagraph_capture_sizes ) + else: + self.cudagraph_batch_sizes = [] + + # Cache the device properties. + self._init_device_properties() + + # Encoder timing registry for observability + self.encoder_timing_registry: dict[str, EncoderTimingStats] = {} + self._encoder_timing_lock = threading.Lock() + # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) @@ -546,9 +727,16 @@ def __init__( self.num_accepted_tokens = self._make_buffer( self.max_num_reqs, dtype=torch.int64 ) + # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + # Double buffer to avoid race condition: previous iteration's async + # copy may still be reading from CPU while current iteration writes. + self.is_mm_embed_buffers = [ + self._make_buffer(self.max_num_tokens, dtype=torch.bool), + self._make_buffer(self.max_num_tokens, dtype=torch.bool), + ] + self.is_mm_embed_idx = 0 # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -602,11 +790,7 @@ def __init__( self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.mm_budget = ( - MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) + MultiModalBudget(self.vllm_config, self.mm_registry) if self.supports_mm_inputs else None ) @@ -620,6 +804,22 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. + self._num_valid_draft_tokens: torch.Tensor | None = None + self._num_valid_draft_tokens_cpu: torch.Tensor | None = None + self._num_valid_draft_tokens_event: current_platform.torch_device_fn.Event | None = None + self._num_valid_draft_tokens_copy_stream: current_platform.torch_device_fn.Stream | None = None + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + self._num_valid_draft_tokens_cpu = torch.empty( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self._num_valid_draft_tokens_event = current_platform.torch_device_fn.Event() + self._num_valid_draft_tokens_copy_stream = current_platform.torch_device_fn.Stream() + + self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() # TODO(yxa): NPU uses int32, CUDA uses int64 for sampled token ids sampled_ids_dtype = torch.int32 if current_platform.device_type == "npu" else torch.int64 @@ -629,28 +829,73 @@ def __init__( device="cpu", pin_memory=self.pin_memory, ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None - self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None - if self.use_async_scheduling and self.num_spec_tokens: - self.valid_sampled_token_count_event = torch.Event() - self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() - self.valid_sampled_token_count_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + self.valid_sampled_token_count_copy_stream: current_platform.torch_device_fn.Stream | None = None + # We also copy the drafted tokens to the CPU asynchronously, + # in case we need them for structured outputs. + self.draft_token_ids_event: torch.Event | None = None + self.draft_token_ids_copy_stream: current_platform.torch_device_fn.Stream | None = None + self.valid_sampled_token_count_cpu: torch.Tensor | None = None + self.draft_token_ids_cpu: torch.Tensor | None = None + self.num_accepted_tokens_event: torch.Event | None = None + if self.num_spec_tokens: + self.draft_token_ids_event = torch.Event() + self.num_accepted_tokens_event = torch.Event() + self.draft_token_ids_copy_stream = current_platform.torch_device_fn.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + if self.use_async_scheduling: + self.valid_sampled_token_count_event = torch.Event() + self.valid_sampled_token_count_copy_stream = current_platform.torch_device_fn.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=sampled_ids_dtype, + device="cpu", + pin_memory=self.pin_memory, + ) + + # Model weight offloader + # Make sure this is called before any get_offloader call + set_offloader(create_offloader(self.offload_config)) # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None self.layerwise_nvtx_hooks_registered = False + def update_max_model_len(self, max_model_len: int) -> None: + self.max_model_len = max_model_len + if self.speculative_config: + draft_config = self.speculative_config.draft_model_config + if draft_config is None or draft_config.max_model_len is None: + self.effective_drafter_max_model_len = self.max_model_len + def reset_mm_cache(self) -> None: + """ + Clear the multi-modal cache that was used during profiling, + but no longer needed during inference. + """ if self.mm_budget: self.mm_budget.reset_cache() + self.late_interaction_runner.clear() + + def reset_encoder_cache(self) -> None: + """Clear the GPU-side encoder cache storing vision embeddings. + + This should be called when model weights are updated to ensure + stale embeddings computed with old weights are not reused. + """ + self.encoder_cache.clear() + self.late_interaction_runner.clear() @managed_inference_mode() def init_fp8_kv_scales(self) -> None: @@ -721,7 +966,17 @@ def _make_buffer( with_numpy=numpy, ) - def _init_model_kwargs(self, num_tokens: int): + def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers: + if self._mamba_copy_bufs is None: + self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create( + self.max_num_reqs, + self.kv_cache_config, + self.model.get_mamba_state_copy_func(), + self._make_buffer, + ) + return self._mamba_copy_bufs + + def _init_model_kwargs(self): model_kwargs = dict[str, Any]() if not self.is_pooling_model: @@ -765,7 +1020,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ - # Attention free models have zero kv_cache_goups, however models + # Attention free models have zero kv_cache_groups, however models # like Mamba are also attention free but use the kv_cache for # keeping its internal state. This is why we check the number # of kv_cache groups instead of solely checking @@ -780,6 +1035,32 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: decode_threshold=self.reorder_batch_threshold, ) + def _init_kv_zero_meta(self) -> None: + """One-time precomputation for _zero_block_ids. + + Delegates to KVBlockZeroer.init_meta with the runner's state. + Called from gpu_worker.py outside the CuMem pool context. + """ + self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory) + self._kv_block_zeroer.init_meta( + attn_groups_iter=self._kv_cache_spec_attn_group_iterator(), + kernel_block_sizes=self._kernel_block_sizes, + cache_dtype=self.cache_config.cache_dtype, + runner_only_attn_layers=self.runner_only_attn_layers, + static_forward_context=(self.compilation_config.static_forward_context), + ) + + def _zero_block_ids(self, block_ids: list[int]) -> None: + """Zero the KV cache memory for the given block IDs.""" + if hasattr(self, "_kv_block_zeroer"): + self._kv_block_zeroer.zero_block_ids(block_ids) + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from current_platform.torch_device_fn.get_device_properties""" + + self.num_sms = num_compute_units(self.device.index) + # Note: used for model runner override. def _sync_device(self) -> None: current_platform.torch_device_fn.synchronize() @@ -798,6 +1079,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + self.late_interaction_runner.on_requests_finished( + scheduler_output.finished_req_ids + ) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -807,6 +1091,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) + # Zero GPU memory for freshly allocated cache blocks to prevent + # stale NaN/data from corrupting attention or SSM computation. + if scheduler_output.new_block_ids_to_zero: + self._zero_block_ids(scheduler_output.new_block_ids_to_zero) + # Free the cached encoder outputs. for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_cache.pop(mm_hash, None) @@ -833,10 +1122,23 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) + is_ngram_gpu = ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ) + if is_ngram_gpu: + ngram_gpu_new_reqs: list[CachedRequestState] = [] + reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id + if req_id in self.requests: + # For streaming case only. + req_state = self._update_streaming_request(req_id, new_req_data) + reqs_to_add.append(req_state) + continue + sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params @@ -872,6 +1174,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) self.requests[req_id] = req_state + self.late_interaction_runner.register_request(req_id, pooling_params) if sampling_params and sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( @@ -889,10 +1192,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self._init_xdrope_positions(req_state) reqs_to_add.append(req_state) + # Track new requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + + # Save scheduler-allocated spec lengths before trimming so + # prev_num_draft_len keeps the optimistic count for rejection correction. + original_num_spec_per_req: dict[str, int] = {} + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + for req_id, toks in scheduled_spec_tokens.items(): + original_num_spec_per_req[req_id] = len(toks) + update_scheduler_for_invalid_drafts( + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens_cpu, + scheduler_output, + self.input_batch.req_id_to_index, + ) # Wait until valid_sampled_tokens_count is copied to cpu, # then use it to update actual num_computed_tokens of each request. @@ -906,20 +1229,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_output_tokens = req_data.num_output_tokens[i] req_index = self.input_batch.req_id_to_index.get(req_id) - # prev_num_draft_len is used in async scheduling mode with - # spec decode. it indicates if need to update num_computed_tokens - # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], - # prev_num_draft_len = 0. - # second step: num_computed_tokens = 100(prompt lenth), - # spec_tokens = [a,b], prev_num_draft_len = 0. - # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], - # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain - # the spec tokens length, but in third step it contains the - # spec tokens length. we only need to update num_computed_tokens - # when prev_num_draft_len > 0. - if req_state.prev_num_draft_len: + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # first step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step doesn't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. if req_index is None: req_state.prev_num_draft_len = 0 else: @@ -930,24 +1253,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_computed_tokens -= num_rejected req_state.output_token_ids.extend([-1] * num_accepted) + if is_ngram_gpu and num_accepted > 0 and req_index is not None: + self.input_batch.num_tokens_no_spec[req_index] += num_accepted + # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: - # When using PP, the scheduler sends the sampled tokens back, - # because there's no direct communication between the first- - # stage worker and the last-stage worker. - new_token_ids = req_data.new_token_ids[i] - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = ( - num_computed_tokens + len(new_token_ids) - req_state.num_tokens - ) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + if not req_data.new_token_ids: + # Async scheduled PP: Sampled tokens propagated via GPU broadcast. + new_token_ids: list[int] = [] + else: + # Non-async scheduling with PP: The scheduler sends + # sampled token ids back because there's no direct communication + # between the first-stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:] + ) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -957,7 +1289,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_prompt_tokens[req_index] + num_output_tokens ) - self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx # Update the block IDs. @@ -985,6 +1316,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] reqs_to_add.append(req_state) + # Track resumed requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) continue # Update the persistent batch. @@ -1002,46 +1336,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index, start_token_index:end_token_index ] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index - self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [] - ) - num_spec_tokens = len(spec_token_ids) - # For async scheduling, token_ids_cpu assigned from - # spec_token_ids are placeholders and will be overwritten in - # _prepare_input_ids. - if num_spec_tokens: - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index - ] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec tokens. - self.input_batch.num_tokens[req_index] += num_spec_tokens - - # When speculative decoding is used with structured output, - # the scheduler can drop draft tokens that do not - # conform to the schema. This can result in - # scheduler_output.scheduled_spec_decode_tokens being empty, - # even when speculative decoding is enabled. - self.input_batch.spec_token_ids[req_index].clear() - self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) - - # there are no draft tokens with async scheduling, - # we clear the spec_decoding info in scheduler_output and - # use normal sampling but rejection_sampling. - if self.use_async_scheduling: - req_state.prev_num_draft_len = num_spec_tokens - if num_spec_tokens and self._draft_token_ids is None: - scheduler_output.total_num_scheduled_tokens -= num_spec_tokens - scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens - scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + # Restore scheduler-side draft count after ngram trimming. + if original_num_spec_per_req: + orig = original_num_spec_per_req.get(req_id, 0) + if orig != req_state.prev_num_draft_len: + req_state.prev_num_draft_len = orig + # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -1050,8 +1358,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + # Incrementally update ngram_gpu tensors after batch is stable + if is_ngram_gpu: + update_ngram_gpu_tensors_incremental( + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + ngram_gpu_new_reqs, + self.device, + _pinned_idx_buf=self._ngram_pinned_idx_buf, + _pinned_val_buf=self._ngram_pinned_val_buf, + ) + def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor + self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: """Update the cached states after model execution. @@ -1061,17 +1381,18 @@ def _update_states_after_model_execute( each sequence, and a shifting is done during the next iteration based on the number of accepted tokens. """ - if not self.model_config.is_hybrid or not self.speculative_config: + if not self.speculative_config or not self.model_config.is_hybrid: return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = ( + num_reqs = output_token_ids.size(0) + self.num_accepted_tokens.gpu[:num_reqs] = ( ( torch.cat( [ output_token_ids, torch.full( - (output_token_ids.size(0), 1), + (num_reqs, 1), -1, device=output_token_ids.device, ), @@ -1082,11 +1403,64 @@ def _update_states_after_model_execute( ) .int() .argmax(-1) - .cpu() - .numpy() ) - for i, num_tokens in enumerate(num_accepted_tokens): - self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + if self.cache_config.mamba_cache_mode == "align": + for i, num_tokens in enumerate( + self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() + ): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + else: + self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( + self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True + ) + assert self.num_accepted_tokens_event is not None + self.num_accepted_tokens_event.record() + + def _update_streaming_request( + self, req_id: str, new_req_data: NewRequestData + ) -> CachedRequestState: + """Updates streaming session request from `scheduled_new_reqs`. + + Removes the request from InputBatch (if present), updates the cached + state, and prepares it for re-addition to the batch. + + NOTE: prompt_token_ids includes intermediate output tokens - tokens + previously generated but now are input context (part of the prompt). + """ + self.input_batch.remove_request(req_id) + req_state = self.requests[req_id] + + req_state.prompt_token_ids = new_req_data.prompt_token_ids + req_state.mm_features = new_req_data.mm_features + req_state.prompt_embeds = new_req_data.prompt_embeds + req_state.sampling_params = new_req_data.sampling_params + req_state.pooling_params = new_req_data.pooling_params + self.late_interaction_runner.register_request(req_id, req_state.pooling_params) + req_state.block_ids = new_req_data.block_ids + req_state.num_computed_tokens = new_req_data.num_computed_tokens + req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds + ) + + # Clear `output_token_ids` as previous output tokens are now part of + # `prompt_token_ids`. + req_state.output_token_ids.clear() + + if self.uses_mrope: + self._init_mrope_positions(req_state) + + return req_state def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -1123,22 +1497,20 @@ def _extract_mm_kwargs( if not scheduler_output or not self.is_multimodal_raw_input_only_model: return {} - mm_kwargs = list[MultiModalKwargsItem]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() for req in scheduler_output.scheduled_new_reqs: for feature in req.mm_features: if feature.data is not None: - mm_kwargs.append(feature.data) + mm_kwargs.append((feature.modality, feature.data)) # Input all modalities at once - model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs( mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, ): - mm_kwargs_combined.update(mm_kwargs_group) + mm_kwargs_combined.update(mm_kwargs_batch) return mm_kwargs_combined @@ -1149,6 +1521,9 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: mm_budget = self.mm_budget assert mm_budget is not None + if not mm_budget.mm_max_toks_per_item: + return {} # No tower modalities (embed-only mode) + dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) @@ -1231,30 +1606,30 @@ def _prepare_input_ids( prev_draft_token_indices.extend(range(start, start + draft_len)) indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(sample_flattened_indices) + num_common_tokens = len(sample_flattened_indices) total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens - if num_commmon_tokens < total_without_spec: + if num_common_tokens < total_without_spec: # If not all requests are decodes from the last iteration, # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.enable_prompt_embeds: self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) - if num_commmon_tokens == 0: + if num_common_tokens == 0: # No requests in common with the previous iteration # So input_ids.cpu will have all the input ids. return - if indices_match and max_flattened_index == (num_commmon_tokens - 1): + if indices_match and max_flattened_index == (num_common_tokens - 1): # Common-case optimization: the batch is unchanged # and no reordering happened. # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. - self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + self.input_ids.gpu[:num_common_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_common_tokens, 0], non_blocking=True, ) if self.enable_prompt_embeds: - self.is_token_ids.gpu[:num_commmon_tokens] = True + self.is_token_ids.gpu[:num_common_tokens] = True return # Upload the index tensors asynchronously so the scatter can be non-blocking. sampled_tokens_index_tensor = torch.tensor( @@ -1286,7 +1661,6 @@ def _prepare_input_ids( # because input_ids dtype is torch.int32, # so convert draft_token_ids to torch.int32 here. draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) - self._draft_token_ids = None self.input_ids.gpu.scatter_( dim=0, @@ -1299,12 +1673,14 @@ def _get_encoder_seq_lens( num_scheduled_tokens: dict[str, int], kv_cache_spec: KVCacheSpec, num_reqs: int, + for_cudagraph_capture: bool = False, ) -> tuple[torch.Tensor | None, np.ndarray | None]: if not isinstance(kv_cache_spec, CrossAttentionSpec): return None, None # Zero out buffer for padding requests that are not actually scheduled (CGs) self.encoder_seq_lens.np[:num_reqs] = 0 + # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch for req_id in num_scheduled_tokens: @@ -1321,6 +1697,15 @@ def _get_encoder_seq_lens( feature.mm_position.length for feature in req_state.mm_features ) self.encoder_seq_lens.np[req_index] = encoder_input_tokens + if for_cudagraph_capture: + # During CUDA graph capture, we need to use realistic encoder lengths + # so that max_seqlen_k is captured with the correct value. + max_encoder_len = getattr( + self.model_config.hf_config, + "max_source_positions", + self.max_encoder_len, + ) + self.encoder_seq_lens.np[:num_reqs] = max_encoder_len self.encoder_seq_lens.copy_to_gpu(num_reqs) encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs] @@ -1501,7 +1886,6 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: @@ -1518,14 +1902,11 @@ def _prepare_inputs( ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = ( - len(draft_token_ids) - if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx] - ) - else -1 - ) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ): + num_decode_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens ) @@ -1564,6 +1945,7 @@ def _build_attention_metadata( for_cudagraph_capture: bool = False, num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, + slot_mappings: dict[int, torch.Tensor] | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] @@ -1578,7 +1960,7 @@ def _build_attention_metadata( attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: - attn_metadata = [{} for _ in range(len(ubatch_slices))] + attn_metadata = [dict() for _ in range(len(ubatch_slices))] if for_cudagraph_capture: # For some attention backends (e.g. FA) with sliding window models we need @@ -1589,6 +1971,8 @@ def _build_attention_metadata( max_seq_len = self.seq_lens.np[:num_reqs].max().item() if use_spec_decode: + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() self.num_accepted_tokens.np[:num_reqs] = ( self.input_batch.num_accepted_tokens_cpu[:num_reqs] ) @@ -1597,7 +1981,7 @@ def _build_attention_metadata( kv_cache_groups = self.kv_cache_config.kv_cache_groups - def _get_block_table_and_slot_mapping(kv_cache_gid: int): + def _get_block_table(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): @@ -1606,24 +1990,23 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): dtype=torch.int32, device=self.device, ) - slot_mapping = torch.zeros( - (num_tokens_padded,), - dtype=torch.int64, - device=self.device, - ) else: blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) - slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - slot_mapping[num_tokens:num_tokens_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + return blk_table_tensor - return blk_table_tensor, slot_mapping + assert slot_mappings is not None + block_table_gid_0 = _get_block_table(0) + slot_mapping_gid_0 = slot_mappings[0] - block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + if self.routed_experts_initialized: + attn_gid = self.routed_experts_attn_gid + slot_mapping_attn = slot_mappings[attn_gid] + self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -1662,6 +2045,15 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): logits_indices ) + # Cache attention metadata builds across hybrid KV-cache groups + # The only thing that changes between different hybrid KV-cache groups when the + # same metadata builder and KVCacheSpec is the same is the block table, so we + # can cache the attention metadata builds and just update the block table using + # `builder.update_block_table` if the builder supports it. + cached_attn_metadata: dict[ + tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata + ] = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1669,15 +2061,22 @@ def _build_attn_group_metadata( ubid: int | None = None, ) -> None: attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder(ubid or 0) + kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + kv_cache_spec = kv_cache_spec.kv_cache_specs[attn_group.layer_names[0]] + cache_key = (kv_cache_spec, type(builder)) + cascade_attn_prefix_len = ( cascade_attn_prefix_lens[kv_cache_gid][attn_gid] if cascade_attn_prefix_lens else 0 ) - builder = attn_group.get_metadata_builder(ubid or 0) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance( + builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) + ): assert ubid is None, "UBatching not supported with GDN yet" extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], @@ -1690,12 +2089,23 @@ def _build_attn_group_metadata( attn_metadata_i = builder.build_for_cudagraph_capture( common_attn_metadata ) + elif ( + cache_key in cached_attn_metadata + and builder.supports_update_block_table + ): + attn_metadata_i = builder.update_block_table( + cached_attn_metadata[cache_key], + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping, + ) else: attn_metadata_i = builder.build( common_prefix_len=cascade_attn_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + if builder.supports_update_block_table: + cached_attn_metadata[cache_key] = attn_metadata_i if ubid is None: assert isinstance(attn_metadata, dict) @@ -1719,15 +2129,15 @@ def _build_attn_group_metadata( num_scheduled_tokens or {}, kv_cache_group.kv_cache_spec, num_reqs_padded, + for_cudagraph_capture=for_cudagraph_capture, ) if kv_cache_gid > 0: - cm.block_table_tensor, cm.slot_mapping = ( - _get_block_table_and_slot_mapping(kv_cache_gid) - ) + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + if self.drafter.kv_cache_gid == kv_cache_gid: spec_decode_common_attn_metadata = cm else: spec_decode_common_attn_metadata = cm @@ -1808,7 +2218,6 @@ def _compute_cascade_attn_prefix_lens( return cascade_attn_prefix_lens if use_cascade_attn else None - def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, @@ -1834,6 +2243,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. @@ -1892,8 +2302,6 @@ def _compute_cascade_attn_prefix_len( and kv_cache_spec.attention_chunk_size is not None ) assert isinstance(kv_cache_spec, AttentionSpec) - - ###TODO(lms): fix use of num_sms use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, @@ -1902,7 +2310,7 @@ def _compute_cascade_attn_prefix_len( use_alibi=self.use_alibi, use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, - num_sms=1, + num_sms=self.num_sms, dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 @@ -2076,6 +2484,7 @@ def _calc_spec_decode_metadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, + cu_num_sampled_tokens=cu_num_sampled_tokens, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, @@ -2096,42 +2505,45 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item() ) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits + # Dispatch for the decoder portion of the model. + _, batch_desc = self.cudagraph_dispatcher.dispatch( + num_logits, invalid_modes={CUDAGraphMode.FULL} + ) + num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ :num_logits_padded ] return logits_indices_padded - def _batch_mm_kwargs_from_scheduler( + def _batch_mm_inputs_from_scheduler( self, scheduler_output: "SchedulerOutput", - ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: - """Batch multimodal kwargs from scheduled encoder inputs. + ) -> tuple[ + list[str], + list[tuple[str, MultiModalKwargsItem]], + list[tuple[str, PlaceholderRange]], + ]: + """Batch multimodal inputs from scheduled encoder inputs. Args: scheduler_output: The scheduler output containing scheduled encoder inputs. Returns: - A tuple of (mm_kwargs, req_ids_pos) where: - - mm_kwargs: List of multimodal kwargs items to be batched - - mm_hashes_pos: List of (mm_hash, position_info) tuples + A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where: + - mm_hashes: List of multimodal hashes for each item + - mm_kwargs: List of multimodal kwargs for each item + - mm_lora_refs: List of (req_id, placeholder_range) for each item """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return [], [] - # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - # list of tuple (mm_hash, position_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + return [], [], [] + + mm_hashes = list[str]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() + # Multimodal LoRA reference info to map each multimodal item + # back to its request & position + mm_lora_refs = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] @@ -2139,23 +2551,29 @@ def _batch_mm_kwargs_from_scheduler( mm_feature = req_state.mm_features[mm_input_id] if mm_feature.data is None: continue - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - return mm_kwargs, mm_hashes_pos + mm_hashes.append(mm_feature.identifier) + mm_kwargs.append((mm_feature.modality, mm_feature.data)) + mm_lora_refs.append((req_id, mm_feature.mm_position)) + + return mm_hashes, mm_kwargs, mm_lora_refs def _execute_mm_encoder( self, scheduler_output: "SchedulerOutput" ) -> list[torch.Tensor]: - # Batch the multi-modal inputs using the helper method. - mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler( scheduler_output ) if not mm_kwargs: return [] + should_time = bool( + self.observability_config + and self.observability_config.enable_mm_processor_stats + and scheduler_output.scheduled_encoder_inputs + ) + # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. @@ -2164,13 +2582,72 @@ def _execute_mm_encoder( # multimodal inputs. The proper solution should be reordering the # encoder outputs. model = cast(SupportsMultiModal, self.model) + + if self.lora_config and self.lora_manager.supports_tower_connector_lora(): + # Build LoRA mappings independently for encoder inputs + # (encoder batch structure is different from main batch) + prompt_lora_mapping = [] + token_lora_mapping = [] + lora_requests = set() + encoder_token_counts = [] + + for req_id, pos_info in mm_lora_refs: + req_idx = self.input_batch.req_id_to_index[req_id] + lora_id = int(self.input_batch.request_lora_mapping[req_idx]) + + # Prefer pos_info.get_num_embeds to count precise MM embedding tokens. + num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined] + pos_info.get_num_embeds() + ) + prompt_lora_mapping.append(lora_id) + token_lora_mapping.extend([lora_id] * num_tokens) + encoder_token_counts.append(num_tokens) + + if lora_id > 0: + lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id) + if lora_request is not None: + lora_requests.add(lora_request) + + # Set tower adapter mapping + tower_mapping = LoRAMapping( + tuple(token_lora_mapping), + tuple(prompt_lora_mapping), + is_prefill=True, + type=LoRAMappingType.TOWER, + ) + self.lora_manager.set_active_adapters(lora_requests, tower_mapping) + + if hasattr(self.model, "get_num_mm_connector_tokens"): + post_op_counts = [ + self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined] + for num_tokens in encoder_token_counts + ] + + connector_token_mapping = np.repeat( + np.array(prompt_lora_mapping, dtype=np.int32), + np.array(post_op_counts, dtype=np.int32), + ) + connector_mapping = LoRAMapping( + index_mapping=tuple(connector_token_mapping.tolist()), + prompt_mapping=tuple(prompt_lora_mapping), + is_prefill=True, + type=LoRAMappingType.CONNECTOR, + ) + + self.lora_manager.set_active_adapters( + lora_requests, + connector_mapping, + ) + encoder_outputs: list[torch.Tensor] = [] - for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs + current_item_idx = 0 + for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs( mm_kwargs, device=self.device, pin_memory=self.pin_memory, ): - curr_group_outputs: list[torch.Tensor] = [] + batch_outputs: MultiModalEmbeddings # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when @@ -2186,40 +2663,48 @@ def _execute_mm_encoder( and modality == "video" and num_items > 1 ): - for video_mm_kwargs_item in filter( - lambda item: item.modality == "video", mm_kwargs - ): - _, _, micro_batch_mm_inputs = next( - group_mm_kwargs_by_modality( - [video_mm_kwargs_item], - device=self.device, - pin_memory=self.pin_memory, + batch_outputs_lst = list[torch.Tensor]() + for video_idx in range(num_items): + video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx] + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx + video_idx, 1 + ): + _, _, micro_batch_mm_inputs = next( + group_and_batch_mm_kwargs( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + ) ) - ) - micro_batch_outputs = model.embed_multimodal( - **micro_batch_mm_inputs - ) + micro_batch_outputs = model.embed_multimodal( + **micro_batch_mm_inputs + ) + + batch_outputs_lst.extend(micro_batch_outputs) - curr_group_outputs.extend(micro_batch_outputs) + batch_outputs = batch_outputs_lst else: # Run the encoder. - # `curr_group_outputs` is either of the following: + # `batch_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) # in case feature_size is fixed across all multimodal items. # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=num_items, - ) - encoder_outputs.extend(curr_group_outputs) + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx, num_items + ): + batch_outputs = model.embed_multimodal(**mm_kwargs_batch) + + sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items) + encoder_outputs.extend(batch_outputs) + + current_item_idx += num_items # Cache the encoder outputs by mm_hash - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache[mm_hash] = output logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) @@ -2233,8 +2718,13 @@ def _gather_mm_embeddings( ) -> tuple[list[torch.Tensor], torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + # Swap to the other buffer to avoid race condition with previous + # iteration's async copy that may still be reading from CPU. + self.is_mm_embed_idx = 1 - self.is_mm_embed_idx + is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx] + mm_embeds = list[torch.Tensor]() - is_mm_embed = self.is_mm_embed.cpu + is_mm_embed = is_mm_embed_buf.cpu is_mm_embed[:total_num_scheduled_tokens] = False req_start_idx = 0 @@ -2290,9 +2780,15 @@ def _gather_mm_embeddings( mm_embeds_item = encoder_output[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( - True if is_embed is None else is_embed - ) + # OR mask for overlapping mm_features (use_audio_in_video) + if is_embed is None: + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True + ) + else: + is_mm_embed[ + req_start_pos + start_idx : req_start_pos + end_idx + ] |= is_embed mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: @@ -2312,7 +2808,7 @@ def _gather_mm_embeddings( mm_embeds.extend(mm_embeds_req) req_start_idx += num_scheduled_tokens - is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) + is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) @@ -2325,8 +2821,10 @@ def _gather_mm_embeddings( return mm_embeds, is_mm_embed def get_model(self) -> nn.Module: - # get raw model out of the cudagraph wrapper. + if not hasattr(self, "model"): + raise ValueError("Cannot get model before model has been initialized") if isinstance(self.model, (GraphWrapper, UBatchWrapper)): + # get raw model out of the cudagraph wrapper. return self.model.unwrap() return self.model @@ -2343,6 +2841,9 @@ def get_supported_generation_tasks(self) -> list[GenerationTask]: supported_tasks.append("transcription") + if supports_realtime(model): + supported_tasks.append("realtime") + return supported_tasks def get_supported_pooling_tasks(self) -> list[PoolingTask]: @@ -2405,7 +2906,7 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ - if not self.parallel_config.enable_eplb: + if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed: return assert self.eplb_state is not None @@ -2417,59 +2918,104 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: log_stats=self.parallel_config.eplb_config.log_balancedness, ) + def setup_eplb_from_mapping( + self, + expanded_physical_to_logical: torch.Tensor, + old_num_physical_experts: int, + ) -> None: + model = self.get_model() + assert is_mixture_of_experts(model) + + self.eplb_state = EplbState.from_mapping( + model=model, + model_config=self.model_config, + device=self.device, + parallel_config=self.parallel_config, + expanded_physical_to_logical=expanded_physical_to_logical, + num_valid_physical_experts=old_num_physical_experts, + ) + def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + kv_connector_output: KVConnectorOutput | None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + num_reqs = self.input_batch.num_reqs + assert num_reqs == len(self.input_batch.pooling_params), ( "Either all or none of the requests in a batch must be pooling request" ) hidden_states = hidden_states[:num_scheduled_tokens] - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device + num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device ) model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( - hidden_states=hidden_states, - pooling_metadata=pooling_metadata, - ) - raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True) if x is not None else x, - raw_pooler_output, + hidden_states=hidden_states, pooling_metadata=pooling_metadata ) - self._sync_device() - - pooler_output: list[torch.Tensor | None] = [] - for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens - ): - output = raw_output if seq_len == prompt_len else None - pooler_output.append(output) - return ModelRunnerOutput( + finished_mask = [ + seq_len == prompt_len + for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens) + ] + raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output( + raw_pooler_output=raw_pooler_output, + pooling_params=pooling_metadata.pooling_params, req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, + finished_mask=finished_mask, ) - def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: - # Pad tokens to multiple of tensor_parallel_size when - # enabled collective fusion for SP - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.compilation_config.pass_config.enable_sp and tp_size > 1: - return round_up(num_scheduled_tokens, tp_size) + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids.copy(), + req_id_to_index=self.input_batch.req_id_to_index.copy(), + kv_connector_output=kv_connector_output, + ) + + if raw_pooler_output is None or not any(finished_mask): + model_runner_output.pooler_output = [None] * num_reqs + return model_runner_output + + if self.use_async_scheduling: + return AsyncGPUPoolingModelRunnerOutput( + model_runner_output=model_runner_output, + raw_pooler_output=raw_pooler_output, + finished_mask=finished_mask, + async_output_copy_stream=self.async_output_copy_stream, + ) + + model_runner_output.pooler_output = _copy_pooler_output_to_cpu( + raw_pooler_output=raw_pooler_output, + finished_mask=finished_mask, + ) + self._sync_device() + + return model_runner_output + + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.compilation_config.pass_config.enable_sp and tp_size > 1: + return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens + def _prepare_mm_inputs( + self, num_tokens: int + ) -> tuple[torch.Tensor | None, torch.Tensor]: + if self.model.requires_raw_input_tokens: + input_ids = self.input_ids.gpu[:num_tokens] + else: + input_ids = None + + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + return input_ids, inputs_embeds + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -2512,10 +3058,9 @@ def _preprocess( # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) model_kwargs = { - **self._init_model_kwargs(num_scheduled_tokens), + **self._init_model_kwargs(), **self._extract_mm_kwargs(scheduler_output), } elif self.enable_prompt_embeds and is_first_rank: @@ -2543,7 +3088,7 @@ def _preprocess( self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() input_ids = None else: # For text-only models, we use token ids as input. @@ -2552,7 +3097,7 @@ def _preprocess( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_input_tokens] @@ -2594,22 +3139,27 @@ def _sample( ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() if spec_decode_metadata is None: - # Update output token ids with tokens sampled in last step - # if async scheduling and required by current sampling params. - self.input_batch.update_async_output_token_ids() return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) + # Update spec_token_ids with real draft tokens from pre step only when + # output_token_ids is needed (penalties or bad_words are in use). + if self.use_async_scheduling and self._draft_token_req_ids is not None: + draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu() + self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs logits, sampling_metadata, ) - self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( @@ -2651,7 +3201,7 @@ def _bookkeeping_sync( sampled_token_ids = sampler_output.sampled_token_ids logprobs_tensors = sampler_output.logprobs_tensors invalid_req_indices = [] - cu_num_tokens: list[int] | None = None + logprobs_lists = None if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2661,13 +3211,16 @@ def _bookkeeping_sync( # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[int(i)].clear() + + if logprobs_tensors is not None: + logprobs_lists = logprobs_tensors.tolists() else: # Includes spec decode tokens. - valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, discard_sampled_tokens_req_indices, - return_cu_num_tokens=logprobs_tensors is not None, + logprobs_tensors=logprobs_tensors, ) else: valid_sampled_token_ids = [] @@ -2715,18 +3268,11 @@ def _bookkeeping_sync( self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - logprobs_lists = ( - logprobs_tensors.tolists(cu_num_tokens) - if not self.use_async_scheduling and logprobs_tensors is not None - else None - ) - # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], @@ -2790,6 +3336,27 @@ def _model_forward( **model_kwargs, ) + @staticmethod + def _is_uniform_decode( + max_num_scheduled_tokens: int, + uniform_decode_query_len: int, + num_tokens: int, + num_reqs: int, + force_uniform_decode: bool | None = None, + ) -> bool: + """ + Checks if it's a decode batch with same amount scheduled tokens + across all requests. + """ + return ( + ( + (max_num_scheduled_tokens == uniform_decode_query_len) + and (num_tokens == max_num_scheduled_tokens * num_reqs) + ) + if force_uniform_decode is None + else force_uniform_decode + ) + def _determine_batch_execution_and_padding( self, num_tokens: int, @@ -2803,6 +3370,7 @@ def _determine_batch_execution_and_padding( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, @@ -2811,14 +3379,12 @@ def _determine_batch_execution_and_padding( torch.Tensor | None, CUDAGraphStat | None, ]: - num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - uniform_decode = ( - ( - (max_num_scheduled_tokens == self.uniform_decode_query_len) - and (num_tokens_padded == max_num_scheduled_tokens * num_reqs) - ) - if force_uniform_decode is None - else force_uniform_decode + uniform_decode = self._is_uniform_decode( + max_num_scheduled_tokens=max_num_scheduled_tokens, + uniform_decode_query_len=self.uniform_decode_query_len, + num_tokens=num_tokens, + num_reqs=num_reqs, + force_uniform_decode=force_uniform_decode, ) # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output # is present). Also, chunked-prefill is disabled, so batch are uniform. @@ -2826,46 +3392,49 @@ def _determine_batch_execution_and_padding( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - has_lora = ( - len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None - else force_has_lora + # Compute LoRA state for cudagraph dispatch + num_active_loras = ( + force_num_active_loras + if force_num_active_loras is not None + else len(self.input_batch.lora_id_to_lora_request) ) + has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora + + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - dispatch_cudagraph = ( - lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( + def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None): + return self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, uniform_decode=uniform_decode, - disable_full=disable_full, + num_active_loras=num_active_loras, + valid_modes={CUDAGraphMode.NONE} if force_eager else valid_modes, + invalid_modes={CUDAGraphMode.FULL} if disable_full else None, ) - if not force_eager - else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) - ) cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, use_cascade_attn or has_encoder_output + num_tokens_padded, disable_full=use_cascade_attn or has_encoder_output ) num_tokens_padded = batch_descriptor.num_tokens + if self.compilation_config.pass_config.enable_sp: + assert ( + batch_descriptor.num_tokens + % self.vllm_config.parallel_config.tensor_parallel_size + == 0 + ), ( + "Sequence parallelism requires num_tokens to be " + "a multiple of tensor parallel size" + ) # Extra coordination when running data-parallel since we need to coordinate # across ranks should_ubatch, num_tokens_across_dp = False, None if self.vllm_config.parallel_config.data_parallel_size > 1: - # Disable DP padding when running eager to avoid excessive padding when - # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller - # in a P/D setup and still use CUDA graphs (enabled by this padding) on the - # decoder. - allow_dp_padding = ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - ) - should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( coordinate_batch_across_dp( num_tokens_unpadded=num_tokens, parallel_config=self.parallel_config, allow_microbatching=allow_microbatching, - allow_dp_padding=allow_dp_padding, num_tokens_padded=num_tokens_padded, uniform_decode=uniform_decode, num_scheduled_tokens_per_request=num_scheduled_tokens_np, @@ -2880,7 +3449,7 @@ def _determine_batch_execution_and_padding( # Re-dispatch with DP padding so we have the correct batch_descriptor cudagraph_mode, batch_descriptor = dispatch_cudagraph( num_tokens_padded, - disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, + valid_modes={CUDAGraphMode(synced_cudagraph_mode)}, ) # Assert to make sure the agreed upon token count is correct otherwise # num_tokens_across_dp will no-longer be valid @@ -2939,152 +3508,281 @@ def _register_layerwise_nvtx_hooks(self) -> None: pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True + def _get_slot_mappings( + self, + num_tokens_padded: int, + num_reqs_padded: int, + num_tokens_unpadded: int, + ubatch_slices: "UBatchSlices | None" = None, + ) -> tuple[ + dict[int, torch.Tensor] | None, + dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + ]: + """ + Build slot mappings in both formats needed by the system. + + Args: + num_tokens_padded: Total number of tokens (padded) + num_reqs_padded: Total number of requests (padded) + num_tokens_unpadded: Actual number of tokens (unpadded) + ubatch_slices: Optional ubatch slicing info for DBO + + Returns: + A tuple of: + - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata + - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext + """ + if not ( + hasattr(self, "kv_cache_config") + and self.kv_cache_config is not None + and len(self.kv_cache_config.kv_cache_groups) > 0 + ): + return None, None + + def _get_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + kv_cache_gid + ].kv_cache_spec + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + slot_mapping = torch.zeros( + (num_tokens_padded,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_gid] + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1) + + return slot_mapping + + slot_mappings_by_gid = { + gid: _get_slot_mapping(gid) + for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups) + } + + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): + slot_mapping = slot_mappings_by_gid[gid] + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping + + if ubatch_slices is not None: + result: list[dict[str, torch.Tensor]] = [] + for ubatch in ubatch_slices: + sliced_mappings: dict[str, torch.Tensor] = {} + for layer_name, slot_mapping in slot_mappings_by_layer.items(): + sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] + result.append(sliced_mappings) + return slot_mappings_by_gid, result + + return slot_mappings_by_gid, slot_mappings_by_layer + @managed_inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | IntermediateTensors | None: - + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError( "State error: sample_tokens() must be called " "after execute_model() returns None." ) - # self._draft_token_ids is None when `input_fits_in_drafter=False` - # and there is no draft tokens scheduled. so it need to update the - # spec_decoding info in scheduler_output with async_scheduling. - # use deepcopy to avoid the modification has influence on the - # scheduler_output in engine core process. - # TODO(Ronald1995): deepcopy is expensive when there is a large - # number of requests, optimize it later. + if self.routed_experts_initialized: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + # If ngram_gpu is used, we need to copy the scheduler_output to avoid + # the modification has influence on the scheduler_output in engine core process. + # The replace is much faster than deepcopy. if ( - self.use_async_scheduling - and self.num_spec_tokens - and self._draft_token_ids is None + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() ): - scheduler_output = deepcopy(scheduler_output) + num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy() + spec_decode_tokens_copy = ( + scheduler_output.scheduled_spec_decode_tokens.copy() + ) + scheduler_output = replace( + scheduler_output, + num_scheduled_tokens=num_scheduled_tokens_copy, + scheduled_spec_decode_tokens=spec_decode_tokens_copy, + ) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("gpu_model_runner: preprocess"): - with self.synchronize_input_prep(): - # Update persistent batch states. - self._update_states(scheduler_output) - - if has_ec_transfer() and get_ec_transfer().is_producer: - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ) as ec_connector_output: - self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output(scheduler_output) - - if not num_scheduled_tokens: - if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" - and self.parallel_config.data_parallel_size > 1 - ): - # this is a corner case when both external launcher - # and DP are enabled, num_scheduled_tokens could be - # 0, and has_unfinished_requests in the outer loop - # returns True. before returning early here we call - # dummy run to ensure coordinate_batch_across_dp - # is called into to avoid out of sync issues. - self._dummy_run(1) - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward( - scheduler_output, self.vllm_config - ) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect " - "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs" - ) + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions( + scheduler_output.preempted_req_ids + ) - num_reqs = self.input_batch.num_reqs - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + # Update persistent batch states. + self._update_states(scheduler_output) - ( - logits_indices, - spec_decode_metadata, - ) = self._prepare_inputs( + if has_ec_transfer() and not get_ec_transfer().is_consumer: + with self.maybe_get_ec_connector_output( scheduler_output, - num_scheduled_tokens_np, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" ) - cascade_attn_prefix_lens = None - # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: - # Pre-compute cascade attention prefix lengths - cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( - num_scheduled_tokens_np, - self.input_batch.num_computed_tokens_cpu[:num_reqs], - scheduler_output.num_common_prefix_blocks, - ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - ( - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, - cudagraph_stats, - ) = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens_np, - max_num_scheduled_tokens=max_num_scheduled_tokens, - use_cascade_attn=cascade_attn_prefix_lens is not None, - num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), - ) + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) - logger.debug( - "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " - "should_ubatch: %s, num_tokens_across_dp: %s", - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, ) - num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + # True if any attention backend handles KV cache update separately + # from forward() (i.e., forward_includes_kv_cache_update=False). When true, + # slot_mappings must use padded dimensions to match the key/value tensors. + has_separate_kv_update = not all( + all( + g.backend.forward_includes_kv_cache_update + for g in self.attn_groups[id] ) - ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, - num_scheduled_tokens_np, - num_tokens_padded, - num_reqs_padded, + for id, spec in enumerate(self.kv_cache_config.kv_cache_groups) + if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec) + ) + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), ) - pad_attn = cudagraph_mode == CUDAGraphMode.FULL - - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - - (attn_metadata, spec_decode_common_attn_metadata) = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded + if pad_attn or has_separate_kv_update + else num_tokens_unpadded, + num_reqs_padded=( + num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs + ), + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) + ) ( input_ids, @@ -3105,8 +3803,18 @@ def execute_model( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + # Encoder-decoder models can only compile the pure decode steps where no + # encoder inputs are present. Use eager for the first pass. + num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs) + has_encoder_input = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) + # Run the model. # Use persistent buffers for CUDA graphs. + # When spec decode is enabled, defer connector finalization + # (wait_for_save + clear metadata) until after draft model runs. + defer_kv_connector_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -3116,9 +3824,14 @@ def execute_model( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, + skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, + defer_finalize=defer_kv_connector_finalize, + ) as kv_connector_output, ): model_output = self._model_forward( input_ids=input_ids, @@ -3148,11 +3861,12 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool( - hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output, ) - output.kv_connector_output = kv_connector_output - return output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -3179,6 +3893,7 @@ def execute_model( model_output_broadcast_data: dict[str, Any] = {} if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() + broadcasted = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 ) @@ -3195,20 +3910,21 @@ def execute_model( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) self.kv_connector_output = kv_connector_output - return None @managed_inference_mode() def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: - kv_connector_output = self.kv_connector_output - self.kv_connector_output = None - if self.execute_model_state is None: - # Nothing to do (PP non-final rank case), output isn't used. + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + # receive sampled token ids from the last PP rank. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # type: ignore[return-value] @@ -3232,6 +3948,7 @@ def sample_tokens( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3245,6 +3962,21 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids, scheduler_output + ) + if self.use_async_scheduling: + pp = get_pp_group() + # For torchrun external_launcher PP mode with broadcast_pp_output=True, + # PP outputs have been broadcasted to all ranks at logits computation. + # Therefore, here is no need to send sampled token ids again in this case. + if not self.broadcast_pp_output and pp.world_size > 1 and pp.is_last_rank: + self._pp_broadcast_prev_sampled_token_ids( + sampler_output.sampled_token_ids + ) + + self._draft_token_ids = None + self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -3259,51 +3991,80 @@ def propose_draft_token_ids(sampled_token_ids): aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, + slot_mappings, ) + self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config - use_padded_batch_for_eagle = ( - spec_config is not None - and spec_config.use_eagle() - and not spec_config.disable_padded_drafter_batch - ) - effective_drafter_max_model_len = self.max_model_len - if effective_drafter_max_model_len is None: - effective_drafter_max_model_len = self.model_config.max_model_len - if ( - spec_config is not None - and spec_config.draft_model_config is not None - and spec_config.draft_model_config.max_model_len is not None - ): - effective_drafter_max_model_len = ( - spec_config.draft_model_config.max_model_len + propose_drafts_after_bookkeeping = False + if spec_config is not None: + input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + <= self.effective_drafter_max_model_len ) - input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens - <= effective_drafter_max_model_len - ) - if use_padded_batch_for_eagle: - assert self.speculative_config is not None - assert isinstance(self.drafter, EagleProposer) - sampled_token_ids = sampler_output.sampled_token_ids - if input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens + use_gpu_toks = ( + spec_config.use_eagle() + or spec_config.uses_draft_model() + or spec_config.uses_extract_hidden_states() + ) and not spec_config.disable_padded_drafter_batch + if use_gpu_toks: + # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampled_token_ids) - elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None - next_token_ids, valid_sampled_tokens_count = ( - self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - ) - self._copy_valid_sampled_token_count( - next_token_ids, valid_sampled_tokens_count + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, ) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + elif ( + spec_config.use_ngram_gpu() + and not spec_config.disable_padded_drafter_batch + ): + assert isinstance(self.drafter, NgramProposerGPU) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + # Since we couldn't run the drafter, + # just use zeros for the draft tokens. + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + else: + propose_drafts_after_bookkeeping = input_fits_in_drafter with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -3323,25 +4084,38 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) - if ( - self.speculative_config - and not use_padded_batch_for_eagle - and input_fits_in_drafter - ): + if propose_drafts_after_bookkeeping: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) + # Finalize KV connector (wait_for_save + clear metadata) after + # draft model runs. Deferred from target model forward to allow + # draft model to also save its KV cache. + if spec_config is not None: + self.finalize_kv_connector() + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() + + # self.kv_connector_output may be modified during drafting + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.routed_experts_initialized: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs @@ -3350,13 +4124,12 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_stats=cudagraph_stats, ) - # Advance IO step after the full inference cycle (forward + sampling) - # so that one step encompasses both execute_model and sample_tokens. + # FL: Advance IO step after the full inference cycle advance_io_step() - ### TODO(lms): abstract async schedule for all hardware if not self.use_async_scheduling: return output + with record_function_or_nullcontext( "gpu_model_runner: AsyncGPUModelRunnerOutput" ): @@ -3380,17 +4153,95 @@ def propose_draft_token_ids(sampled_token_ids): return async_output + def _pp_broadcast_prev_sampled_token_ids( + self, sampled_token_ids: torch.Tensor + ) -> None: + """Broadcast sampled token ids (GPU) from last PP stage""" + pp = get_pp_group() + assert pp.is_last_rank + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, ( + "PP+async expects sampled_token_ids to have shape [num_reqs, 1]" + ) + torch.distributed.broadcast( + sampled_token_ids, src=pp.rank, group=pp.device_group + ) + + def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None: + """Receive sampled token ids broadcast from last PP stage""" + pp = get_pp_group() + assert not pp.is_last_rank + num_reqs = self.input_batch.num_reqs + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device) + torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group) + self.input_batch.prev_sampled_token_ids = recv + + # construct `prev_req_id_to_index` here so `_prepare_input_ids` + # can map req_id -> previous batch row + discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0] + discard_req_indices_set = set(discard_req_indices) + prev_req_id_to_index: dict[str, int] = {} + for i, req_id in enumerate(self.input_batch.req_ids): + if i in discard_req_indices_set: + continue + prev_req_id_to_index[req_id] = i + # PP+async scheduling: advance per-request local cached output length by + # appending a placeholder (-1) token id. + if (req_state := self.requests.get(req_id)) is not None: + req_state.output_token_ids.append(-1) + self.input_batch.prev_req_id_to_index = prev_req_id_to_index + def take_draft_token_ids(self) -> DraftTokenIds | None: - if self._draft_token_ids is None: + if not self.num_spec_tokens or not self._draft_token_req_ids: return None - req_ids = self.input_batch.req_ids - if isinstance(self._draft_token_ids, torch.Tensor): - draft_token_ids = self._draft_token_ids.tolist() - else: - draft_token_ids = self._draft_token_ids - self._draft_token_ids = None + draft_token_ids, req_ids = self._get_draft_token_ids_cpu() return DraftTokenIds(req_ids, draft_token_ids) + def _copy_draft_token_ids_to_cpu( + self, scheduler_output: "SchedulerOutput", zeros_only: bool = False + ) -> None: + # Check if we need to copy draft tokens to CPU. In async scheduling, + # we only copy when needed for structured output, penalties or bad_words. + if self.use_async_scheduling and not ( + scheduler_output.has_structured_output_requests + or self.input_batch.sampling_metadata.output_token_ids + ): + return + # We must also set the corresponding request ids. + self._draft_token_req_ids = self.input_batch.req_ids.copy() + + draft_token_ids: torch.Tensor = self._draft_token_ids + if not torch.is_tensor(draft_token_ids): + return + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_copy_stream is not None + assert self.draft_token_ids_cpu is not None + default_stream = current_platform.torch_device_fn.current_stream() + num_reqs = draft_token_ids.shape[0] + with current_platform.torch_device_fn.stream(self.draft_token_ids_copy_stream): + if not zeros_only: + # Trigger async copy of draft token ids to cpu. + self.draft_token_ids_copy_stream.wait_stream(default_stream) + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids, non_blocking=True + ) + else: + # No copy needed, just zero-out cpu tensor. + self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_event.record() + + def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids, self.input_batch.req_ids + req_ids = self._draft_token_req_ids + if req_ids is None: + return [], [] + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_cpu is not None + self.draft_token_ids_event.synchronize() + return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor ) -> None: @@ -3400,10 +4251,11 @@ def _copy_valid_sampled_token_count( default_stream = current_platform.torch_device_fn.current_stream() # Initialize a new stream to overlap the copy operation with # prepare_input of draft model. - with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): + with current_platform.torch_device_fn.stream(self.valid_sampled_token_count_copy_stream): self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore counts = valid_sampled_tokens_count counts_cpu = self.valid_sampled_token_count_cpu + assert counts_cpu is not None counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) self.valid_sampled_token_count_event.record() @@ -3412,14 +4264,13 @@ def _copy_valid_sampled_token_count( def _get_valid_sampled_token_count(self) -> list[int]: # Wait until valid_sampled_tokens_count is copied to cpu, prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids - if ( - self.valid_sampled_token_count_event is None - or prev_sampled_token_ids is None - ): + sampled_count_event = self.valid_sampled_token_count_event + if sampled_count_event is None or prev_sampled_token_ids is None: return [] counts_cpu = self.valid_sampled_token_count_cpu - self.valid_sampled_token_count_event.synchronize() + assert counts_cpu is not None + sampled_count_event.synchronize() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() def propose_draft_token_ids( @@ -3432,24 +4283,65 @@ def propose_draft_token_ids( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None if spec_config.method == "ngram": + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( sampled_token_ids, - self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, + slot_mappings=slot_mappings, + ) + elif spec_config.use_ngram_gpu(): + assert isinstance(self.drafter, NgramProposerGPU) + ( + next_token_ids, + valid_sampled_tokens_count, + valid_sampled_token_ids_gpu, + ) = self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + + draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + self.num_tokens_no_spec_gpu[:batch_size], + self.token_ids_gpu_tensor[:batch_size], + valid_sampled_token_ids_gpu, + valid_sampled_tokens_count, + ) + + # Cache valid draft counts for scheduler-side trimming. + self._num_valid_draft_tokens = num_valid_draft_tokens + + # Async D2H copy on a dedicated stream. + copy_num_valid_draft_tokens( + self._num_valid_draft_tokens_cpu, + self._num_valid_draft_tokens_copy_stream, + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens, + self.input_batch.num_reqs, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) - draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) + draft_token_ids = self.drafter.propose( + self.input_batch, sampled_token_ids, slot_mappings=slot_mappings + ) elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -3474,9 +4366,41 @@ def propose_draft_token_ids( draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, + slot_mappings=slot_mappings, + ) + elif spec_config.uses_extract_hidden_states(): + assert isinstance(self.drafter, ExtractHiddenStatesProposer) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for " + "extract_hidden_states method." ) - elif spec_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if not self.use_aux_hidden_state_outputs or aux_hidden_states is None: + raise ValueError( + "aux_hidden_states are required when using `extract_hidden_states`" + ) + target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] + + draft_token_ids = self.drafter.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + slot_mappings=slot_mappings, + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + elif spec_config.use_eagle() or spec_config.uses_draft_model(): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -3514,6 +4438,7 @@ def propose_draft_token_ids( next_token_ids, valid_sampled_tokens_count ) + num_rejected_tokens_gpu = None if spec_decode_metadata is None: token_indices_to_sample = None # input_ids can be None for multimodal models. @@ -3544,12 +4469,14 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[token_indices] else: - common_attn_metadata, token_indices_to_sample = ( - self.drafter.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count, - ) + ( + common_attn_metadata, + token_indices_to_sample, + num_rejected_tokens_gpu, + ) = self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -3563,7 +4490,7 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:total_num_tokens] - if self.supports_mm_inputs: + if self.supports_mm_inputs and self.drafter.supports_mm_inputs: mm_embed_inputs = self._gather_mm_embeddings( scheduler_output, shift_computed_tokens=1, @@ -3576,10 +4503,12 @@ def propose_draft_token_ids( target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, + token_indices_to_sample=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + slot_mappings=slot_mappings, ) return draft_token_ids @@ -3595,34 +4524,30 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) - def load_model(self, eep_scale_up: bool = False) -> None: + @instrument(span_name="Loading (GPU)") + def load_model(self, load_dummy_weights: bool = False) -> None: """ Args: - eep_scale_up: the model loading is for elastic EP scale up. + load_dummy_weights: load dummy weights instead of real weights. """ logger.info_once( "Starting to load model %s...", self.model_config.model, scope="global", ) - global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( - EplbState.get_eep_state(self.parallel_config) - if eep_scale_up - else (None, None, None) - ) if self.parallel_config.enable_eplb: self.eplb_state = EplbState(self.parallel_config, self.device) eplb_models = 0 - # IO dumper is only supported in eager mode. In graph mode (torch.compile) - # TorchDispatchMode and the module forward hooks are incompatible with - # Dynamo tracing, so IO dumping is silently skipped. + # FL: IO dumper init (skipped under Dynamo tracing) init_io_dump_from_env(getattr(self.model_config, "enforce_eager", False)) try: with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() + if load_dummy_weights: + self.load_config.load_format = "dummy" model_loader = get_model_loader(self.load_config) self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config @@ -3639,6 +4564,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: and is_mixture_of_experts(self.drafter.model) and self.parallel_config.enable_eplb ): + assert not self.parallel_config.enable_elastic_ep, ( + "Elastic EP is not supported with drafter model." + ) spec_config = self.vllm_config.speculative_config assert spec_config is not None assert spec_config.draft_model_config is not None @@ -3646,17 +4574,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: "EPLB is enabled for drafter model %s.", spec_config.draft_model_config.model, ) - - global_expert_load = ( - global_expert_loads[eplb_models] - if global_expert_loads - else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) if self.eplb_state is None: self.eplb_state = EplbState( self.parallel_config, self.device @@ -3664,9 +4581,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.eplb_state.add_model( self.drafter.model, spec_config.draft_model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) eplb_models += 1 @@ -3686,7 +4600,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: aux_layers, ) else: - aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + aux_layers = ( + self.model.get_eagle3_default_aux_hidden_state_layers() + ) self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() @@ -3706,20 +4622,19 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.error(combined_msg) raise e logger.info_once( - "Model loading took %.4f GiB memory and %.6f seconds", - self.model_memory_usage / GiB_bytes, + "Model loading took %s GiB memory and %.6f seconds", + format_gib(self.model_memory_usage), time_after_load - time_before_load, scope="local", ) - - # IO dumper: register module paths and install module context hooks. - register_io_module_hooks(self.model) - - prepare_communication_buffer_for_model(self.model) - if (drafter := getattr(self, "drafter", None)) and ( - drafter_model := getattr(drafter, "model", None) - ): - prepare_communication_buffer_for_model(drafter_model) + if not load_dummy_weights: + prepare_communication_buffer_for_model(self.model) + # FL: register IO dumper module hooks + register_io_module_hooks(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) @@ -3727,49 +4642,42 @@ def load_model(self, eep_scale_up: bool = False) -> None: and mm_config.is_multimodal_pruning_enabled() ) - if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + if ( + is_mixture_of_experts(self.model) + and self.parallel_config.enable_eplb + and not load_dummy_weights + ): logger.info_once("EPLB is enabled for model %s.", self.model_config.model) - global_expert_load = ( - global_expert_loads[eplb_models] if global_expert_loads else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) assert self.eplb_state is not None self.eplb_state.add_model( self.model, self.model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) if self.eplb_state.is_async: - self.eplb_state.start_async_loop(rank_mapping=rank_mapping) - - # print(f"{self.vllm_config.compilation_config.mode=}") + self.eplb_state.start_async_loop() if ( self.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE - and supports_dynamo() ): backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.stock_torch_compile_count += 1 self.model.compile(fullgraph=True, backend=backend) return # for other compilation modes, cudagraph behavior is controlled by - # CudagraphWraper and CudagraphDispatcher of vllm. + # CudagraphWrapper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching + ): self.model = GraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - elif self.parallel_config.enable_dbo: + elif self.parallel_config.use_ubatching: if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device @@ -3779,6 +4687,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, CUDAGraphMode.NONE, self.device ) + get_offloader().post_init() + def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: """Extract Eagle3 auxiliary layer indices from speculative config. @@ -3803,23 +4713,89 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." - ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), model_config=self.model_config) - - def save_tensorized_model( + def reload_weights( self, - tensorizer_config: "TensorizerConfig", + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + weights_path: str | None = None, + is_checkpoint_format: bool = True, ) -> None: - TensorizerLoader.save_model( - self.get_model(), - tensorizer_config=tensorizer_config, - model_config=self.model_config, + """ + Reload weights from a weights iterator or from disk + + :param weights_iterator: weights to load into model + :param weights_path: path to load weights from if weights_iterator is not + provided. Use path of original model if neither is provided. + :param is_checkpoint_format: set to False if weights have already been processed + into kernel format (repacking, renaming, etc.) + """ + # TODO(@kylesayrs): generalize to all runners and loaders + # argument validation + if weights_iterator is None and not is_checkpoint_format: + logger.warning( + "Reloading from disk means that weights will be in checkpoint format. " + "Please use `is_checkpoint_format=True` " + "to avoid weight reloading errors" + ) + + model = self.get_model() + weights_to_load = {name for name, _ in model.named_parameters()} + counter_before_reloading = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + if not hasattr(model_loader, "get_all_weights"): + raise NotImplementedError( + f"Model reloading with `{self.load_config.load_format}` format" + ) + + if weights_path is not None: + self.model_config.model = weights_path + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + # begin loading weights + logger.info_once("Reloading weights inplace...", scope="local") + load_device = ( + self.vllm_config.load_config.device or self.vllm_config.device_config.device ) + with torch.device(load_device): + if is_checkpoint_format: + # load weights from checkpoint/ original model format + initialize_layerwise_reload(model) + loaded_weights = model.load_weights(weights_iterator) + finalize_layerwise_reload(model, self.model_config) + + else: + # load weights from kernel format + logger.warning_once( + "Reloading with `is_checkpoint_format=True` requires that " + "weights be in kernel format and already sharded", + scope="local", + ) + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) # TODO: buffers? + param.copy_(loaded_weight) + loaded_weights.add(name) + + # logging and validation + counter_after_reloading = time.perf_counter() + diff_seconds = counter_after_reloading - counter_before_reloading + logger.info_once( + "Reloading and processing weights took %.2f seconds", + diff_seconds, + scope="local", + ) + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + logger.warning( + "Following weights were not loaded from checkpoint: %s", + weights_not_loaded, + ) def _get_prompt_logprobs_dict( self, @@ -3900,7 +4876,7 @@ def _get_prompt_logprobs_dict( # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) - token_ids, logprobs, ranks = self.sampler.gather_logprobs( + token_ids, logprobs, ranks, _ = self.sampler.gather_logprobs( logprobs, num_prompt_logprobs, tgt_token_ids ) @@ -3932,7 +4908,7 @@ def _get_nans_in_logits( ) -> dict[str, int]: try: if logits is None: - return dict.fromkeys(self.input_batch.req_ids, 0) + return {req_id: 0 for req_id in self.input_batch.req_ids} num_nans_in_logits = {} num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() @@ -4000,22 +4976,22 @@ def _get_mm_dummy_batch( """Dummy data for profiling and precompiling multimodal models.""" assert self.mm_budget is not None - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_model_len, + # Don't use `max_items_per_batch` here to avoid redundant computation + dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs( + self.model_config, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) - dummy_mm_data = dummy_decoder_data.multi_modal_data + dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0] - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] - dummy_mm_items = [dummy_mm_item] * max_items_per_batch + # We use the cache so that the item is saved to the cache, + # but not read from the cache + assert dummy_mm_item is not None, "Item should not already be cached" return next( - mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, + mm_kwargs_batch + for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs( + [(modality, dummy_mm_item)] * max_items_per_batch, device=self.device, pin_memory=self.pin_memory, ) @@ -4033,12 +5009,13 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, - activate_lora: bool = False, is_graph_capturing: bool = False, + num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the - graph for the model. + CUDA graph for the model. Args: num_tokens: Number of tokens to run the dummy forward pass. @@ -4057,11 +5034,21 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run - activate_lora: If False, dummy_run is performed without LoRAs. + num_active_loras: Number of distinct active LoRAs to capture for. + LoRA is activated when num_active_loras > 0. + profile_seq_lens: If provided, use this value for seq_lens instead + of max_query_len. Used to profile attention workspace that + scales with context length. """ + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + assert ( cudagraph_runtime_mode is None - or cudagraph_runtime_mode.valid_runtime_modes() + or cudagraph_runtime_mode.is_valid_runtime_mode() ) # If cudagraph_mode.decode_mode() == FULL and @@ -4082,7 +5069,7 @@ def _dummy_run( # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. - assert num_tokens <= self.scheduler_config.max_num_batched_tokens + assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: assert not uniform_decode @@ -4133,7 +5120,10 @@ def _dummy_run( # `force_has_lora` is used for cudagraph capture; because LoRA is # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + force_has_lora=num_active_loras > 0, + # `force_num_active_loras` is used for cudagraph capture; because we + # need to capture graphs for specific num_active_loras counts + force_num_active_loras=num_active_loras, ) ) @@ -4150,51 +5140,77 @@ def _dummy_run( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) attn_metadata: PerLayerAttnMetadata | None = None - # If force_attention is True, we always capture attention. Otherwise, - # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - if create_mixed_batch: - # In the mixed batch mode (used for FI warmup), we use - # shorter sequence lengths to run faster. - # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] - else: - seq_lens = max_query_len # type: ignore[assignment] - self.seq_lens.np[:num_reqs] = seq_lens - self.seq_lens.np[num_reqs:] = 0 - self.seq_lens.copy_to_gpu() + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) - self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens - self.query_start_loc.copy_to_gpu() + # _dummy_run shares pinned CPU buffers (seq_lens, query_start_loc, + # etc.) with execute_model. It must participate in the same event + # protocol so that back-to-back dummy/real steps don't overwrite + # pinned memory while a prior non_blocking H2D DMA is still reading. + with self.synchronize_input_prep(): + # If force_attention is True, we always capture attention. + # Otherwise, it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if profile_seq_lens is not None: + seq_lens = profile_seq_lens # type: ignore[assignment] + elif create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] + else: + seq_lens = max_query_len # type: ignore[assignment] + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() - pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL - attn_metadata, _ = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs_padded, - max_query_len=max_query_len, - ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, - for_cudagraph_capture=is_graph_capturing, - ) + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices), + for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, + use_spec_decode=self.speculative_config is not None, + ) with self.maybe_dummy_run_with_lora( self.lora_config, num_scheduled_tokens, num_sampled_tokens, - activate_lora, remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), @@ -4202,7 +5218,7 @@ def _dummy_run( elif self.enable_prompt_embeds: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -4237,6 +5253,7 @@ def _dummy_run( num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded + with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( @@ -4247,6 +5264,7 @@ def _dummy_run( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), ): outputs = self.model( @@ -4262,8 +5280,16 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + or self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + ) + assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. @@ -4282,13 +5308,17 @@ def _dummy_run( # lora cases when cudagraph_specialize_lora is enabled. This is a # short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/28334 - if self.compilation_config.cudagraph_specialize_lora and activate_lora: + if ( + self.compilation_config.cudagraph_specialize_lora + and num_active_loras > 0 + ): use_cudagraphs = False self.drafter.dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, ) # We register layerwise NVTX hooks here after the first dynamo tracing is @@ -4326,6 +5356,12 @@ def _dummy_sampler_run( # The dummy hidden states may contain special values, # like `inf` or `nan`. # To avoid breaking the sampler, we use a random tensor here instead. + + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # MM Encoder only model no need to run sampler. + return torch.tensor([]) + hidden_states = torch.rand_like(hidden_states) logits = self.model.compute_logits(hidden_states) @@ -4400,24 +5436,21 @@ def _dummy_pooler_run_task( max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req) + num_scheduled_tokens_np[-1] += num_tokens % num_reqs + assert np.sum(num_scheduled_tokens_np) == num_tokens + assert len(num_scheduled_tokens_np) == num_reqs req_num_tokens = num_tokens // num_reqs - dummy_prompt_lens = torch.tensor( - num_scheduled_tokens_list, - device="cpu", - ) + dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np) dummy_token_ids = torch.zeros( (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) - dummy_pooling_params.verify(task=task, model_config=self.model_config) + dummy_pooling_params.verify(self.model_config) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -4429,7 +5462,7 @@ def _dummy_pooler_run_task( ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, + num_scheduled_tokens_np, seq_lens_cpu=dummy_prompt_lens, device=hidden_states.device, ) @@ -4454,6 +5487,11 @@ def _dummy_pooler_run( self, hidden_states: torch.Tensor, ) -> PoolerOutput: + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # MM Encoder only model not need to run pooler. + return torch.tensor([]) + # Find the task that has the largest output for subsequent steps supported_pooling_tasks = self.get_supported_pooling_tasks() @@ -4469,7 +5507,7 @@ def _dummy_pooler_run( for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = sum(o.nbytes for o in output) + output_size[task] = sum(o.nbytes for o in output if o is not None) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0] @@ -4489,40 +5527,51 @@ def profile_run(self) -> None: assert mm_budget is not None if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ - dummy_modality - ] + if not mm_budget.mm_max_toks_per_item: + # All modality limits are 0 — embedding-only mode. + # Budget is non-zero for embedding storage, but + # there's no encoder to profile. + logger.info( + "Skipping encoder profiling for embedding-only " + "mode (all modality limits=0 with " + "enable_mm_embeds=True).", + ) + else: + # NOTE: Currently model is profiled with a single + # non-text modality with the max possible input + # tokens even when it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ + dummy_modality + ] - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + logger.info_once( + "Encoder cache will be initialized with a " + "budget of %s tokens, and profiled with " + "%s %s items of the maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + scope="local", + ) - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) + # Run multimodal encoder. + dummy_encoder_outputs = self.model.embed_multimodal( + **batched_dummy_mm_inputs + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - for i, output in enumerate(dummy_encoder_outputs): - self.encoder_cache[f"tmp_{i}"] = output + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run( @@ -4540,6 +5589,169 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def _init_minimal_kv_cache_for_profiling(self) -> None: + from vllm.v1.core.kv_cache_utils import ( + get_kv_cache_config_from_groups, + get_kv_cache_groups, + ) + + kv_cache_spec = self.get_kv_cache_spec() + kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) + min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 + + # Temporarily change num_gpu_blocks_override to allocate a minimal KV cache + saved_override = self.cache_config.num_gpu_blocks_override + self.cache_config.num_gpu_blocks_override = min_blocks + minimal_config = get_kv_cache_config_from_groups( + self.vllm_config, kv_cache_groups, available_memory=0 + ) + self.cache_config.num_gpu_blocks_override = saved_override + + self.initialize_kv_cache(minimal_config) + self.cache_config.num_gpu_blocks = minimal_config.num_blocks + + logger.debug("Initialized minimal KV cache for CUDA graph profiling") + + @staticmethod + @contextmanager + def _freeze_gc(): + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + gc.collect() + + def _cleanup_profiling_kv_cache(self) -> None: + torch.accelerator.synchronize() + if hasattr(self, "kv_caches") and self.kv_caches: + for i in range(len(self.kv_caches)): + self.kv_caches[i] = None # type: ignore + self.kv_caches.clear() + if hasattr(self, "cross_layers_kv_cache"): + self.cross_layers_kv_cache = None + self.cross_layers_attn_backend = None + if hasattr(self, "attn_groups"): + self.attn_groups.clear() + if hasattr(self, "kv_cache_config"): + delattr(self, "kv_cache_config") + self.cache_config.num_gpu_blocks = None + + for layer in self.compilation_config.static_forward_context.values(): + if hasattr(layer, "kv_cache"): + layer.kv_cache = [] + + gc.collect() + torch.accelerator.empty_cache() + + logger.debug("Cleaned up profiling KV cache and CUDA graphs") + + @torch.inference_mode() + def profile_cudagraph_memory(self) -> int: + with set_current_vllm_config(self.vllm_config): + self._init_minimal_kv_cache_for_profiling() + + saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured + + capture_descs = self.cudagraph_dispatcher.get_capture_descs() + + total_graphs = sum(len(descs) for _, descs in capture_descs) + if total_graphs == 0: + logger.debug("No CUDA graphs will be captured, skipping profiling") + self._cleanup_profiling_kv_cache() + return 0 + + logger.info( + "Profiling CUDA graph memory: %s", + ", ".join( + f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})" + for mode, descs in capture_descs + if descs + ), + ) + + # Use a temporary pool for profiling to avoid fragmentation in the main pool. + profiling_pool = current_platform.graph_pool_handle() + original_pools: dict[int, Any] = {} + for instance in list(GraphWrapper._all_instances): + original_pools[id(instance)] = instance.graph_pool + instance.graph_pool = profiling_pool + + set_cudagraph_capturing_enabled(True) + with self._freeze_gc(), graph_capture(device=self.device): + shared_memory_estimate = {} + per_graph_estimate = {} + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + + for mode, descs in capture_descs: + profile_descs = descs[:2] + mem_samples: list[int] = [] + + for i, desc in enumerate(profile_descs): + mem_before = current_platform.torch_device_fn.mem_get_info()[0] + self._warmup_and_capture( + desc, + cudagraph_runtime_mode=mode, + profile_seq_lens=( + min( + self.max_model_len, + self.max_num_tokens // desc.num_tokens, + ) + if mode == CUDAGraphMode.FULL and i == 0 + else None + ), + ) + torch.accelerator.synchronize() + free_after = current_platform.torch_device_fn.mem_get_info()[0] + mem_samples.append(mem_before - free_after) + + first_capture = mem_samples[0] + # Use at least 1 MiB per graph for driver overhead + per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20) + + shared_memory_estimate[mode] = first_capture + per_graph_estimate[mode] = per_graph * (len(descs) - 1) + + logger.debug( + "Estimated %s CUDA graph memory: " + "%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph", + mode.name, + first_capture / (1 << 20), + len(descs), + per_graph / (1 << 20), + ) + + set_cudagraph_capturing_enabled(False) + GraphWrapper.clear_all_graphs() + for instance in list(GraphWrapper._all_instances): + if id(instance) in original_pools: + instance.graph_pool = original_pools[id(instance)] + for key_set in self.cudagraph_dispatcher.cudagraph_keys.values(): + key_set.clear() + self.cudagraph_dispatcher.keys_initialized = False + self.maybe_remove_all_loras(self.lora_config) + self._cleanup_profiling_kv_cache() + compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured + + # FULL and PIECEWISE graphs share the global pool at runtime and are + # never replayed concurrently, so the pool overlays their memory. + # Take the max to avoid double-counting the overlap. + total_estimate = max(shared_memory_estimate.values()) + sum( + per_graph_estimate.values() + ) + logger.info( + "Estimated CUDA graph memory: %.2f GiB total", + total_estimate / (1 << 30), + ) + + return int(total_estimate) + + @instrument(span_name="Capture model") def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( @@ -4552,75 +5764,28 @@ def capture_model(self) -> int: start_time = time.perf_counter() - @contextmanager - def freeze_gc(): - # Optimize garbage collection during CUDA graph capture. - # Clean up, then freeze all remaining objects from being included - # in future collections. - gc.collect() - should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC - if should_freeze: - gc.freeze() - try: - yield - finally: - if should_freeze: - gc.unfreeze() - gc.collect() - # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) - with freeze_gc(), graph_capture(device=self.device): + with self._freeze_gc(), graph_capture(device=self.device): + torch.accelerator.synchronize() + torch.accelerator.empty_cache() start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] - cudagraph_mode = self.compilation_config.cudagraph_mode - assert cudagraph_mode is not None - - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] - if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - # make sure we capture the largest batch size first - compilation_cases = list( - product(reversed(self.cudagraph_batch_sizes), lora_cases) - ) - self._capture_cudagraphs( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, - ) - # Capture full cudagraph for uniform decode batches if we - # don't already have full mixed prefill-decode cudagraphs. - if ( - cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() - ): - max_num_tokens = ( - self.scheduler_config.max_num_seqs * self.uniform_decode_query_len - ) - decode_cudagraph_batch_sizes = [ - x - for x in self.cudagraph_batch_sizes - if max_num_tokens >= x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - product(reversed(decode_cudagraph_batch_sizes), lora_cases) - ) + for ( + runtime_mode, + batch_descs, + ) in self.cudagraph_dispatcher.get_capture_descs(): self._capture_cudagraphs( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + batch_descriptors=batch_descs, + cudagraph_runtime_mode=runtime_mode, ) + torch.accelerator.synchronize() - current_platform.torch_device_fn.synchronize() + torch.accelerator.synchronize() end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] + # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because @@ -4628,6 +5793,9 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + # Lock workspace to prevent resizing during execution. # Max workspace sizes should have been captured during warmup/profiling. lock_workspace() @@ -4644,21 +5812,59 @@ def freeze_gc(): ) return cuda_graph_size + def _warmup_and_capture( + self, + desc: BatchDescriptor, + cudagraph_runtime_mode: CUDAGraphMode, + profile_seq_lens: int | None = None, + allow_microbatching: bool = False, + num_warmups: int | None = None, + ): + if num_warmups is None: + num_warmups = self.compilation_config.cudagraph_num_of_warmups + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + for _ in range(num_warmups): + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + ) + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + is_graph_capturing=True, + profile_seq_lens=profile_seq_lens, + ) + def _capture_cudagraphs( self, - compilation_cases: list[tuple[int, bool]], + batch_descriptors: list[BatchDescriptor], cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool, ): assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE - and cudagraph_runtime_mode.valid_runtime_modes() + and cudagraph_runtime_mode.is_valid_runtime_mode() ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + if not batch_descriptors: + return + + uniform_decode = batch_descriptors[0].uniform + # Only rank 0 should print progress bar during capture if is_global_first_rank(): - compilation_cases = tqdm( - compilation_cases, + batch_descriptors = tqdm( + batch_descriptors, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", @@ -4667,49 +5873,27 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for batch_desc in batch_descriptors: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph allow_microbatching = ( - self.parallel_config.enable_dbo + self.parallel_config.use_ubatching and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, - num_tokens=num_tokens, + num_tokens=batch_desc.num_tokens, uniform_decode=uniform_decode, ) ) - - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - ) - self._dummy_run( - num_tokens, + self._warmup_and_capture( + batch_desc, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - is_graph_capturing=True, ) + torch.accelerator.synchronize() self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -4808,14 +5992,22 @@ def initialize_metadata_builders( if kv_cache_group_id < len(kernel_block_sizes) else None, num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, + if not self.parallel_config.use_ubatching + else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) # Note (tdoublep): do this *after* constructing builders, # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() + # Initialize drafter attention backend + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], @@ -4959,10 +6151,22 @@ def _check_and_update_cudagraph_mode( self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size ) - capture_sizes = self.compilation_config.cudagraph_capture_sizes - self.cudagraph_batch_sizes = ( - capture_sizes if capture_sizes is not None else [] + + # If the model has Mamba layers and cudagraph mode includes FULL + # decode, cap cudagraph capture sizes to the number of available + # Mamba cache blocks. Each decode request needs one conv_state + # cache line, so capture batch sizes cannot exceed num_blocks. + # Only FULL decode graphs are affected because PIECEWISE captures + # run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update). + # See: https://github.com/vllm-project/vllm/issues/34094 + if cudagraph_mode.has_full_cudagraphs(): + has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups ) + if has_mamba and self.kv_cache_config is not None: + self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache( + self.kv_cache_config.num_blocks + ) # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. @@ -4971,6 +6175,14 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + # Initialize drafter's cudagraph dispatcher if using spec decode. + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer) + self.drafter.initialize_cudagraph_keys(cudagraph_mode) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups. @@ -4991,120 +6203,75 @@ def calculate_reorder_batch_threshold(self) -> None: return self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] - @staticmethod - def select_common_block_size( - kv_manager_block_size: int, attn_groups: list[AttentionGroup] - ) -> int: - """ - Select a block size that is supported by all backends and is a factor of - kv_manager_block_size. - - If kv_manager_block_size is supported by all backends, return it directly. - Otherwise, return the max supported size. - - Args: - kv_manager_block_size: Block size of KV cache - attn_groups: List of attention groups - - Returns: - The selected block size - - Raises: - ValueError: If no valid block size found - """ - - def block_size_is_supported( - backends: list[type[AttentionBackend]], block_size: int - ) -> bool: - """ - Check if the block size is supported by all backends. - """ - for backend in backends: - is_supported = False - for supported_size in backend.get_supported_kernel_block_sizes(): - if isinstance(supported_size, int): - if block_size == supported_size: - is_supported = True - elif isinstance(supported_size, MultipleOf): - if block_size % supported_size.base == 0: - is_supported = True - else: - raise ValueError(f"Unknown supported size: {supported_size}") - if not is_supported: - return False - return True - - backends = [group.backend for group in attn_groups] - - # Case 1: if the block_size of kv cache manager is supported by all backends, - # return it directly - if block_size_is_supported(backends, kv_manager_block_size): - return kv_manager_block_size - - # Case 2: otherwise, the block_size must be an `int`-format supported size of - # at least one backend. Iterate over all `int`-format supported sizes in - # descending order and return the first one that is supported by all backends. - # Simple proof: - # If the supported size b is in MultipleOf(x_i) format for all attention - # backends i, and b a factor of kv_manager_block_size, then - # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will - # return kv_manager_block_size in case 1. - all_int_supported_sizes = set( - supported_size - for backend in backends - for supported_size in backend.get_supported_kernel_block_sizes() - if isinstance(supported_size, int) - ) - - for supported_size in sorted(all_int_supported_sizes, reverse=True): - if kv_manager_block_size % supported_size != 0: - continue - if block_size_is_supported(backends, supported_size): - return supported_size - raise ValueError(f"No common block size for {kv_manager_block_size}. ") - def may_reinitialize_input_batch( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> None: """ Re-initialize the input batch if the block sizes are different from - `[self.cache_config.block_size]`. This usually happens when there - are multiple KV cache groups. + what it was originally created with. This happens when the final + block size (determined after model loading) differs from the + placeholder used during __init__, or when there are multiple + KV cache groups. Args: kv_cache_config: The KV cache configuration. kernel_block_sizes: The kernel block sizes for each KV cache group. """ - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) - ] + block_sizes = [] + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for kv_cache_group in kv_cache_config.kv_cache_groups: + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + block_size = kv_cache_group.kv_cache_spec.block_size + block_sizes.append(block_size) + max_num_blocks_per_req = cdiv( + max_model_len, block_size * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + max_num_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + max_num_blocks.append(max_num_blocks_per_req) - if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ - self.cache_config.block_size - ]: - assert self.cache_config.cpu_offload_gb == 0, ( + if ( + block_sizes != self._init_block_sizes + or kernel_block_sizes != self._init_kernel_block_sizes + ): + assert self.offload_config.uva.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "for more details." ) + self._init_block_sizes = block_sizes + self._init_kernel_block_sizes = kernel_block_sizes self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=self.num_spec_tokens, ) + assert self._init_block_sizes == block_sizes, ( + f"InputBatch block_sizes {self._init_block_sizes} != " + f"kv_cache block_sizes {block_sizes}" + ) + assert self._init_kernel_block_sizes == kernel_block_sizes, ( + f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} " + f"!= kv_cache kernel_block_sizes {kernel_block_sizes}" + ) + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig ) -> dict[str, torch.Tensor]: @@ -5146,49 +6313,6 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups - def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: - """ - Generate kernel_block_sizes that matches each block_size. - - For attention backends that support virtual block splitting, - use the supported block sizes from the backend. - For other backends (like Mamba), use the same block size (no splitting). - - Args: - kv_cache_config: The KV cache configuration. - - Returns: - list[int]: List of kernel block sizes for each cache group. - """ - kernel_block_sizes = [] - for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group.kv_cache_spec - if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): - # All layers in the UniformTypeKVCacheSpecs have the same type, - # Pick an arbitrary one to dispatch. - kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) - if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): - continue - elif isinstance(kv_cache_spec, AttentionSpec): - # This is an attention backend that supports virtual - # block splitting. Get the supported block sizes from - # all backends in the group. - attn_groups = self.attn_groups[kv_cache_gid] - kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - selected_kernel_size = self.select_common_block_size( - kv_manager_block_size, attn_groups - ) - kernel_block_sizes.append(selected_kernel_size) - elif isinstance(kv_cache_spec, MambaSpec): - # This is likely Mamba or other non-attention cache, - # no splitting. - kernel_block_sizes.append(kv_cache_spec.block_size) - else: - raise NotImplementedError( - f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" - ) - return kernel_block_sizes - def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -5252,7 +6376,8 @@ def _reshape_kv_cache_tensors( ) # Maintain original KV shape view. inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) ] kv_caches[layer_name] = ( kv_cache_raw_tensors[layer_name] @@ -5411,6 +6536,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) @@ -5419,7 +6545,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # backends for that group only supports block_size 64, we will return # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # tokens each. - kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + kernel_block_sizes = prepare_kernel_block_sizes( + kv_cache_config, self.attn_groups + ) + self._kernel_block_sizes = kernel_block_sizes # create metadata builders self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) @@ -5430,8 +6559,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if ( + self.speculative_config + and self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance(self.drafter, ExtractHiddenStatesProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) @@ -5447,6 +6579,58 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + def _get_attention_kv_cache_gid(self) -> int: + """Find the KV cache group index for attention layers.""" + for gid, group in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(group.kv_cache_spec, AttentionSpec): + return gid + return 0 + + def init_routed_experts_capturer(self): + logger.info( + "Initializing routed experts capturer, enable_return_routed_experts: %s", + self.model_config.enable_return_routed_experts, + ) + routed_experts_capturer = RoutedExpertsCapturer.create() + self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() + min_block_size = min( + [ + group.kv_cache_spec.block_size + for group in self.kv_cache_config.kv_cache_groups + ] + ) + num_groups = len(self.kv_cache_config.kv_cache_groups) + self.max_num_kv_tokens = ( + self.kv_cache_config.num_blocks // num_groups + ) * min_block_size + dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size + pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size + if pcp_size * dcp_size > 1: + self.max_num_kv_tokens *= pcp_size * dcp_size + + routed_experts_capturer.init_buffer( + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_num_kv_tokens=self.max_num_kv_tokens, + vllm_config=self.vllm_config, + ) + self._bind_routed_experts_capturer(routed_experts_capturer) + self.routed_experts_initialized = True + + def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + from vllm.model_executor.layers.fused_moe.router.base_router import ( + BaseRouter, + ) + + for module in self.compilation_config.static_forward_context.values(): + if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter): + layer_id = module.layer_id + + def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer): + _capturer.capture(_layer_id, topk_ids) + + module.router.set_capture_fn(_capture_fn) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. @@ -5481,7 +6665,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ - if has_ec_transfer() and get_ec_transfer().is_producer: + if has_ec_transfer() and not get_ec_transfer().is_consumer: return {} kv_cache_spec: dict[str, KVCacheSpec] = {} layer_type = cast(type[Any], AttentionLayerBase) @@ -5519,3 +6703,79 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() + + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """ + Get encoder timing stats for all requests and clear the registry. + + Returns: + Dictionary mapping request_id to stats dict. + """ + with self._encoder_timing_lock: + stats = { + req_id: stats_obj.to_dict() + for req_id, stats_obj in self.encoder_timing_registry.items() + } + self.encoder_timing_registry.clear() + return stats + + @contextmanager + def timed_encoder_operation( + self, + should_time: bool, + group_lora_refs: list[tuple[str, Any]], + current_item_idx: int, + num_items: int, + ): + """ + Context manager to time encoder forward operations. + + Args: + should_time: Whether timing is enabled + group_lora_refs: Full list of (request_id, pos_info) tuples + current_item_idx: Starting index for this group + num_items: Number of items in this group + """ + if not should_time: + yield + return + + group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items] + group_request_ids = {req_id for req_id, _ in group_refs} + + torch.accelerator.synchronize() + start_time = time.perf_counter() + + try: + yield + finally: + torch.accelerator.synchronize() + elapsed = time.perf_counter() - start_time + + per_request_time = elapsed / max(len(group_request_ids), 1) + + with self._encoder_timing_lock: + for req_id in group_request_ids: + if req_id not in self.encoder_timing_registry: + self.encoder_timing_registry[req_id] = EncoderTimingStats() + + stats = self.encoder_timing_registry[req_id] + stats.encoder_forward_secs += per_request_time + stats.num_encoder_calls += 1 + + +@dataclass +class EncoderTimingStats: + """Per-request timing statistics for encoder forward pass.""" + + encoder_forward_secs: float = 0.0 + """Time spent in vision encoder forward pass (seconds).""" + + num_encoder_calls: int = 0 + """Number of times encoder was called for this request.""" + + def to_dict(self) -> dict[str, float | int]: + return { + "encoder_forward_secs": self.encoder_forward_secs, + "num_encoder_calls": self.num_encoder_calls, + } diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index de6dbb0f..4cd7631f 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -1,5 +1,5 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project @@ -16,7 +16,7 @@ import torch.distributed import torch.nn as nn -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config from vllm.config.compilation import CompilationMode from vllm.distributed import ( ensure_model_parallel_initialized, @@ -47,7 +47,7 @@ def kernel_warmup(worker): ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed +from vllm.utils.torch_utils import set_random_seed from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.platforms import current_platform from vllm.profiler.wrapper import TorchProfilerWrapper @@ -206,11 +206,6 @@ def __init__( distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} @@ -429,10 +424,10 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. - def load_model(self) -> None: - eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + def load_model(self, *, load_dummy_weights: bool = False) -> None: ### TODO(lms): support manages a memory pool for device tensors. - self.model_runner.load_model(eep_scale_up=eep_scale_up) + with set_current_vllm_config(self.vllm_config): + self.model_runner.load_model(load_dummy_weights=load_dummy_weights) # with self._maybe_get_memory_pool_context(tag="weights"): # self.model_runner.load_model(eep_scale_up=eep_scale_up) @@ -546,6 +541,11 @@ def get_kv_connector_handshake_metadata(self) -> dict | None: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() + def update_max_model_len(self, max_model_len: int) -> None: + """Update max_model_len after auto-fit to GPU memory.""" + self.model_config.max_model_len = max_model_len + self.model_runner.max_model_len = max_model_len + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" # Init kv cache connector here, because it requires @@ -565,7 +565,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: # context = nullcontext() self.model_runner.initialize_kv_cache(kv_cache_config) - def compile_or_warm_up_model(self) -> None: + def compile_or_warm_up_model(self) -> float: warmup_sizes = [] if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: # warm up sizes that are not in cudagraph capture sizes, @@ -688,15 +688,24 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + return self.compilation_config.compilation_time + def reset_mm_cache(self) -> None: self.model_runner.reset_mm_cache() + def reset_encoder_cache(self) -> None: + self.model_runner.reset_encoder_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """Get encoder timing stats from model runner.""" + return self.model_runner.get_encoder_timing_stats() + def annotate_profile(self, scheduler_output): # add trace annotation so that we can easily distinguish # new/cached request numbers in each iteration @@ -789,7 +798,7 @@ def execute_model( def take_draft_token_ids(self) -> Optional[DraftTokenIds]: return self.model_runner.take_draft_token_ids() - def profile(self, is_start: bool = True): + def profile(self, is_start: bool = True, profile_prefix: str | None = None): if self.profiler is None: raise RuntimeError("Profiling is not enabled.") if is_start: