From d5bb6d68b95a535cd49af6e39db3739d41c49c93 Mon Sep 17 00:00:00 2001 From: cyberpioneer Date: Tue, 31 Mar 2026 07:16:29 +0000 Subject: [PATCH 1/7] upgrade vllm --- README.md | 3 +- vllm_fl/__init__.py | 66 +- vllm_fl/attention/utils.py | 4 +- vllm_fl/compilation/graph.py | 137 +- vllm_fl/configs/qwen3_5_moe.py | 185 - vllm_fl/dispatch/README.md | 4 +- vllm_fl/dispatch/__init__.py | 2 +- vllm_fl/dispatch/backends/__init__.py | 28 +- .../dispatch/backends/flaggems/flaggems.py | 2 +- .../backends/flaggems/impl/attention.py | 16 +- .../flaggems/impl/custom_attention.py | 2 +- .../dispatch/backends/flaggems/impl/mla.py | 4 +- .../dispatch/backends/reference/reference.py | 2 +- vllm_fl/dispatch/backends/vendor/__init__.py | 35 +- .../backends/vendor/ascend/impl/attention.py | 5 +- .../vendor/ascend/impl/causal_conv1d.py | 2 +- .../ascend/impl/mm_encoder_attention.py | 2 +- vllm_fl/dispatch/backends/vendor/cuda/cuda.py | 2 +- .../backends/vendor/cuda/impl/activation.py | 4 +- .../backends/vendor/iluvatar/iluvatar.py | 2 +- .../vendor/metax/impl/attention/flash_attn.py | 18 +- .../vendor/metax/impl/attention/mla/common.py | 14 +- .../metax/impl/attention/mla/flashmla.py | 10 +- .../impl/attention/ops/merge_attn_states.py | 2 +- .../metax/impl/attention/utils/fa_utils.py | 4 +- .../dispatch/backends/vendor/metax/metax.py | 4 +- vllm_fl/dispatch/builtin_ops.py | 69 +- vllm_fl/dispatch/config/__init__.py | 2 - vllm_fl/dispatch/config/nvidia.yaml | 4 +- vllm_fl/dispatch/config/utils.py | 32 +- vllm_fl/dispatch/logger_manager.py | 2 +- vllm_fl/dispatch/manager.py | 13 +- vllm_fl/dispatch/policy.py | 13 +- vllm_fl/models/glm_moe_dsa.py | 17 - vllm_fl/models/kimi_k25.py | 245 -- vllm_fl/models/minicpmo.py | 854 ---- vllm_fl/models/qwen3_5.py | 951 ----- vllm_fl/models/qwen3_next.py | 1396 ------- vllm_fl/ops/custom_ops.py | 5 - vllm_fl/ops/fused_moe/fused_moe.py | 14 +- vllm_fl/ops/fused_moe/layer.py | 33 +- vllm_fl/patches/glm_moe_dsa.py | 20 - vllm_fl/platform.py | 57 +- vllm_fl/utils.py | 58 - vllm_fl/worker/model_runner.py | 3675 ++++++++++++----- vllm_fl/worker/worker.py | 35 +- 46 files changed, 2920 insertions(+), 5134 deletions(-) delete mode 100644 vllm_fl/configs/qwen3_5_moe.py delete mode 100644 vllm_fl/models/glm_moe_dsa.py delete mode 100644 vllm_fl/models/kimi_k25.py delete mode 100644 vllm_fl/models/minicpmo.py delete mode 100644 vllm_fl/models/qwen3_5.py delete mode 100644 vllm_fl/models/qwen3_next.py diff --git a/README.md b/README.md index 61835f50..e4e31094 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n ### Setup -1. Install vllm from the official [v0.13.0](https://github.com/vllm-project/vllm/tree/v0.13.0) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL). +1. Install vllm from the official [v0.18.1](https://github.com/vllm-project/vllm/tree/v0.18.1) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL). 2. Install vllm-plugin-FL @@ -65,6 +65,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n ```sh git clone https://github.com/flagos-ai/FlagGems + git checkout v5.0.0 cd FlagGems pip install --no-build-isolation . # or editble install diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 2c823db4..56e3ece4 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -45,63 +45,12 @@ def register(): def register_model(): - """Register the FL model.""" - from vllm import ModelRegistry - import vllm.model_executor.models.qwen3_next as qwen3_next_module + """Register FL-specific models not yet upstream.""" + # Models now upstream in vLLM v0.18.1 (no longer need plugin registration): + # Qwen3NextForCausalLM, Qwen3_5MoeForConditionalGeneration, + # MiniCPMO, KimiK25ForConditionalGeneration, Qwen3_5MoeConfig - # Register Qwen3.5 MoE config - try: - from vllm.transformers_utils.config import _CONFIG_REGISTRY - from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig - _CONFIG_REGISTRY["qwen3_5_moe"] = Qwen3_5MoeConfig - except Exception as e: - logger.error(f"Register Qwen3.5 MoE config error: {str(e)}") - - # Register Qwen3Next model - try: - from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401 - - qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM - logger.warning( - "Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, " - "original vLLM implementation is overridden" - ) - - ModelRegistry.register_model( - "Qwen3NextForCausalLM", - "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM" - ) - except Exception as e: - logger.error(f"Register Qwen3Next model error: {str(e)}") - - # Register Qwen3.5 MoE model - try: - ModelRegistry.register_model( - "Qwen3_5MoeForConditionalGeneration", - "vllm_fl.models.qwen3_5:Qwen3_5MoeForConditionalGeneration" - ) - except Exception as e: - logger.error(f"Register Qwen3.5 MoE model error: {str(e)}") - - # Register MiniCPMO model - try: - ModelRegistry.register_model( - "MiniCPMO", - "vllm_fl.models.minicpmo:MiniCPMO" - ) - except Exception as e: - logger.error(f"Register MiniCPMO model error: {str(e)}") - - # Register Kimi-K2.5 model - try: - ModelRegistry.register_model( - "KimiK25ForConditionalGeneration", - "vllm_fl.models.kimi_k25:KimiK25ForConditionalGeneration", - ) - except Exception as e: - logger.error(f"Register KimiK25 model error: {str(e)}") - - # Register GLM-5 (GlmMoeDsa) model + # Register GLM-5 (GlmMoeDsa) — config not yet upstream try: from vllm.transformers_utils.config import _CONFIG_REGISTRY from vllm_fl.configs.glm_moe_dsa import GlmMoeDsaConfig @@ -109,10 +58,5 @@ def register_model(): from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model glm5_model() - - ModelRegistry.register_model( - "GlmMoeDsaForCausalLM", - "vllm_fl.models.glm_moe_dsa:GlmMoeDsaForCausalLM" - ) except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") diff --git a/vllm_fl/attention/utils.py b/vllm_fl/attention/utils.py index 642e7800..88982dfc 100644 --- a/vllm_fl/attention/utils.py +++ b/vllm_fl/attention/utils.py @@ -13,8 +13,8 @@ def patch_mm_encoder_attention(): FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a fallback to flash_attn. """ - import vllm.attention.layers.mm_encoder_attention as mm_mod - from vllm.attention.backends.registry import AttentionBackendEnum + import vllm.model_executor.layers.attention.mm_encoder_attention as mm_mod + from vllm.v1.attention.backends.registry import AttentionBackendEnum def _patched_maybe_get_vit_flash_attn_backend(attn_backend): if attn_backend == AttentionBackendEnum.FLASH_ATTN: diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index e674ec90..74e515a5 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -1,14 +1,15 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/compilation/cuda_graph.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/compilation/cuda_graph.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +import weakref from collections import Counter from collections.abc import Callable from contextlib import ExitStack -from typing import Any, Optional +from typing import Any, ClassVar from unittest.mock import patch import torch @@ -18,13 +19,18 @@ from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import ( + BatchDescriptor, + get_forward_context, + is_forward_context_available, +) from vllm.logger import init_logger from vllm.platforms import current_platform logger = init_logger(__name__) +# FL-specific: platform-agnostic weak_ref_tensors def weak_ref_tensors(tensor: Any) -> Any: if current_platform.device_type == "cuda": from vllm.utils.torch_utils import weak_ref_tensors @@ -34,6 +40,7 @@ def weak_ref_tensors(tensor: Any) -> Any: return tensor +# FL-specific: platform-agnostic graph class selection class Graph: if current_platform.device_type == "cuda": graph = torch.cuda.CUDAGraph @@ -42,15 +49,21 @@ class Graph: else: raise NotImplementedError("not support graph") + +# Re-export CUDAGraphStat for compatibility +from vllm.compilation.cuda_graph import CUDAGraphStat # noqa: F401, E402 + + @dataclasses.dataclass class GraphEntry: batch_descriptor: BatchDescriptor - graph: Optional[Graph] = None - output: Optional[Any] = None + graph: Any | None = None + output: Any | None = None # for graph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: list[int] | None = None + @dataclasses.dataclass class GraphOptions: @@ -60,11 +73,22 @@ class GraphOptions: class GraphWrapper: + """FL-specific graph wrapper that supports multiple device types (CUDA, NPU). + Adapted from upstream CUDAGraphWrapper with platform-agnostic graph capture.""" + + _all_instances: ClassVar[weakref.WeakSet["GraphWrapper"]] = weakref.WeakSet() + + @classmethod + def clear_all_graphs(cls) -> None: + """Clear captured graphs from all GraphWrapper instances.""" + for instance in list(cls._all_instances): + instance.clear_graphs() + def __init__(self, runnable: Callable, vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, - cudagraph_options: Optional[GraphOptions] = None): + cudagraph_options: GraphOptions | None = None): self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -72,9 +96,10 @@ def __init__(self, self.first_run_finished = False self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + self._runnable_str = str(runnable) if self.is_debugging_mode else None # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't - # need to initialize a CUDAGraphWrapper. + # need to initialize a GraphWrapper. assert self.runtime_mode != CUDAGraphMode.NONE # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -83,25 +108,41 @@ def __init__(self, if cudagraph_options is None: cudagraph_options = GraphOptions() - self.graph_options = cudagraph_options + self.cudagraph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # cudagraphs for. self.concrete_graph_entries: dict[BatchDescriptor, GraphEntry] = {} - def __getattr__(self, key: str): + GraphWrapper._all_instances.add(self) + + def __getattr__(self, key: str) -> Any: # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError( - f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}" - ) + if self.is_debugging_mode: + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self._runnable_str}" + ) + raise AttributeError def unwrap(self) -> Callable: # in case we need to access the original runnable. return self.runnable + @property + def cudagraph_wrapper(self) -> "GraphWrapper": + return self + + def clear_graphs(self) -> None: + self.concrete_graph_entries.clear() + def __call__(self, *args, **kwargs): + if not is_forward_context_available(): + # No forward context means we are outside the normal + # inference path (e.g. a vision encoder forward pass). + return self.runnable(*args, **kwargs) + forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor graph_runtime_mode = forward_context.cudagraph_runtime_mode @@ -110,14 +151,9 @@ def __call__(self, *args, **kwargs): graph_runtime_mode == CUDAGraphMode.NONE or graph_runtime_mode != self.runtime_mode ): - # CUDAGraphMode.NONE could mean the profile run, a warmup run, or - # running without cudagraphs. - # We do not trigger capture/replay if the runtime mode is not - # matches. This enables properly dispatching to the correct - # CUDAGraphWrapper when nesting multiple instances with different - # runtime modes. return self.runnable(*args, **kwargs) + assert batch_descriptor is not None if batch_descriptor not in self.concrete_graph_entries: # create a new entry for this batch descriptor self.concrete_graph_entries[batch_descriptor] = GraphEntry( @@ -127,11 +163,7 @@ def __call__(self, *args, **kwargs): entry = self.concrete_graph_entries[batch_descriptor] if entry.graph is None: - if self.graph_options.debug_log_enable: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every - # shape. E.g. we only log it for the first subgraph in - # piecewise mode. + if self.cudagraph_options.debug_log_enable: logger.debug( "Capturing a cudagraph on (%s,%s)", self.runtime_mode.name, @@ -147,32 +179,40 @@ def __call__(self, *args, **kwargs): graph = Graph.graph() with ExitStack() as stack: - if self.graph_options.gc_disable: - # during every model forward for piecewise graph - # mode, we will capture many pieces of graphs - # (roughly one per layer). running gc again and again - # across layers will make the graph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. + if self.cudagraph_options.gc_disable: stack.enter_context(patch("gc.collect", lambda: None)) + # FL-specific: patch our platform's empty_cache stack.enter_context( - patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None) + patch("vllm_fl.platform.PlatformFL.empty_cache", + lambda: None) ) - set_graph_pool_id(self.graph_pool) - - # mind-exploding: carefully manage the reference and memory. - with current_platform.torch_device_fn.graph(graph, pool=self.graph_pool): + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) + + # Sync offloader's copy stream before capture if available. + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().sync_prev_onload() + except (ImportError, RuntimeError): + pass + + # FL-specific: use platform-agnostic graph capture + with current_platform.torch_device_fn.graph( + graph, pool=self.graph_pool + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) - if self.graph_options.weak_ref_output: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph in piecewise cuadgraph mode, because - # the output of the last graph will not be used by - # any other cuda graph. - output = weak_ref_tensors(output) + # Join offloader's copy stream after forward if available + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().join_after_forward() + except (ImportError, RuntimeError): + pass + if self.cudagraph_options.weak_ref_output: + output = weak_ref_tensors(output) entry.output = weak_ref_tensors(output) entry.graph = graph @@ -195,6 +235,13 @@ def __call__(self, *args, **kwargs): f"got {new_input_addresses}" ) + # Sync offloader before replay if available + try: + from vllm.model_executor.offloader.base import get_offloader + get_offloader().sync_prev_onload() + except (ImportError, RuntimeError): + pass + current_platform.torch_device_fn.synchronize() entry.graph.replay() return entry.output diff --git a/vllm_fl/configs/qwen3_5_moe.py b/vllm_fl/configs/qwen3_5_moe.py deleted file mode 100644 index 48d68865..00000000 --- a/vllm_fl/configs/qwen3_5_moe.py +++ /dev/null @@ -1,185 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. -# All rights reserved. -"""Qwen3.5-MoE model configuration for vLLM plugin.""" - -from transformers.configuration_utils import PretrainedConfig - - -def _layer_type_validation(layer_types, num_hidden_layers): - if layer_types is not None and num_hidden_layers is not None: - if len(layer_types) != num_hidden_layers: - raise ValueError( - f"Length of layer_types ({len(layer_types)}) must match " - f"num_hidden_layers ({num_hidden_layers})" - ) - - -class Qwen3_5MoeTextConfig(PretrainedConfig): - model_type = "qwen3_5_moe_text" - keys_to_ignore_at_inference = ["past_key_values"] - base_config_key = "text_config" - - def __init__( - self, - vocab_size=248320, - hidden_size=2048, - num_hidden_layers=40, - num_attention_heads=16, - num_key_value_heads=2, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_parameters=None, - attention_bias=False, - attention_dropout=0.0, - head_dim=256, - linear_conv_kernel_dim=4, - linear_key_head_dim=128, - linear_value_head_dim=128, - linear_num_key_heads=16, - linear_num_value_heads=32, - moe_intermediate_size=512, - shared_expert_intermediate_size=512, - num_experts_per_tok=8, - num_experts=256, - norm_topk_prob=True, - output_router_logits=False, - router_aux_loss_coef=0.001, - layer_types=None, - full_attention_interval=4, - attn_output_gate=True, - pad_token_id=None, - bos_token_id=None, - eos_token_id=None, - **kwargs, - ): - kwargs.pop("ignore_keys_at_rope_validation", None) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.head_dim = head_dim - self.rope_parameters = rope_parameters - self.attn_output_gate = attn_output_gate - - self.layer_types = layer_types - if self.layer_types is None: - self.layer_types = [ - "linear_attention" - if bool((i + 1) % full_attention_interval) - else "full_attention" - for i in range(self.num_hidden_layers) - ] - _layer_type_validation(self.layer_types, self.num_hidden_layers) - - self.linear_conv_kernel_dim = linear_conv_kernel_dim - self.linear_key_head_dim = linear_key_head_dim - self.linear_value_head_dim = linear_value_head_dim - self.linear_num_key_heads = linear_num_key_heads - self.linear_num_value_heads = linear_num_value_heads - self.moe_intermediate_size = moe_intermediate_size - self.shared_expert_intermediate_size = shared_expert_intermediate_size - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.full_attention_interval = full_attention_interval - - # partial_rotary_factor is needed by rope - kwargs.setdefault("partial_rotary_factor", 0.25) - - super().__init__(**kwargs) - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.tie_word_embeddings = tie_word_embeddings - - -class Qwen3_5MoeVisionConfig(PretrainedConfig): - model_type = "qwen3_5_moe" - base_config_key = "vision_config" - - def __init__( - self, - depth=27, - hidden_size=1152, - hidden_act="gelu_pytorch_tanh", - intermediate_size=4304, - num_heads=16, - in_channels=3, - patch_size=16, - spatial_merge_size=2, - temporal_patch_size=2, - out_hidden_size=3584, - num_position_embeddings=2304, - initializer_range=0.02, - deepstack_visual_indexes=None, - **kwargs, - ): - super().__init__(**kwargs) - self.depth = depth - self.hidden_size = hidden_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.in_channels = in_channels - self.patch_size = patch_size - self.spatial_merge_size = spatial_merge_size - self.temporal_patch_size = temporal_patch_size - self.out_hidden_size = out_hidden_size - self.num_position_embeddings = num_position_embeddings - self.initializer_range = initializer_range - self.deepstack_visual_indexes = deepstack_visual_indexes or [] - - -class Qwen3_5MoeConfig(PretrainedConfig): - model_type = "qwen3_5_moe" - sub_configs = { - "vision_config": Qwen3_5MoeVisionConfig, - "text_config": Qwen3_5MoeTextConfig, - } - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - text_config=None, - vision_config=None, - image_token_id=248056, - video_token_id=248057, - vision_start_token_id=248053, - vision_end_token_id=248054, - tie_word_embeddings=False, - **kwargs, - ): - if isinstance(vision_config, dict): - self.vision_config = self.sub_configs["vision_config"](**vision_config) - elif vision_config is None: - self.vision_config = self.sub_configs["vision_config"]() - - if isinstance(text_config, dict): - self.text_config = self.sub_configs["text_config"](**text_config) - elif text_config is None: - self.text_config = self.sub_configs["text_config"]() - - self.image_token_id = image_token_id - self.video_token_id = video_token_id - self.vision_start_token_id = vision_start_token_id - self.vision_end_token_id = vision_end_token_id - super().__init__(**kwargs) - self.tie_word_embeddings = tie_word_embeddings - - -__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig", "Qwen3_5MoeVisionConfig"] diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 51fb5a84..6a382901 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -480,7 +480,7 @@ export VLLM_FL_FLAGOS_BLACKLIST="custom_op1,custom_op2" - **Environment variables override, not merge**: Setting an env var replaces the config value entirely - **`VLLM_FL_PREFER` sets preference, not exclusivity**: It defines the selection order but will fall back to other backends if the preferred one is unavailable - **To force a specific backend**: Combine `PREFER` with `DENY_VENDORS` or use `PER_OP` to exclude unwanted backends -- **`VLLM_FL_STRICT=1`**: Enables strict mode — fails immediately if the primary implementation fails, no fallback is attempted +- **`VLLM_FL_STRICT=1`**: Enables automatic fallback when the primary implementation fails at runtime #### Backend Priority Values @@ -536,7 +536,7 @@ Currently supported operators: ## Fallback Mechanism -When `VLLM_FL_STRICT=0` (default), if the primary implementation fails, the system automatically tries other available implementations: +When `VLLM_FL_STRICT=1`, if the primary implementation fails, the system automatically tries other available implementations: ``` Op 'rms_norm' using 'default.flagos' (kind=flagos, vendor=None) diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index a2f10206..1eaea799 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -21,7 +21,7 @@ Environment Variables: VLLM_FL_CONFIG: Path to YAML configuration file (highest priority, overrides env vars) VLLM_FL_PREFER: Preferred backend ("flagos", "vendor", "reference") - VLLM_FL_STRICT: Strict mode: "1" = fail immediately on error (no fallback), "0" = try fallback (default) + VLLM_FL_STRICT: Enable strict mode ("1" or "0") VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors VLLM_FL_PER_OP: Per-operator order (format: op1=a|b|c;op2=x|y) diff --git a/vllm_fl/dispatch/backends/__init__.py b/vllm_fl/dispatch/backends/__init__.py index 078452c4..ae8c3144 100644 --- a/vllm_fl/dispatch/backends/__init__.py +++ b/vllm_fl/dispatch/backends/__init__.py @@ -2,10 +2,6 @@ """ Backend implementations for vllm-plugin-FL dispatch. - -Vendor backends are dynamically discovered and loaded by builtin_ops.py -based on the current platform. This package does not eagerly import vendor -backends to avoid loading unnecessary dependencies at startup. """ from .base import Backend @@ -13,3 +9,27 @@ from .reference import ReferenceBackend __all__ = ["Backend", "FlagGemsBackend", "ReferenceBackend"] + +# Try to import vendor backends +try: + from .vendor.ascend import AscendBackend + + __all__.append("AscendBackend") +except ImportError: + AscendBackend = None + +# Add more vendor backends here as they become available +try: + from .vendor.cuda import CudaBackend + + __all__.append("CudaBackend") +except ImportError: + CudaBackend = None + +# Import MACA backend +try: + from .vendor.maca import MacaBackend + + __all__.append("MacaBackend") +except ImportError: + pass diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py index e12afaf7..a7b166ea 100644 --- a/vllm_fl/dispatch/backends/flaggems/flaggems.py +++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py @@ -144,7 +144,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum # TritonAttentionBackend requires CUDA, check if available if not torch.cuda.is_available(): diff --git a/vllm_fl/dispatch/backends/flaggems/impl/attention.py b/vllm_fl/dispatch/backends/flaggems/impl/attention.py index 44c751c4..d7e14dea 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/attention.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/attention.py @@ -11,27 +11,29 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, is_quantized_kv_cache, ) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs -from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.model_executor.layers.attention.attention import Attention +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs +from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.model_executor.layers.batch_invariant import _batch_invariant_MODE as _bi_mode from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, +) +from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, get_dcp_local_seq_lens, get_kv_cache_layout, @@ -457,7 +459,7 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = 3 # 2 #get_flash_attn_version() # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() + self.batch_invariant_enabled = _bi_mode if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( diff --git a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py index 5aa2860d..f6dc37e3 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py @@ -1,4 +1,4 @@ -from vllm.attention.backends.registry import ( +from vllm.v1.attention.backends.registry import ( AttentionBackendEnum, register_backend, ) diff --git a/vllm_fl/dispatch/backends/flaggems/impl/mla.py b/vllm_fl/dispatch/backends/flaggems/impl/mla.py index 866dcd51..1789498d 100644 --- a/vllm_fl/dispatch/backends/flaggems/impl/mla.py +++ b/vllm_fl/dispatch/backends/flaggems/impl/mla.py @@ -8,7 +8,7 @@ import torch -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, is_quantized_kv_cache, @@ -17,7 +17,7 @@ # from vllm.attention.ops.triton_decode_attention import decode_attention_fwd # from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import ( +from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBackend, MLACommonImpl, MLACommonMetadata, diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py index 7f20276c..9c10a2e5 100644 --- a/vllm_fl/dispatch/backends/reference/reference.py +++ b/vllm_fl/dispatch/backends/reference/reference.py @@ -150,7 +150,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Fully qualified class path string (vLLM native backend) """ # Return vLLM's native flash attention backend as reference - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: # vLLM native MLA backend diff --git a/vllm_fl/dispatch/backends/vendor/__init__.py b/vllm_fl/dispatch/backends/vendor/__init__.py index 6818887b..bdeacfbe 100644 --- a/vllm_fl/dispatch/backends/vendor/__init__.py +++ b/vllm_fl/dispatch/backends/vendor/__init__.py @@ -8,10 +8,6 @@ Available vendor backends: - ascend: Huawei Ascend NPU backend -This package intentionally avoids eager imports of vendor subpackages. -Importing a specific backend such as ``vllm_fl.dispatch.backends.vendor.ascend`` -should not pull in other vendor branches. - To add a new vendor backend: 1. Create a subdirectory: vendor// 2. Implement the backend class inheriting from Backend @@ -22,3 +18,34 @@ """ __all__ = [] + +# Import Ascend backend +try: + from .ascend import AscendBackend + + __all__.append("AscendBackend") +except ImportError: + pass + +# Import CUDA backend +try: + from .cuda import CudaBackend + + __all__.append("CudaBackend") +except ImportError: + pass + +# Import MACA backend +try: + from .maca import MacaBackend + + __all__.append("MacaBackend") +except ImportError: + pass + +# Add more vendor backends here as they become available: +# try: +# from .rocm import RocmBackend +# __all__.append("RocmBackend") +# except ImportError: +# pass diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py index 4dcdd7d7..494cf508 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionLayer, @@ -36,7 +36,8 @@ ) from vllm.config import VllmConfig, get_current_vllm_config from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import AttentionCGSupport, CommonAttentionMetadata +from vllm.v1.attention.backend import AttentionCGSupport +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm_fl.dispatch.backends.vendor.ascend.impl.attention_mask import ( AttentionMaskBuilder, diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py index 2aad980a..a6fa90f2 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py @@ -15,7 +15,7 @@ import torch.nn.functional as F import triton import triton.language as tl -from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.utils import PAD_SLOT_ID def causal_conv1d_ref( diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py index 4e8c0758..78dfcd24 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py @@ -21,7 +21,7 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.config import MultiModalConfig MIN_PAD_SIZE = 64 # min_size to pad weight diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py index 8260f067..e504559d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py @@ -174,7 +174,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: if use_sparse: diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py index 20f67053..049deb3d 100644 --- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py +++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py @@ -22,11 +22,9 @@ def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor: Returns: Output tensor of shape [..., d] """ - from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul - d = x.shape[-1] // 2 out = torch.empty(*x.shape[:-1], d, dtype=x.dtype, device=x.device) - vllm_silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) return out diff --git a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py index dacf6753..e54f7987 100644 --- a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py +++ b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py @@ -154,7 +154,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum if use_mla: if use_sparse: diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py index 35cdd6a9..7d10e52f 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py @@ -9,15 +9,15 @@ import numpy as np import torch -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, AttentionType, MultipleOf, is_quantized_kv_cache, ) -from vllm.attention.layer import Attention -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.model_executor.layers.attention.attention import Attention +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs # -------------------------------------------------------------- # Note: use Maca's merge_attn_states to get cuda kernel invoked @@ -42,13 +42,15 @@ from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, +) +from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, get_dcp_local_seq_lens, get_kv_cache_layout, @@ -56,7 +58,7 @@ reshape_attn_output_for_spec_decode, # used for prefill decode split with mtp reshape_query_for_spec_decode, # used for prefill decode split with mtp ) -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend from vllm.v1.kv_cache_interface import AttentionSpec # -------------------------------------------------------------- @@ -450,7 +452,7 @@ def build( prefill_block_table_tensor = None # \------------------------- Metax Modification -------------------------/ - if vllm_is_batch_invariant(): + if _bi_mode: max_num_splits = 1 def schedule( @@ -645,7 +647,7 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() + self.batch_invariant_enabled = _bi_mode if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py index 3d06c6c3..94ec2e51 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py @@ -198,13 +198,13 @@ from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import ( +from vllm.v1.attention.backend import ( AttentionBackend, AttentionLayer, MLAAttentionImpl, ) -from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.common import cp_lse_ag_out_rs +from vllm.model_executor.layers.attention.mla_attention import get_mla_dims +from vllm.v1.attention.ops.common import cp_lse_ag_out_rs # -------------------------------------------------------------- # Note: use Maca's merge_attn_states to get cuda kernel invoked @@ -214,7 +214,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -272,7 +272,9 @@ class QueryLenSupport(Enum): flashinfer_available = False -from vllm.v1.attention.backends.mla.common import logger +from vllm.logger import init_logger + +logger = init_logger(__name__) CUDNN_WORKSPACE_SIZE = 12800 @@ -1276,7 +1278,7 @@ def _flash_attn_varlen_diff_headdims( # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse - if vllm_is_batch_invariant(): + if _bi_mode: kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py index 823ff4ba..6f10c44e 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py @@ -7,7 +7,7 @@ import torch -from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf +from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from ..ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, @@ -17,7 +17,7 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, + _batch_invariant_MODE as _bi_mode, ) from vllm.platforms.interface import DeviceCapability from .common import ( @@ -28,13 +28,13 @@ MLACommonMetadataBuilder, QueryLenSupport, ) +from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, ) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend logger = init_logger(__name__) @@ -271,7 +271,7 @@ def _forward_decode( tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata num_splits = attn_metadata.decode.num_splits - if vllm_is_batch_invariant(): + if _bi_mode: device = q.device dtype = torch.int32 diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py index 446772c2..4395bb25 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py @@ -43,7 +43,7 @@ def supported_headdim(o: torch.Tensor) -> bool: output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse ) else: - from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + from vllm.v1.attention.ops.merge_attn_states import merge_attn_states return merge_attn_states( output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py index 56d2c787..b21344c5 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py @@ -2,7 +2,9 @@ # 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.utils.fa_utils import logger +import logging + +logger = logging.getLogger(__name__) from vllm.platforms import current_platform diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py index 66de575f..d9dee3c9 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/metax.py +++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py @@ -14,7 +14,7 @@ from vllm_fl.dispatch.backends.base import Backend -from vllm.attention.backends.registry import AttentionBackendEnum, register_backend +from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend # Register attention backends for MACA @@ -147,7 +147,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> Returns: Fully qualified class path string """ - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum # register before selection register_attention_backends() diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py index ec207a5b..426c6019 100644 --- a/vllm_fl/dispatch/builtin_ops.py +++ b/vllm_fl/dispatch/builtin_ops.py @@ -12,7 +12,6 @@ import importlib import os -from .config import get_vendor_device_map from .registry import OpRegistry from .logger_manager import get_logger @@ -22,46 +21,6 @@ _VENDOR_BACKENDS_DIR = os.path.join(os.path.dirname(__file__), "backends", "vendor") -def _find_vendor_backend_dir( - vendor_name: str, - available_vendor_dirs: set[str], -) -> str | None: - """Return the backend directory for *vendor_name*, or None if not found. - - Resolves *vendor_name* against get_vendor_device_map() and picks the first - candidate directory that exists in *available_vendor_dirs*. The "maca" alias - is treated as "metax" for MetaX runtime compatibility. - """ - # Keep compatibility with MetaX runtime naming. - if vendor_name == "maca": - vendor_name = "metax" - vendor_map = get_vendor_device_map() - if vendor_name not in vendor_map: - return None - value = vendor_map[vendor_name] - device_type = value.get("device_type") - device_name = value.get("device_name") - return next( - (c for c in (vendor_name, device_name, device_type) if c in available_vendor_dirs), - None, - ) - - -def _get_current_vendor_backend_dirs(available_vendor_dirs: set[str]) -> set[str]: - """Detect current platform vendor name and return its backend directory.""" - try: - from vllm.platforms import current_platform - - vendor_name = getattr(current_platform, "vendor_name", None) - if not isinstance(vendor_name, str) or not vendor_name: - return None - return _find_vendor_backend_dir(vendor_name, available_vendor_dirs) - except Exception as exc: - raise RuntimeError( - "Failed to detect current vendor backend from current_platform." - ) from exc - - def _register_vendor_backends(registry: OpRegistry) -> None: """ Auto-discover and register all vendor backends. @@ -76,33 +35,11 @@ def _register_vendor_backends(registry: OpRegistry) -> None: logger.debug(f"Vendor backends directory not found: {_VENDOR_BACKENDS_DIR}") return - available_vendor_dirs = { - vendor_name - for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR) - if os.path.isdir(os.path.join(_VENDOR_BACKENDS_DIR, vendor_name)) - and not vendor_name.startswith("_") - } - - current_vendor_dir = _get_current_vendor_backend_dirs(available_vendor_dirs) - if not current_vendor_dir: - logger.warning( - "Unable to detect current vendor backend; skipping vendor backend registration" - ) - return - - logger.info( - "Registering vendor backends for current platform: %s", - current_vendor_dir, - ) - - for vendor_name in available_vendor_dirs: + for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR): vendor_path = os.path.join(_VENDOR_BACKENDS_DIR, vendor_name) - if vendor_name != current_vendor_dir: - logger.debug( - "Skipping vendor backend '%s' for current platform", - vendor_name, - ) + # Skip non-directories and special files + if not os.path.isdir(vendor_path) or vendor_name.startswith("_"): continue # Skip if no register_ops.py exists diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py index bafd6d4e..c60b1ce8 100644 --- a/vllm_fl/dispatch/config/__init__.py +++ b/vllm_fl/dispatch/config/__init__.py @@ -14,7 +14,6 @@ get_oot_blacklist, get_per_op_order, get_platform_name, - get_vendor_device_map, load_platform_config, ) @@ -26,5 +25,4 @@ 'get_flagos_blacklist', 'get_oot_blacklist', 'get_effective_config', - 'get_vendor_device_map', ] diff --git a/vllm_fl/dispatch/config/nvidia.yaml b/vllm_fl/dispatch/config/nvidia.yaml index 3c8676ff..0b06a8a3 100644 --- a/vllm_fl/dispatch/config/nvidia.yaml +++ b/vllm_fl/dispatch/config/nvidia.yaml @@ -10,8 +10,8 @@ prefer: flagos strict: false # Vendor Whitelist (Optional, allows all if not set) -# allow_vendors: -# - cuda +allow_vendors: + - cuda # Vendor Blacklist (Optional) # deny_vendors: diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index 2a6fd181..411cf943 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -11,7 +11,7 @@ 1. VLLM_FL_CONFIG: User-specified config file path (complete override) 2. Environment variables: Override specific items from platform config - VLLM_FL_PREFER: Backend preference (flagos, vendor, reference) - - VLLM_FL_STRICT: Strict mode: 1 = fail immediately on error (no fallback), 0 = try fallback (default) + - VLLM_FL_STRICT: Strict mode (1 or 0) - VLLM_FL_PER_OP: Per-operator backend order - VLLM_FL_FLAGOS_BLACKLIST: FlagOS operator blacklist - VLLM_FL_OOT_BLACKLIST: OOT operator blacklist @@ -35,7 +35,6 @@ from typing import Any, Optional import yaml -from vllm_fl.utils import VENDOR_DEVICE_MAP # Directory containing config files (config/) _CONFIG_DIR = Path(__file__).parent @@ -208,32 +207,3 @@ def get_effective_config() -> dict[str, Any]: # Return empty config return {} - -def get_vendor_device_map() -> dict[str, dict[str, str]]: - """Load vendor mapping from Python config module. - - Returns: - Mapping where key is vendor_name and value is - {"device_type": ..., "device_name": ...}. - """ - if not isinstance(VENDOR_DEVICE_MAP, dict): - return {} - - result: dict[str, dict[str, str]] = {} - for vendor_name, value in VENDOR_DEVICE_MAP.items(): - if not isinstance(vendor_name, str) or not isinstance(value, dict): - continue - - device_type = value.get("device_type") - device_name = value.get("device_name") - if not isinstance(device_type, str) or not device_type: - continue - if not isinstance(device_name, str) or not device_name: - continue - - result[vendor_name] = { - "device_type": device_type, - "device_name": device_name, - } - - return result \ No newline at end of file diff --git a/vllm_fl/dispatch/logger_manager.py b/vllm_fl/dispatch/logger_manager.py index 37867e8e..57e4a4f0 100644 --- a/vllm_fl/dispatch/logger_manager.py +++ b/vllm_fl/dispatch/logger_manager.py @@ -38,7 +38,7 @@ def get_logger(name: str = "vllm_fl.dispatch") -> logging.Logger: if not logger.handlers: handler = logging.StreamHandler() formatter = logging.Formatter( - "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", + "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler.setFormatter(formatter) diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py index c886940e..cee87132 100644 --- a/vllm_fl/dispatch/manager.py +++ b/vllm_fl/dispatch/manager.py @@ -483,11 +483,8 @@ def call(self, op_name: str, *args, **kwargs): """ Resolve and call an operator implementation with optional fallback support. - Behavior is controlled by the active policy's strict flag (VLLM_FL_STRICT): - - VLLM_FL_STRICT=0 (default): fallback mode — if the primary implementation - fails, the system automatically tries the next available implementation. - - VLLM_FL_STRICT=1: strict mode — fail immediately on the first error, - no fallback is attempted. + When VLLM_FL_STRICT=1, this method will try alternative implementations + if the primary one fails. Otherwise, it behaves like the original implementation. Logs on first call or when the implementation changes (e.g., backend switch). @@ -499,10 +496,10 @@ def call(self, op_name: str, *args, **kwargs): Result from the implementation Raises: - RuntimeError: If all implementations fail (fallback mode) or - if the primary implementation fails (strict mode) + RuntimeError: If all implementations fail (when fallback enabled) or + if the primary implementation fails (when fallback disabled) """ - enable_fallback = not get_policy().strict + enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" if not enable_fallback: # Original behavior: use cached resolve() and fast-fail diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index e7f06139..f1523e80 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -386,7 +386,7 @@ def _policy_from_env(self) -> SelectionPolicy: Environment variables: - VLLM_FL_CONFIG: Path to YAML configuration file (complete override) - VLLM_FL_PREFER: Preference (flagos, vendor, reference) - - VLLM_FL_STRICT: Strict mode: 1 = fail immediately on error (no fallback), 0 = try fallback (default) + - VLLM_FL_STRICT: Enable strict mode (1 or 0) - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors - VLLM_FL_PER_OP: Per-op order (format: op1=a|b|c;op2=x|y) @@ -409,7 +409,7 @@ def _policy_from_env(self) -> SelectionPolicy: # Priority 2: Environment variables override platform config # Get values from environment variables env_prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() - env_strict_str = os.environ.get("VLLM_FL_STRICT", "0").strip() + env_strict_str = os.environ.get("VLLM_FL_STRICT", "").strip() env_deny_str = os.environ.get("VLLM_FL_DENY_VENDORS", "").strip() env_allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() env_per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() @@ -423,14 +423,7 @@ def _policy_from_env(self) -> SelectionPolicy: prefer_str = PREFER_DEFAULT if env_strict_str: - if env_strict_str not in ("0", "1"): - logger.warning( - f"Invalid VLLM_FL_STRICT value '{env_strict_str}', " - f"expected '0' or '1'. Defaulting to '0' (fallback mode)." - ) - strict = False - else: - strict = env_strict_str == "1" + strict = env_strict_str == "1" elif platform_policy: strict = platform_policy.strict else: diff --git a/vllm_fl/models/glm_moe_dsa.py b/vllm_fl/models/glm_moe_dsa.py deleted file mode 100644 index 4b88c4d0..00000000 --- a/vllm_fl/models/glm_moe_dsa.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Inference-only GLM-5 (GlmMoeDsa) model. - -GLM-5 uses a DeepSeek V2/V3-style architecture with MLA (Multi-head Latent -Attention) and Mixture of Experts. The HF model type is ``glm_moe_dsa`` and -the architecture class is ``GlmMoeDsaForCausalLM``. - -This thin wrapper inherits from vLLM's ``DeepseekV2ForCausalLM`` which already -handles MLA, MoE, the DSA Indexer, and MTP speculative decoding layers. -""" - -from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM - - -class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM): - """GLM-5 model for causal language modelling.""" - pass diff --git a/vllm_fl/models/kimi_k25.py b/vllm_fl/models/kimi_k25.py deleted file mode 100644 index 894839e5..00000000 --- a/vllm_fl/models/kimi_k25.py +++ /dev/null @@ -1,245 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal Kimi-K2.5 model support for text-only inference. - -This is a simplified implementation that wraps DeepseekV3 for text-only -benchmarking. Vision components are not included. -""" - -import copy -from collections.abc import Iterable - -import torch -from torch import nn - -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Model, - get_spec_layer_idx_from_weight_name, -) -from vllm.model_executor.models.interfaces import SupportsPP -from vllm.model_executor.models.utils import ( - PPMissingLayer, - is_pp_missing_parameter, - maybe_prefix, -) -from vllm.sequence import IntermediateTensors - -logger = init_logger(__name__) - - -class KimiK25ForConditionalGeneration(nn.Module, SupportsPP): - """Kimi-K2.5 model for text-only conditional generation. - - This is a minimal implementation that uses DeepseekV2Model as the - language backbone. Vision components are not included for simplicity. - """ - - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - super().__init__() - model_config = vllm_config.model_config - config = model_config.hf_config - self.config = config - quant_config = vllm_config.quant_config - - # Extract text config from KimiK25Config - # The text_config is a DeepseekV3Config - text_config = getattr(config, "text_config", config) - self.hidden_size = text_config.hidden_size - - # Create a modified vllm_config with text_config as hf_config - sub_vllm_config = copy.deepcopy(vllm_config) - sub_vllm_config.model_config.hf_config = text_config - - # Build language model using DeepseekV2Model - self.language_model = DeepseekV2Model( - vllm_config=sub_vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - # Build lm_head - if get_pp_group().is_last_rank: - vocab_size = getattr(config, "vocab_size", text_config.vocab_size) - self.lm_head = ParallelLMHead( - vocab_size, - text_config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - logit_scale = getattr(config, "logit_scale", 1.0) - vocab_size = getattr(config, "vocab_size", text_config.vocab_size) - self.logits_processor = LogitsProcessor(vocab_size, scale=logit_scale) - - def get_language_model(self) -> nn.Module: - return self.language_model - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Apply token embeddings to input_ids.""" - return self.language_model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - hidden_states = self.language_model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) - return logits - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - """Get expert parameter mapping for MoE layers.""" - text_config = getattr(self.config, "text_config", self.config) - if not getattr(text_config, "n_routed_experts", None): - return [] - return SharedFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=text_config.n_routed_experts, - num_redundant_experts=0, - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - """Load weights with proper name remapping for Kimi-K2.5.""" - text_config = getattr(self.config, "text_config", self.config) - - # Weight name remapping for Kimi-K2.5 -> DeepseekV2 - _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", - } - - stacked_params_mapping = [ - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - if getattr(text_config, "kv_lora_rank", None) and getattr( - text_config, "q_lora_rank", None - ): - stacked_params_mapping += [ - (".fused_qkv_a_proj", ".q_a_proj", 0), - (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), - ] - expert_params_mapping = self.get_expert_mapping() - - params_dict = dict(self.named_parameters()) - - for args in weights: - name, loaded_weight = args[:2] - kwargs = args[2] if len(args) > 2 else {} - - # Skip rotary embedding cached values - if "rotary_emb.inv_freq" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - continue - - # Skip speculative decode layers - spec_layer = get_spec_layer_idx_from_weight_name(text_config, name) - if spec_layer is not None: - continue - - # Skip vision tower weights (not needed for text-only inference) - if "vision_tower" in name or "mm_projector" in name: - continue - - # Apply key remapping - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - - use_default_weight_loading = False - - # Handle stacked parameters - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, **kwargs) - break - else: - # Handle expert parameters - for _, ( - param_name, - weight_name, - expert_id, - shard_id, - ) in enumerate(expert_params_mapping): - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - expert_id=expert_id, - shard_id=shard_id, - **kwargs, - ) - break - else: - use_default_weight_loading = True - - if use_default_weight_loading: - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict.get(name) - if param is None: - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, **kwargs) diff --git a/vllm_fl/models/minicpmo.py b/vllm_fl/models/minicpmo.py deleted file mode 100644 index b0491510..00000000 --- a/vllm_fl/models/minicpmo.py +++ /dev/null @@ -1,854 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" - -from collections.abc import Callable, Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias - -import torch -from torch import nn -from transformers import BatchFeature -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import ( - ACT2FN, - WhisperAttention, - WhisperConfig, - WhisperEncoder, -) - -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - NestedTensors, -) -from vllm.multimodal.parse import ( - AudioItem, - AudioProcessorItems, - DictEmbeddingItems, - ModalityData, - ModalityDataItems, - MultiModalDataItems, - MultiModalDataParser, -) -from vllm.multimodal.processing import ( - PromptReplacement, - PromptUpdate, - PromptUpdateDetails, -) -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from vllm.model_executor.models.minicpmv import ( - _MAX_FRAMES_PER_VIDEO, - MiniCPMV2_6, - MiniCPMV4_5, - MiniCPMVDummyInputsBuilder, - MiniCPMVMultiModalDataParser, - MiniCPMVMultiModalProcessor, - MiniCPMVProcessingInfo, - _minicpmv_field_config, -) -from vllm.model_executor.models.utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix - -CPU_DEVICE = torch.device("cpu") - - -class MiniCPMOAudioFeatureInputs(TensorSchema): - """ - Dimensions: - - bns: Batch size * number of audios * number of slices - - bn: Batch size * number of audios - - c: Number of channels - - l: Length - - s: Number of slices - """ - - type: Literal["audio_features"] = "audio_features" - - audio_features: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bns", "c", "l", dynamic_dims={"l"}), - ] - """ - Slice here means chunk. Audio that is too long will be split into slices, - which is the same as image. Padding is used therefore `audio_features` is - `torch.Tensor`. - """ - - audio_feature_lens: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "s"), - ] - """ - This should be feature length of each audio slice, - which equals to `audio_features.shape[-1]` - """ - - -class MiniCPMOAudioEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of audios - - s: Number of slices - - h: Hidden size (must match language model backbone) - - Length of each slice may vary, so pass it as a list. - """ - - type: Literal["audio_embeds"] = "audio_embeds" - - audio_embeds: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "s", "h", dynamic_dims={"s"}), - ] - - -MiniCPMOAudioInputs: TypeAlias = ( - MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs -) - - -def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): - return dict( - **_minicpmv_field_config(hf_inputs), - audio_features=MultiModalFieldConfig.batched("audio"), - audio_feature_lens=MultiModalFieldConfig.batched("audio"), - audio_embeds=MultiModalFieldConfig.batched("audio"), - ) - - -class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): - def __init__( - self, - data: Mapping[str, torch.Tensor], - fields_factory: Callable[ - [Mapping[str, torch.Tensor]], - Mapping[str, MultiModalFieldConfig], - ], - ) -> None: - super().__init__( - data, - modality="image", - required_fields={"audio_embeds"}, - fields_factory=fields_factory, - ) - - -class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): - def _parse_audio_data( - self, - data: dict[str, torch.Tensor] | ModalityData[AudioItem], - ) -> ModalityDataItems[Any, Any] | None: - if isinstance(data, dict): - return MiniCPMOAudioEmbeddingItems( - data, - fields_factory=_minicpmo_field_config, - ) - - return super()._parse_audio_data(data) - - -class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): - audio_pattern = "()" - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {**super().get_supported_mm_limits(), "audio": None} - - def get_audio_placeholder( - self, - audio_lens: int, - chunk_input: bool = True, - chunk_length: int = 1, - ) -> str: - hf_processor = self.get_hf_processor() - - return hf_processor.get_audio_placeholder( - audio_lens, - chunk_input=chunk_input, - chunk_length=chunk_length, - ) - - def get_default_audio_pool_step(self) -> int: - hf_config = self.get_hf_config() - # MiniCPM-o 4.5 uses pool_step=5, older versions use 2 - return getattr(hf_config, "audio_pool_step", 2) - - def get_default_audio_sampling_rate(self) -> int: - return 16000 - - def get_chunk_length(self) -> int: - return self.get_hf_config().audio_chunk_length - - def get_max_audio_tokens_per_chunk(self) -> int: - pool_step = self.get_default_audio_pool_step() - fbank_feat_in_chunk = 100 - cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 - return (cnn_feat_in_chunk - pool_step) // pool_step + 1 - - def get_max_audio_chunks_with_most_features(self) -> int: - return 30 - - def get_max_audio_tokens(self) -> int: - num_chunks = self.get_max_audio_chunks_with_most_features() - return self.get_max_audio_tokens_per_chunk() * num_chunks - - def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: - sampling_rate = self.get_default_audio_sampling_rate() - num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 - - def get_num_frames_with_most_features( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> int: - max_images = mm_counts.get("image", 0) - max_videos = mm_counts.get("video", 0) - max_audios = mm_counts.get("audio", 0) - - max_image_tokens = self.get_max_image_tokens() * max_images - max_audio_tokens = self.get_max_audio_tokens() * max_audios - max_total_frames = self.get_max_video_frames( - seq_len - max_image_tokens - max_audio_tokens - ) - max_frames_per_video = min( - max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO - ) - - return max(max_frames_per_video, 1) - - -class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_audios = mm_counts.get("audio", 0) - - audio_prompt_texts = self.info.audio_pattern * num_audios - - return super().get_dummy_text(mm_counts) + audio_prompt_texts - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_audios = mm_counts.get("audio", 0) - audio_len = ( - self.info.get_max_audio_chunks_with_most_features() - * self.info.get_default_audio_sampling_rate() - ) - - audio_overrides = mm_options.get("audio") if mm_options else None - - audio_mm_data = { - "audio": self._get_dummy_audios( - length=audio_len, num_audios=num_audios, overrides=audio_overrides - ) - } - - return { - **super().get_dummy_mm_data(seq_len, mm_counts, mm_options), - **audio_mm_data, - } - - -class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: - return MiniCPMOMultiModalDataParser( - target_sr=self.info.get_default_audio_sampling_rate() - ) - - def get_audio_prompt_texts( - self, - audio_lens: int, - chunk_input: bool = True, - chunk_length: int = 1, - ) -> str: - return self.info.get_audio_placeholder( - audio_lens, - chunk_input=chunk_input, - chunk_length=chunk_length, - ) - - def process_audios( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: - if (audios := mm_data.get("audios")) is None: - return {} - - parsed_audios = ( - self._get_data_parser() - .parse_mm_data({"audio": audios}) - .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) - ) - - if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): - audio_inputs = {} - else: - audio_inputs = self._base_call_hf_processor( - prompts=[self.info.audio_pattern] * len(parsed_audios), - mm_data={"audios": [[audio] for audio in parsed_audios]}, - mm_kwargs={**mm_kwargs, "chunk_input": True}, - tok_kwargs=tok_kwargs, - out_keys={"audio_features", "audio_feature_lens"}, - ) - - # Avoid padding since we need the output for each audio to be - # independent of other audios for the cache to work correctly - unpadded_audio_features = [ - feat[:, :feature_len] - for feat, feature_len in zip( - audio_inputs["audio_features"], - audio_inputs["audio_feature_lens"], - ) - ] - audio_inputs["audio_features"] = unpadded_audio_features - - return audio_inputs - - def process_mm_inputs( - self, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> Mapping[str, NestedTensors]: - return { - **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs), - **self.process_audios(mm_data, mm_kwargs, tok_kwargs), - } - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - base_updates = super()._get_prompt_updates( - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - out_mm_kwargs=out_mm_kwargs, - ) - - audio_placeholder = self.info.audio_pattern - - def get_audio_replacement(item_idx: int): - audios = mm_items.get_items( - "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems) - ) - - if isinstance(audios, MiniCPMOAudioEmbeddingItems): - single_audio_embeds = audios.get(item_idx)["audio_embeds"] - audio_len = self.info.get_audio_len_by_num_chunks( - sum(map(len, single_audio_embeds)) - ) - else: - audio_len = audios.get_audio_length(item_idx) - - return PromptUpdateDetails.select_text( - self.get_audio_prompt_texts(audio_len), - "", - ) - - return [ - *base_updates, - PromptReplacement( - modality="audio", - target=audio_placeholder, - replacement=get_audio_replacement, - ), - ] - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return _minicpmo_field_config(hf_inputs) - - -class MultiModalProjector(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) - self.relu = nn.ReLU() - self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) - - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.relu(self.linear1(audio_features)) - hidden_states = self.linear2(hidden_states) - return hidden_states - - -class MiniCPMWhisperEncoderLayer(nn.Module): - def __init__(self, config: WhisperConfig, layer_idx: int): - super().__init__() - self.embed_dim = config.d_model - self.self_attn = WhisperAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.final_layer_norm(hidden_states) - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout( - hidden_states, p=self.activation_dropout, training=self.training - ) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - hidden_states = residual + hidden_states - - if hidden_states.dtype == torch.float16: - hidden_states = cast_overflow_tensors(hidden_states) - - outputs = (hidden_states,) - - return outputs - - -class MiniCPMWhisperEncoder(WhisperEncoder): - def __init__(self, config: WhisperConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [ - MiniCPMWhisperEncoderLayer(config, layer_idx=i) - for i in range(config.encoder_layers) - ] - ) - - def forward( - self, - input_features: torch.Tensor, - attention_mask: torch.Tensor | None = None, - ) -> BaseModelOutputWithPast: - # Ignore copy - input_features = input_features.to( - dtype=self.conv1.weight.dtype, device=self.conv1.weight.device - ) - - inputs_embeds = nn.functional.gelu(self.conv1(input_features)) - inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) - - inputs_embeds = inputs_embeds.permute(0, 2, 1) - - embed_pos = self.embed_positions.weight - - embed_pos = embed_pos[: inputs_embeds.shape[1], :] - - hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - encoder_states = () - - for idx, encoder_layer in enumerate(self.layers): - encoder_states = encoder_states + (hidden_states,) - to_drop = False - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: # skip the layer - to_drop = True - - # Ignore copy - if to_drop: - layer_outputs = (None, None) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - ) - - hidden_states = layer_outputs[0] - - hidden_states = self.layer_norm(hidden_states) - encoder_states = encoder_states + (hidden_states,) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - hidden_states=encoder_states, - ) - - -class MiniCPMOBaseModel: - """Base mixin class for MiniCPM-O models with audio support.""" - - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "(./)" - if modality.startswith("video"): - return "()" - if modality.startswith("audio"): - return "()" - - raise ValueError("Only image, video or audio modality is supported") - - def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Do not use parameters temporarily - audio_config = self.config.audio_config - model = MiniCPMWhisperEncoder(audio_config) - audio_output_dim = int(audio_config.encoder_ffn_dim // 4) - self.audio_avg_pooler = nn.AvgPool1d( - self.config.audio_pool_step, stride=self.config.audio_pool_step - ) - self.audio_projection_layer = MultiModalProjector( - in_dim=audio_output_dim, out_dim=self.embed_dim - ) - self.audio_encoder_layer = -1 - return model - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) - return loader.load_weights(weights) - - def subsequent_chunk_mask( - self, - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = CPU_DEVICE, - num_lookhead: int = 0, - ) -> torch.Tensor: - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - # Vectorized computation of row indices and chunk boundaries - row_indices = torch.arange(size, device=device) - chunk_indices = row_indices // chunk_size - if num_left_chunks < 0: - # If num_left_chunks < 0, start is always 0 for all rows - start_indices = torch.zeros_like(row_indices) - else: - # Compute start indices vectorially - start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0) - start_indices = start_chunk_indices * chunk_size - # Compute ending indices vectorially - end_chunk_indices = chunk_indices + 1 - end_indices = torch.clamp( - end_chunk_indices * chunk_size + num_lookhead, max=size - ) - # Create column indices for broadcasting - col_indices = torch.arange(size, device=device).unsqueeze(0) - start_indices = start_indices.unsqueeze(1) - end_indices = end_indices.unsqueeze(1) - # Vectorized mask creation - ret = (col_indices >= start_indices) & (col_indices < end_indices) - return ret - - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): - input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 - input_lengths_after_pooling = ( - input_lengths_after_cnn - self.config.audio_pool_step - ) // self.config.audio_pool_step + 1 - input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) - - return input_lengths_after_cnn, input_lengths_after_pooling - - def get_audio_hidden_states( - self, data: MiniCPMOAudioFeatureInputs - ) -> list[torch.Tensor]: - chunk_length = self.config.audio_chunk_length - - # (bs, 80, frames) or [], multi audios need filled in advance - wavforms_raw = data["audio_features"] - if isinstance(wavforms_raw, list): - B = len(wavforms_raw) - C = wavforms_raw[0].shape[-2] - L = max(item.shape[-1] for item in wavforms_raw) - device = wavforms_raw[0].device - dtype = wavforms_raw[0].dtype - - wavforms = torch.zeros((B, C, L), dtype=dtype, device=device) - for i, wavforms_item in enumerate(wavforms_raw): - L_item = wavforms_item.shape[-1] - wavforms[i, ..., :L_item] = wavforms_item - else: - wavforms = wavforms_raw - - # list, [[x1, x2], [y1], [z1]] - audio_feature_lens_raw = data["audio_feature_lens"] - if isinstance(audio_feature_lens_raw, torch.Tensor): - audio_feature_lens_raw = audio_feature_lens_raw.unbind(0) - - audio_feature_lens = torch.hstack(audio_feature_lens_raw) - batch_size, _, max_mel_seq_len = wavforms.shape - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 - - # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = ( - torch.arange( - 0, - max_seq_len, - dtype=audio_feature_lens.dtype, - device=audio_feature_lens.device, - ) - .unsqueeze(0) - .expand(batch_size, max_seq_len) - ) - lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) - # Create mask - padding_mask = seq_range >= lengths_expand # 1 for padded values - - audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( - batch_size, 1, max_seq_len, max_seq_len - ) - audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device - ) - - if chunk_length > 0: - chunk_num_frame = int(chunk_length * 50) - chunk_mask = self.subsequent_chunk_mask( - size=max_seq_len, - chunk_size=chunk_num_frame, - num_left_chunks=-1, - device=audio_attention_mask_.device, - ) - audio_attention_mask_ = torch.logical_or( - audio_attention_mask_, torch.logical_not(chunk_mask) - ) - - audio_attention_mask[audio_attention_mask_] = float("-inf") - audio_states = self.apm( - wavforms, attention_mask=audio_attention_mask - ).hidden_states[self.audio_encoder_layer] - audio_embeds = self.audio_projection_layer(audio_states) - - audio_embeds = audio_embeds.transpose(1, 2) - audio_embeds = self.audio_avg_pooler(audio_embeds) - audio_embeds = audio_embeds.transpose(1, 2) - - _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( - audio_feature_lens - ) - - num_audio_tokens = feature_lens_after_pooling - - final_audio_embeds = list[torch.Tensor]() - idx = 0 - for i in range(len(audio_feature_lens_raw)): - target_audio_embeds_lst = list[torch.Tensor]() - for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds_lst.append( - audio_embeds[idx, : num_audio_tokens[idx], :] - ) - idx += 1 - - final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) - - return final_audio_embeds - - def _parse_and_validate_audio_input( - self, **kwargs: object - ) -> MiniCPMOAudioInputs | None: - audio_features = kwargs.pop("audio_features", None) - audio_embeds = kwargs.pop("audio_embeds", None) - - if audio_features is None and audio_embeds is None: - return None - - if audio_embeds is not None: - return MiniCPMOAudioEmbeddingInputs( - type="audio_embeds", - audio_embeds=audio_embeds, - ) - - audio_feature_lens = kwargs.pop("audio_feature_lens") - - return MiniCPMOAudioFeatureInputs( - type="audio_features", - audio_features=audio_features, - audio_feature_lens=audio_feature_lens, - ) - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = super()._parse_and_validate_multimodal_inputs(**kwargs) - - # Preserve the order of modalities if there are multiple of them - # from the order of kwargs. - for input_key in kwargs: - if ( - input_key in ("audio_features", "audio_embeds") - and "audios" not in modalities - ): - modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) - - return modalities - - def _process_audio_input( - self, - audio_input: MiniCPMOAudioInputs, - ) -> torch.Tensor | list[torch.Tensor]: - if audio_input["type"] == "audio_embeds": - return audio_input["audio_embeds"] - - return self.get_audio_hidden_states(audio_input) - - def _process_multimodal_inputs(self, modalities: dict): - multimodal_embeddings = super()._process_multimodal_inputs(modalities) - - for modality in modalities: - if modality == "audios": - audio_input = modalities["audios"] - audio_embeddings = self._process_audio_input(audio_input) - multimodal_embeddings += tuple(audio_embeddings) - - return multimodal_embeddings - - -class MiniCPMO2_6(MiniCPMOBaseModel, MiniCPMV2_6): - """MiniCPM-O model based on MiniCPMV 2.6 (Qwen2 backbone).""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Skip MiniCPMV2_6.__init__ version assertion, call MiniCPMVBaseModel directly - from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel - MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix) - # Override version for MiniCPM-O 2.6 - self.version = (2, 6) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) - - -class MiniCPMO4_5(MiniCPMOBaseModel, MiniCPMV4_5): - """MiniCPM-O 4.5 model based on MiniCPMV 4.5 (Qwen3 backbone).""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Skip MiniCPMV4_5.__init__ version assertion, call MiniCPMVBaseModel directly - from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel - MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix) - # Override version for MiniCPM-O 4.5 - self.version = (4, 5) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) - - -_MINICPMO_SUPPORT_VERSION = { - (2, 6): MiniCPMO2_6, - (4, 5): MiniCPMO4_5, -} - - -@MULTIMODAL_REGISTRY.register_processor( - MiniCPMOMultiModalProcessor, - info=MiniCPMOProcessingInfo, - dummy_inputs=MiniCPMODummyInputsBuilder, -) -class MiniCPMO(MiniCPMOBaseModel, MiniCPMV2_6): - """ - MiniCPM-O model with audio support. - Different versions use different LLM backbones: - - Version 2.6: Uses Qwen2 - - Version 4.5: Uses Qwen3 - """ - - def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - - # Determine version from config - if hasattr(config, "version"): - version = str(config.version).split(".") - version = tuple([int(x) for x in version]) - else: - # Auto-detect version based on config features: - # - MiniCPM-o 4.5 (Qwen3 backbone): has head_dim attribute, - # hidden_size=4096, num_hidden_layers=36 - # - MiniCPM-o 2.6 (Qwen2 backbone): no head_dim, different arch - has_head_dim = hasattr(config, "head_dim") - is_qwen3_like = ( - has_head_dim - and getattr(config, "hidden_size", 0) == 4096 - and getattr(config, "num_hidden_layers", 0) == 36 - ) - if is_qwen3_like: - version = (4, 5) - else: - # Default to 2.6 for backward compatibility - version = (2, 6) - - # Dispatch class based on version - instance_cls = _MINICPMO_SUPPORT_VERSION.get(version) - if instance_cls is None: - supported_versions = ", ".join( - [f"{v[0]}.{v[1]}" for v in sorted(_MINICPMO_SUPPORT_VERSION.keys())] - ) - raise ValueError( - f"Currently, MiniCPMO only supports versions " - f"{supported_versions}. Got version: {version}" - ) - - return instance_cls(vllm_config=vllm_config, prefix=prefix) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # This __init__ won't be called due to __new__ returning a different class - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) diff --git a/vllm_fl/models/qwen3_5.py b/vllm_fl/models/qwen3_5.py deleted file mode 100644 index ad3ec5f9..00000000 --- a/vllm_fl/models/qwen3_5.py +++ /dev/null @@ -1,951 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright 2025 The vLLM team. -# Copyright 2025 The Qwen Team. -"""Inference-only Qwen3.5 MoE model compatible with HuggingFace weights.""" - -import typing -from collections.abc import Callable, Iterable - -import torch -from einops import rearrange -from torch import nn - -from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3_5RMSNorm, -) -from vllm.model_executor.layers.linear import MergedColumnParallelLinear -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import IntermediateTensors - -from vllm.model_executor.models.interfaces import ( - HasInnerState, - IsHybrid, - MixtureOfExperts, - MultiModalEmbeddings, - SupportsLoRA, - SupportsPP, - _require_is_multimodal, -) -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - PPMissingLayer, - _merge_multimodal_embeddings, - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.model_executor.models.qwen3_vl import ( - Qwen3_VisionTransformer, - Qwen3VLDummyInputsBuilder, - Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, - Qwen3VLProcessingInfo, -) - -# from vllm_fl.models.qwen3_next import ( -from vllm.model_executor.models.qwen3_next import ( - Qwen3NextAttention, - Qwen3NextDecoderLayer, - Qwen3NextGatedDeltaNet, - Qwen3NextModel, - Qwen3NextSparseMoeBlock, - QwenNextMixtureOfExperts, -) -from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig -import vllm_fl.models.qwen3_next # for error ''_OpNamespace' 'vllm' object has no attribute 'gdn_attention_core'' - -logger = init_logger(__name__) - - -class Qwen3_5SparseMoeBlock(Qwen3NextSparseMoeBlock): - """Override SparseMoeBlock to read config from hf_text_config.""" - - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - # Temporarily patch hf_config to point to hf_text_config - # so the parent class reads the right config - original_hf_config = vllm_config.model_config.hf_config - vllm_config.model_config.hf_config = ( - vllm_config.model_config.hf_text_config - ) - try: - super().__init__(vllm_config=vllm_config, prefix=prefix) - finally: - vllm_config.model_config.hf_config = original_hf_config - - -class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet): - """Qwen3.5 uses MergedColumnParallelLinear for qkvz projection - and has a different forward pass without fix_query_key_value_ordering.""" - - def fix_query_key_value_ordering( - self, - mixed_qkvz: torch.Tensor, - mixed_ba: torch.Tensor, - ): - raise NotImplementedError( - "Qwen3.5 Series dont need to fix query key value ordering" - ) - - def create_qkvz_proj( - self, - hidden_size: int, - key_dim: int, - value_dim: int, - quant_config: QuantizationConfig | None, - prefix: str, - ) -> MergedColumnParallelLinear: - return MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[key_dim, key_dim, value_dim, value_dim], - bias=False, - quant_config=quant_config, - prefix=prefix, - ) - - def __init__(self, config, model_config=None, cache_config=None, - quant_config=None, speculative_config=None, prefix=""): - # Call grandparent init to skip Qwen3NextGatedDeltaNet.__init__ - # but set up the same attributes - nn.Module.__init__(self) - from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - ) - from vllm.config import get_current_vllm_config - from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - RowParallelLinear, - ) - from vllm.model_executor.layers.layernorm import RMSNormGated - from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - mamba_v2_sharded_weight_loader, - ) - from vllm.model_executor.model_loader.weight_utils import ( - sharded_weight_loader, - ) - from vllm.model_executor.utils import set_weight_attrs - from vllm.platforms import current_platform - from transformers.activations import ACT2FN - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ) - - # Convolution - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [query_key_settings, query_key_settings, value_settings], - self.tp_size, - self.tp_rank, - ) - }, - ) - - # Use MergedColumnParallelLinear for qkvz projection - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) - - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj_ba = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.num_v_heads, self.num_v_heads], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_ba", - ) - - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), - ) - self.A_log = nn.Parameter( - torch.empty(divide(self.num_v_heads, self.tp_size)), - ) - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=current_platform.current_device(), - dtype=getattr(config, "dtype", torch.bfloat16), - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - num_tokens = hidden_states.size(0) - - # Input Projection - Qwen3.5 style - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) - qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size - z_size = self.value_dim // self.tp_size - mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) - z = z.reshape(z.size(0), -1, self.head_v_dim) - ba, _ = self.in_proj_ba(hidden_states) - b, a = ba.chunk(2, dim=-1) - - b = b.contiguous() - a = a.contiguous() - - # Core Attention - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # Output Projection - z_shape_og = z.shape - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - -class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer): - def __init__( - self, - vllm_config: VllmConfig, - layer_type: str, - prefix: str = "", - ) -> None: - # Call nn.Module.__init__ directly, skipping Qwen3NextDecoderLayer's - super(Qwen3NextDecoderLayer, self).__init__() - - config = vllm_config.model_config.hf_text_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - speculative_config = vllm_config.speculative_config - - self.layer_type = layer_type - self.layer_idx = extract_layer_index(prefix) - - if self.layer_type == "linear_attention": - self.linear_attn = Qwen3_5GatedDeltaNet( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, - prefix=f"{prefix}.linear_attn", - ) - elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - raise ValueError(f"Invalid layer_type {self.layer_type}") - - # Qwen3.5 MoE: all layers use sparse MoE blocks - self.mlp = Qwen3_5SparseMoeBlock( - vllm_config=vllm_config, - prefix=f"{prefix}.mlp", - ) - - self.input_layernorm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - self.layer_scale = getattr(config, "layer_scale", False) - if self.layer_scale: - self.attn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, 1, config.hidden_size, - dtype=getattr(config, "dtype", torch.bfloat16), - ), - ) - self.ffn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, 1, config.hidden_size, - dtype=getattr(config, "dtype", torch.bfloat16), - ), - ) - - -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - } -) -class Qwen3_5Model(Qwen3NextModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super(Qwen3NextModel, self).__init__() - - config = vllm_config.model_config.hf_text_config - parallel_config = vllm_config.parallel_config - - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - - self.config = config - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - ) - - def get_layer(prefix: str): - return Qwen3_5DecoderLayer( - vllm_config, - layer_type=config.layer_types[extract_layer_index(prefix)], - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" - ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - ) - - if get_pp_group().is_last_rank: - self.norm = Qwen3_5RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - else: - self.norm = PPMissingLayer() - - def load_fused_expert_weights( - self, - name: str, - params_dict: dict, - loaded_weight: torch.Tensor, - shard_id: str, - num_experts: int, - ) -> bool: - param = params_dict[name] - weight_loader = typing.cast(Callable[..., bool], param.weight_loader) - loaded_local_expert = False - for expert_id in range(num_experts): - curr_expert_weight = loaded_weight[expert_id] - success = weight_loader( - param, - curr_expert_weight, - name, - shard_id, - expert_id, - return_success=True, - ) - if success: - loaded_local_expert = True - return loaded_local_expert - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - stacked_params_mapping = [ - # self attention - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - # mlp - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - # GDN - Qwen3.5 style - ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), - ("in_proj_qkvz", "in_proj_z", 3), - ("in_proj_ba", "in_proj_b", 0), - ("in_proj_ba", "in_proj_a", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - is_fused_expert = False - fused_expert_params_mapping = [ - ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), - ("experts.w2_weight", "experts.down_proj", 0, "w2"), - ] - num_experts = ( - self.config.num_experts - if hasattr(self.config, "num_experts") - else 0 - ) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if name.startswith("mtp."): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if ( - "experts.gate_up_proj" in name - or "experts.down_proj" in name - ): - is_fused_expert = True - expert_params_mapping = fused_expert_params_mapping - - if weight_name not in name: - continue - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - if isinstance(shard_id, tuple): - # Multi-shard: split loaded_weight and load each - layer = weight_loader.__self__ - split_sizes = [ - layer.output_sizes[s] for s in shard_id - ] - output_dim = getattr(param, "output_dim", 0) - parts = loaded_weight.split( - split_sizes, dim=output_dim - ) - for s, part in zip(shard_id, parts): - weight_loader(param, part, s) - else: - weight_loader(param, loaded_weight, shard_id) - break - else: - is_expert_weight = False - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - is_expert_weight = True - name_mapped = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name_mapped, self): - continue - if is_fused_expert: - if "experts.gate_up_proj" in name: - loaded_weight = loaded_weight.chunk(2, dim=-2) - success_w1 = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight[0], "w1", num_experts, - ) - success_w3 = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight[1], "w3", num_experts, - ) - success = success_w1 and success_w3 - else: - success = self.load_fused_expert_weights( - name_mapped, params_dict, - loaded_weight, shard_id, num_experts, - ) - if success: - name = name_mapped - break - else: - if ( - name_mapped.endswith(".bias") - or name_mapped.endswith("_bias") - ) and name_mapped not in params_dict: - continue - param = params_dict[name_mapped] - weight_loader = param.weight_loader - success = weight_loader( - param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - name = name_mapped - break - else: - if is_expert_weight: - continue - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - logger.warning_once( - f"Parameter {name} not found in params_dict, " - "skip loading" - ) - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts): - def _get_moe_model_layers(self): - """Get the model layers, handling both CausalLM and VL wrapper.""" - if hasattr(self, "language_model"): - return self.language_model.model.layers - return self.model.layers - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = ( - num_physical_experts - self.num_logical_experts - ) - for layer in self._get_moe_model_layers(): - if isinstance(layer.mlp, Qwen3_5SparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - def set_moe_parameters(self): - self.expert_weights = [] - self.moe_layers = [] - example_moe = None - for layer in self._get_moe_model_layers(): - if isinstance(layer, Qwen3_5DecoderLayer) and isinstance( - layer.mlp, (Qwen3NextSparseMoeBlock, Qwen3_5SparseMoeBlock) - ): - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError( - "No Qwen3_5 MoE layer found in the model.layers." - ) - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - -class Qwen3_5MoeForCausalLM( - nn.Module, - HasInnerState, - SupportsLoRA, - SupportsPP, - Qwen3_5_MoeMixtureOfExperts, - IsHybrid, -): - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_text_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3.5 currently does not support prefix caching" - ) - self.quant_config = vllm_config.quant_config - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.model = Qwen3_5Model( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ): - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.logits_processor(self.lm_head, hidden_states) - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - - def _remap_weights( - weights: Iterable[tuple[str, torch.Tensor]] - ) -> Iterable[tuple[str, torch.Tensor]]: - for name, tensor in weights: - # The HF checkpoint for Qwen3_5MoeForConditionalGeneration - # uses model.language_model.* for the text model, but our - # model structure uses model.* directly (no language_model - # wrapper level). - name = name.replace("model.language_model.", "model.") - yield name, tensor - - loader = AutoWeightsLoader( - self, - skip_prefixes=["mtp.", "model.visual."], - ) - return loader.load_weights(_remap_weights(weights)) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype, - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_text_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) - - -######################################################## -# Qwen3_5-MoE Multimodal (VL) Wrapper -######################################################## - - -class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo): - """Processing info that uses Qwen3_5MoeConfig instead of Qwen3VLConfig.""" - - def get_hf_config(self): - return self.ctx.get_hf_config(Qwen3_5MoeConfig) - - -@MULTIMODAL_REGISTRY.register_processor( - Qwen3VLMultiModalProcessor, - info=Qwen3_5MoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder, -) -class Qwen3_5MoeForConditionalGeneration( - Qwen3VLForConditionalGeneration, - HasInnerState, - Qwen3_5_MoeMixtureOfExperts, - IsHybrid, -): - """Multimodal Qwen3.5-MoE model wrapping VisionTransformer + CausalLM.""" - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): - nn.Module.__init__(self) - config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config - - self.config = config - self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.video_pruning_rate = multimodal_config.video_pruning_rate - self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled() - ) - - # Vision encoder - if not multimodal_config.get_limit_per_prompt( - "image" - ) and not multimodal_config.get_limit_per_prompt("video"): - self.visual = None - else: - self.visual = Qwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=maybe_prefix(prefix, "visual"), - ) - - # Language model (MoE CausalLM) - self.language_model = Qwen3_5MoeForCausalLM( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - # Deepstack support - self.use_deepstack = hasattr( - config.vision_config, "deepstack_visual_indexes" - ) and bool(config.vision_config.deepstack_visual_indexes) - self.deepstack_num_level = ( - len(config.vision_config.deepstack_visual_indexes) - if self.use_deepstack - else 0 - ) - if self.use_deepstack and self.visual is not None: - self.deepstack_input_embeds = [ - torch.zeros( - vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size, - ) - for _ in range(self.deepstack_num_level) - ] - else: - self.deepstack_input_embeds = None - self.visual_dim = config.vision_config.out_hidden_size - self.multiscale_dim = self.visual_dim * self.deepstack_num_level - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids( - self, - input_ids: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings | None = None, - *, - is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, - ) -> torch.Tensor: - inputs_embeds = self._embed_text_input_ids( - input_ids, - self.language_model.embed_input_ids, - is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, - ) - - if multimodal_embeddings is None or len(multimodal_embeddings) == 0: - return inputs_embeds - - is_multimodal = _require_is_multimodal(is_multimodal) - - inputs_embeds = _merge_multimodal_embeddings( - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, - ) - - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor | IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights( - self, weights: Iterable[tuple[str, torch.Tensor]] - ) -> set[str]: - skip_prefixes = ["mtp."] - if self.visual is None: - skip_prefixes.append("visual.") - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.language_model.model.get_expert_mapping() - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype, - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_text_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) diff --git a/vllm_fl/models/qwen3_next.py b/vllm_fl/models/qwen3_next.py deleted file mode 100644 index f3afa3fe..00000000 --- a/vllm_fl/models/qwen3_next.py +++ /dev/null @@ -1,1396 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Inference-only Qwen3Next model.""" - -from collections.abc import Iterable -from itertools import islice - -import torch -from einops import rearrange -from torch import nn -from transformers.activations import ACT2FN -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.layer import Attention -from vllm.compilation.decorators import support_torch_compile -from vllm.config import ( - CacheConfig, - ModelConfig, - SpeculativeConfig, - VllmConfig, - get_current_vllm_config, -) -from vllm.distributed import ( - divide, - get_ep_group, - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, -) -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.fused_moe.config import RoutingMethodType -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3NextRMSNorm, -) -from vllm.model_executor.layers.layernorm import RMSNormGated -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader -from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, - MambaStateShapeCalculator, -) -from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, - causal_conv1d_update, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - sharded_weight_loader, -) -from vllm.model_executor.models.interfaces import ( - HasInnerState, - IsHybrid, - MixtureOfExperts, - SupportsLoRA, - SupportsPP, -) -from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - PPMissingLayer, - extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, - sequence_parallel_chunk, -) -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs import Qwen3NextConfig -from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm_fl.ops.fla import ChunkGatedDeltaRuleOp, FusedRecurrentGatedDeltaRuleOp - -logger = init_logger(__name__) - -KVCache = tuple[torch.Tensor, torch.Tensor] - -class Qwen3NextSparseMoeBlock(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - parallel_config = vllm_config.parallel_config - quant_config = vllm_config.quant_config - - self.tp_size = get_tensor_model_parallel_world_size() - - self.ep_group = get_ep_group().device_group - self.ep_rank = get_ep_group().rank_in_group - self.ep_size = self.ep_group.size() - self.n_routed_experts = config.num_experts - - self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe - - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}." - ) - - # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = parallel_config.enable_eplb - - self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts - self.n_local_physical_experts = self.n_physical_experts // self.ep_size - - self.physical_expert_start = self.ep_rank * self.n_local_physical_experts - self.physical_expert_end = ( - self.physical_expert_start + self.n_local_physical_experts - ) - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - - if config.shared_expert_intermediate_size > 0: - self.shared_expert = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.shared_expert_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - expert_gate=self.shared_expert_gate, - prefix=f"{prefix}.shared_expert", - ) - else: - self.shared_expert = None - - self.experts = SharedFusedMoE( - shared_experts=self.shared_expert, - gate=self.gate, - num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel, - routing_method_type=RoutingMethodType.Renormalize, - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - - if self.is_sequence_parallel: - hidden_states = sequence_parallel_chunk(hidden_states) - - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=hidden_states - ) - else: - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) - - if self.shared_expert is not None: - final_hidden_states = final_hidden_states[0] + final_hidden_states[1] - - if self.is_sequence_parallel: - final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0 - ) - final_hidden_states = final_hidden_states[:num_tokens] - elif self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states - ) - - return final_hidden_states.view(orig_shape) - - -class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): - @property - def mamba_type(self) -> str: - return "gdn_attention" - - def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype - ) - - def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - self.conv_kernel_size, - self.num_spec, - ) - - def __init__( - self, - config: Qwen3NextConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - speculative_config: SpeculativeConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.hidden_size = config.hidden_size - self.num_v_heads = config.linear_num_value_heads - self.num_k_heads = config.linear_num_key_heads - self.head_k_dim = config.linear_key_head_dim - self.head_v_dim = config.linear_value_head_dim - self.key_dim = self.head_k_dim * self.num_k_heads - self.value_dim = self.head_v_dim * self.num_v_heads - - self.conv_kernel_size = config.linear_conv_kernel_dim - self.layer_idx = extract_layer_index(prefix) - self.activation = config.hidden_act - self.act = ACT2FN[config.hidden_act] - self.layer_norm_epsilon = config.rms_norm_eps - self.prefix = prefix - - self.config = config - self.model_config = model_config - self.cache_config = cache_config - self.quant_config = quant_config - self.speculative_config = speculative_config - self.num_spec = ( - self.speculative_config.num_speculative_tokens - if self.speculative_config - else 0 - ) - - # QKV - self.conv_dim = self.key_dim * 2 + self.value_dim - self.conv1d = ColumnParallelLinear( - input_size=self.conv_kernel_size, - output_size=self.conv_dim, - bias=False, - prefix=f"{prefix}.conv1d", - ) - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - - # projection of the input hidden states - self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 - self.projection_size_ba = self.num_v_heads * 2 - self.in_proj_qkvz = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.projection_size_qkvz, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) - # ba_proj doesn't support blockwise fp8 quantization. - self.in_proj_ba = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.projection_size_ba, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_ba", - ) - - query_key_settings = (self.key_dim, 0, False) - value_settings = (self.value_dim, 0, False) - - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": mamba_v2_sharded_weight_loader( - [ - query_key_settings, - query_key_settings, - value_settings, - ], - self.tp_size, - self.tp_rank, - ) - }, - ) - - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), - ) - self.A_log = nn.Parameter( - torch.empty( - divide(self.num_v_heads, self.tp_size), - ) - ) - - set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) - - self.norm = RMSNormGated( - self.head_v_dim, - eps=self.layer_norm_epsilon, - group_size=None, - norm_before_gate=True, - device=current_platform.current_device(), - dtype=config.dtype, - ) - - self.out_proj = RowParallelLinear( - self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp( - output_final_state = True, - use_qk_l2norm_in_kernel=True, - ) - - self.fused_recurrent_gated_delta_rule_multi_query = FusedRecurrentGatedDeltaRuleOp( - inplace_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - self.fused_recurrent_gated_delta_rule_remain_query = FusedRecurrentGatedDeltaRuleOp( - inplace_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - - def fix_query_key_value_ordering( - self, - mixed_qkvz, - mixed_ba, - ): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. - """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - ( - self.head_k_dim - + self.head_k_dim - + (self.head_v_dim + self.head_v_dim) - * self.num_v_heads - // self.num_k_heads - ), - ) - new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.tp_size, - 2 * self.num_v_heads // self.num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - self.head_k_dim, - self.head_k_dim, - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - ] - split_arg_list_ba = [ - self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads, - ] - - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], - # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b = b.reshape(b.size(0), self.num_v_heads // self.tp_size) - a = a.reshape(a.size(0), self.num_v_heads // self.tp_size) - - return query, key, value, z, b, a - - def rearrange_mixed_qkv(self, mixed_qkv): - if mixed_qkv is None: - return None, None, None - query, key, value = torch.split( - mixed_qkv, - [ - self.key_dim // self.tp_size, - self.key_dim // self.tp_size, - self.value_dim // self.tp_size, - ], - dim=-1, - ) - query, key = map( - lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), - (query, key), - ) - value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) - return query, key, value - - def forward( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - ): - """ - Forward pass with three parts: - 1. Input projection - 2. Core attention (custom op) - 3. Output projection - """ - num_tokens = hidden_states.size(0) - - # ============================================================ - # Part 1: Input Projection - # ============================================================ - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba - ) - query, key, value = map( - lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) - ) - mixed_qkv = torch.cat((query, key, value), dim=-1) - - # ============================================================ - # Part 2: Core Attention (Custom Op) - # ============================================================ - # Note: we should not use torch.empty here like other attention backends, - # see discussions in https://github.com/vllm-project/vllm/pull/28182 - core_attn_out = torch.zeros( - (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - torch.ops.vllm.gdn_attention_core( - mixed_qkv, - b, - a, - core_attn_out, - self.prefix, - ) - - # ============================================================ - # Part 3: Output Projection - # ============================================================ - z_shape_og = z.shape - # Reshape input data into 2D tensor - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") - output[:num_tokens], _ = self.out_proj(core_attn_out) - - def _forward_core( - self, - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - ): - """ - Core attention computation (called by custom op). - """ - forward_context = get_forward_context() - attn_metadata: AttentionMetadata = forward_context.attn_metadata - - if attn_metadata is None: - # V1 profile run - return - - assert isinstance(attn_metadata, dict) - attn_metadata = attn_metadata[self.prefix] - assert isinstance(attn_metadata, GDNAttentionMetadata) - has_initial_state = attn_metadata.has_initial_state - spec_query_start_loc = attn_metadata.spec_query_start_loc - non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc - spec_sequence_masks = attn_metadata.spec_sequence_masks - spec_token_indx = attn_metadata.spec_token_indx - non_spec_token_indx = attn_metadata.non_spec_token_indx - spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501 - non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501 - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0].transpose(-1, -2) - ssm_state = self_kv_cache[1] - num_actual_tokens = attn_metadata.num_actual_tokens - num_accepted_tokens = attn_metadata.num_accepted_tokens - - mixed_qkv = mixed_qkv[:num_actual_tokens] - b = b[:num_actual_tokens] - a = a[:num_actual_tokens] - - # 1. Convolution sequence transformation - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) - - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - mixed_qkv_spec = mixed_qkv - mixed_qkv_non_spec = None - else: - mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx) - mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx) - else: - mixed_qkv_spec = None - mixed_qkv_non_spec = mixed_qkv - - # 1.1: Process the multi-query part - if spec_sequence_masks is not None: - mixed_qkv_spec = causal_conv1d_update( - mixed_qkv_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0][ - : attn_metadata.num_spec_decodes - ], - num_accepted_tokens=num_accepted_tokens, - query_start_loc=spec_query_start_loc, - max_query_len=spec_state_indices_tensor.size(-1), - validate_data=False, - ) - - # 1.2: Process the remaining part - if attn_metadata.num_prefills > 0: - mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1) - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" - mixed_qkv_non_spec = causal_conv1d_fn( - mixed_qkv_non_spec_T, - conv_weights, - self.conv1d.bias, - activation=self.activation, - conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=non_spec_state_indices_tensor, - query_start_loc=non_spec_query_start_loc, - metadata=attn_metadata, - ).transpose(0, 1) - elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update( - mixed_qkv_non_spec, - conv_state, - conv_weights, - self.conv1d.bias, - self.activation, - conv_state_indices=non_spec_state_indices_tensor[ - : attn_metadata.num_actual_tokens - ], - validate_data=True, - ) - else: - mixed_qkv_non_spec = None - - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) - query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec - ) - - g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias) - - if spec_sequence_masks is not None: - if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: - g_spec = g - beta_spec = beta - g_non_spec = None - beta_non_spec = None - else: - g_spec = g.index_select(1, spec_token_indx) - beta_spec = beta.index_select(1, spec_token_indx) - g_non_spec = g.index_select(1, non_spec_token_indx) - beta_non_spec = beta.index_select(1, non_spec_token_indx) - else: - g_spec = None - beta_spec = None - g_non_spec = g - beta_non_spec = beta - - # 2. Recurrent attention - - # 2.1: Process the multi-query part - if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = self.fused_recurrent_gated_delta_rule_multi_query( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - ) - else: - core_attn_out_spec, last_recurrent_state = None, None - - # 2.2: Process the remaining part - if attn_metadata.num_prefills > 0: - initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() - initial_state[~has_initial_state, ...] = 0 - ( - core_attn_out_non_spec, - last_recurrent_state, - ) = self.chunk_gated_delta_rule( - q=query_non_spec.contiguous(), - k=key_non_spec.contiguous(), - v=value_non_spec.contiguous(), - g=g_non_spec, - beta=beta_non_spec, - initial_state=initial_state, - cu_seqlens=non_spec_query_start_loc, - ) - # Init cache - ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype - ) - elif attn_metadata.num_decodes > 0: - core_attn_out_non_spec, last_recurrent_state = ( - self.fused_recurrent_gated_delta_rule_remain_query( - q=query_non_spec, - k=key_non_spec, - v=value_non_spec, - g=g_non_spec, - beta=beta_non_spec, - initial_state=ssm_state, - cu_seqlens=non_spec_query_start_loc[ - : attn_metadata.num_decodes + 1 - ], - ssm_state_indices=non_spec_state_indices_tensor, - ) - ) - else: - core_attn_out_non_spec, last_recurrent_state = None, None - - # 3. Merge core attention output - if spec_sequence_masks is not None and core_attn_out_non_spec is not None: - merged_out = torch.empty( - (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), - dtype=core_attn_out_non_spec.dtype, - device=core_attn_out_non_spec.device, - ) - merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) - merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) - core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) - elif spec_sequence_masks is not None: - core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) - else: - core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) - - -class Qwen3NextAttention(nn.Module): - def __init__( - self, - config: Qwen3NextConfig, - model_config: ModelConfig | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.dual_chunk_attention_config = getattr( - config, "dual_chunk_attention_config", None - ) - self.attn_output_gate = getattr(config, "attn_output_gate", True) - - self.qkv_proj = QKVParallelLinear( - config.hidden_size, - self.head_dim, - self.total_num_heads * (1 + self.attn_output_gate), - self.total_num_kv_heads, - bias=getattr(config, "qkv_bias", False), - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - self.rotary_emb = get_rope( - head_size=self.head_dim, - max_position=config.max_position_embeddings, - rope_parameters=config.rope_parameters, - dual_chunk_attention_config=self.dual_chunk_attention_config, - ) - - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - **{ - "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": self.dual_chunk_attention_config, - } - if self.dual_chunk_attention_config - else {}, - ) - - self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - output: torch.Tensor, - hidden_states: torch.Tensor, - ): - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn_output_gate: - q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 - ) - orig_shape = q_gate.shape[:-1] - q_gate = q_gate.view(*orig_shape, self.num_heads, -1) - q, gate = torch.chunk(q_gate, 2, dim=-1) - q = q.reshape(*orig_shape, -1) - gate = gate.reshape(*orig_shape, -1) - else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( - -1, self.num_heads * self.head_dim - ) - k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( - -1, self.num_kv_heads * self.head_dim - ) - - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - - if self.attn_output_gate: - gate = torch.sigmoid(gate) - attn_output = attn_output * gate - - output[:], _ = self.o_proj(attn_output) - - -class Qwen3NextDecoderLayer(nn.Module): - def __init__( - self, - vllm_config: VllmConfig, - layer_type: str, - prefix: str = "", - ) -> None: - super().__init__() - - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - speculative_config = vllm_config.speculative_config - - self.layer_type = layer_type - self.layer_idx = extract_layer_index(prefix) - - if self.layer_type == "linear_attention": - self.linear_attn = Qwen3NextGatedDeltaNet( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - speculative_config=speculative_config, - prefix=f"{prefix}.linear_attn", - ) - elif self.layer_type == "full_attention": - self.self_attn = Qwen3NextAttention( - config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - else: - raise ValueError(f"Invalid layer_type {self.layer_type}") - - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers - ) - if (self.layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 - and (self.layer_idx + 1) % config.decoder_sparse_step == 0 - ): - self.mlp = Qwen3NextSparseMoeBlock( - vllm_config=vllm_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = Qwen3NextMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - - self.input_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - - self.layer_scale = getattr(config, "layer_scale", False) - if self.layer_scale: - self.attn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.dtype, - ), - ) - self.ffn_layer_scale = torch.nn.Parameter( - torch.zeros( - 1, - 1, - config.hidden_size, - dtype=config.dtype, - ), - ) - - def forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor | None, - positions: torch.Tensor = None, - **kwargs: object, - ): - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) - - self_attention_output = torch.empty_like(hidden_states) - if self.layer_type == "linear_attention": - self.linear_attn( - hidden_states=hidden_states, - output=self_attention_output, - ) - elif self.layer_type == "full_attention": - self.self_attn( - hidden_states=hidden_states, - output=self_attention_output, - positions=positions, - ) - else: - raise ValueError("Invalid layer_type") - hidden_states = self_attention_output - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 - ) - else: - hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype) + 1 - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - - hidden_states = self.mlp(hidden_states) - - if self.layer_scale: - if len(hidden_states.shape) == 2: - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 - ) - else: - assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( - f"shape must be the same {len(hidden_states.shape)}, " - f"{len(self.ffn_layer_scale.shape)}" - ) - hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype) + 1 - ) - - return hidden_states, residual - - -@support_torch_compile -class Qwen3NextModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config: Qwen3NextConfig = vllm_config.model_config.hf_config - parallel_config = vllm_config.parallel_config - - eplb_config = parallel_config.eplb_config - self.num_redundant_experts = eplb_config.num_redundant_experts - - self.config = config - - self.vocab_size = config.vocab_size - - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - ) - - def get_layer(prefix: str): - return Qwen3NextDecoderLayer( - vllm_config, - layer_type=config.layer_types[extract_layer_index(prefix)], - prefix=prefix, - ) - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" - ) - self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - - if get_pp_group().is_last_rank: - self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embed_input_ids(input_ids) - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer( - positions=positions, - hidden_states=hidden_states, - residual=residual, - ) - - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts, - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - expert_params_mapping = self.get_expert_mapping() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - if name.startswith("mtp."): - continue - - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - - if "mlp.experts" in name: - continue - - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # name = apply_attn_prefix(name, params_dict) - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip layers on other devices. - if is_pp_missing_parameter(name, self): - continue - # Skip loading extra bias for GPTQ models. - if ( - name.endswith(".bias") or name.endswith("_bias") - ) and name not in params_dict: - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - if name not in params_dict: - logger.warning_once( - f"Parameter {name} not found in params_dict, skip loading" - ) - continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class QwenNextMixtureOfExperts(MixtureOfExperts): - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for layer in self.model.layers: - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() - - def set_moe_parameters(self): - self.expert_weights = [] - - self.moe_layers = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( - layer.mlp, Qwen3NextSparseMoeBlock - ): - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError("No Qwen3Next layer found in the model.layers.") - - # Set MoE hyperparameters - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - -class Qwen3NextForCausalLM( - nn.Module, - HasInnerState, - SupportsLoRA, - SupportsPP, - QwenNextMixtureOfExperts, - IsHybrid, -): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - config = vllm_config.model_config.hf_config - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - - scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, ( - "Qwen3Next currently does not support prefix caching" - ) - self.quant_config = vllm_config.quant_config - - super().__init__() - self.config = config - self.scheduler_config = scheduler_config - self.model = Qwen3NextModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - - # Set MoE hyperparameters - self.set_moe_parameters() - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ): - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - - return hidden_states - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, int], tuple[int, int]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, - hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, - hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, - hf_config.linear_conv_kernel_dim, - num_spec, - ) - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.logits_processor(self.lm_head, hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=["mtp."], - ) - return loader.load_weights(weights) - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() - - -def gdn_attention_core( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - layer_name: str, -) -> None: - """ - Custom op for the core attention computation. - Only handles the convolution + recurrent attention part. - Input/output projections are handled outside this op. - """ - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - self._forward_core( - mixed_qkv=mixed_qkv, - b=b, - a=a, - core_attn_out=core_attn_out, - ) - - -def gdn_attention_core_fake( - mixed_qkv: torch.Tensor, - b: torch.Tensor, - a: torch.Tensor, - core_attn_out: torch.Tensor, - layer_name: str, -) -> None: - """Fake implementation for torch.compile.""" - return - - -if not hasattr(torch.ops.vllm, "gdn_attention_core"): - direct_register_custom_op( - op_name="gdn_attention_core", - op_func=gdn_attention_core, - mutates_args=["core_attn_out"], - fake_impl=gdn_attention_core_fake, - ) - - -@triton.jit -def fused_gdn_gating_kernel( - g, - beta_output, - A_log, - a, - b, - dt_bias, - seq_len, - NUM_HEADS: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, - BLK_HEADS: tl.constexpr, -): - i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) - head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off - mask = head_off < NUM_HEADS - blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_b = tl.load(b + off, mask=mask) - blk_bias = tl.load(dt_bias + head_off, mask=mask) - # If the model is loaded in fp16, without the .float() here, A might be -inf - x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where( - beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x - ) - blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x - tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) - # compute beta_output = sigmoid(b) - blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) - tl.store( - beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask - ) - - -def fused_gdn_gating( - A_log: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - dt_bias: torch.Tensor, - beta: float = 1.0, - threshold: float = 20.0, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fused computation of g and beta for Gated Delta Net. - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - beta_output = b.sigmoid() - TODO maybe use torch.compile to replace this triton kernel - """ - batch, num_heads = a.shape - seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) - beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device) - fused_gdn_gating_kernel[grid]( - g, - beta_output, - A_log, - a, - b, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1, - ) - return g, beta_output diff --git a/vllm_fl/ops/custom_ops.py b/vllm_fl/ops/custom_ops.py index 80ba8f07..0485156a 100644 --- a/vllm_fl/ops/custom_ops.py +++ b/vllm_fl/ops/custom_ops.py @@ -68,11 +68,6 @@ def register_oot_ops(whitelist: Optional[List[str]] = None) -> None: logger.warning(f"OOT op '{op_name}' not found in OOT_OPS, skipping.") continue - # unquantized_fused_moe_method only registers when use_flaggems_op is True - if op_name == "unquantized_fused_moe_method" and not use_flaggems_op(op_name): - logger.debug(f"Skipping '{op_name}': use_flaggems_op returned False") - continue - op_cls, registration_name = OOT_OPS[op_name] logger.info(f"Registering oot op: {op_name} as '{registration_name}'") CustomOp.register_oot(_decorated_op_cls=op_cls, name=registration_name) diff --git a/vllm_fl/ops/fused_moe/fused_moe.py b/vllm_fl/ops/fused_moe/fused_moe.py index b35dc3c5..72465aa9 100644 --- a/vllm_fl/ops/fused_moe/fused_moe.py +++ b/vllm_fl/ops/fused_moe/fused_moe.py @@ -8,7 +8,6 @@ import functools import torch import torch.nn.functional as F -import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -17,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( _get_config_quant_dtype, try_get_optimal_moe_config, - invoke_fused_moe_kernel, + dispatch_fused_moe_kernel, ) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.triton_utils import tl @@ -109,7 +108,9 @@ def fused_experts_impl( top_k_num = topk_ids.size(1) # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + # Note: upstream removed VLLM_FUSED_MOE_CHUNK_SIZE in v0.18.1; + # keep the old default (64K) for the FL chunked implementation. + CHUNK_SIZE = 65536 M = min(num_tokens, CHUNK_SIZE) config_dtype = _get_config_dtype_str( @@ -209,7 +210,7 @@ def fused_experts_impl( ignore_invalid_experts=True, ) - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qcurr_hidden_states, w1, intermediate_cache1, @@ -235,6 +236,9 @@ def fused_experts_impl( # Activation function with multiplication # todo: dispatch to flag_gems and other backends + # v0.18.1: activation may be a MoEActivation enum; normalize to str + if hasattr(activation, "value"): + activation = activation.value if activation == "silu": intermediate_cache2 = call_op( "silu_and_mul", None, intermediate_cache1.view(-1, N) @@ -262,7 +266,7 @@ def fused_experts_impl( block_shape=block_shape, ) - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py index 3d3368b3..740dcd59 100644 --- a/vllm_fl/ops/fused_moe/layer.py +++ b/vllm_fl/ops/fused_moe/layer.py @@ -11,8 +11,8 @@ import vllm.envs as envs from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod -from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator -from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk +from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import RoutingSimulator +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import grouped_topk from vllm.platforms import current_platform from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( @@ -21,7 +21,7 @@ if current_platform.is_cuda_alike(): - from vllm.model_executor.layers.fused_moe.fused_moe import ( + from vllm.model_executor.layers.fused_moe.router.base_router import ( eplb_map_to_physical_and_record, ) else: @@ -76,7 +76,30 @@ def forward_oot( else: return result - forward_native = forward_oot + def forward_native( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + """v0.18.1 forward_native signature — called when custom ops are + disabled (e.g. under torch.compile / Inductor). + Note: shared experts are handled by the upstream runner + (_apply_quant_method), so we must NOT return a tuple here.""" + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + ) class FusedMoEFL(FusedMoE): @@ -127,7 +150,7 @@ def select_experts( plain MoE implementations without redundant experts. """ from vllm_fl.ops.fused_moe.fused_moe import fused_topk - from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk_bias + from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import fused_topk_bias if self.enable_eplb: if self.quant_method.supports_eplb: diff --git a/vllm_fl/patches/glm_moe_dsa.py b/vllm_fl/patches/glm_moe_dsa.py index 4d205537..df8ec255 100644 --- a/vllm_fl/patches/glm_moe_dsa.py +++ b/vllm_fl/patches/glm_moe_dsa.py @@ -52,25 +52,6 @@ def _patched_set(self, special_tokens=None): pass -def patch_is_deepseek_mla(): - """Patch ModelConfig.is_deepseek_mla to recognise glm_moe_dsa as MLA.""" - from vllm.config.model import ModelConfig - _orig_is_mla = ModelConfig.is_deepseek_mla.fget - - @property - def _patched_is_mla(self): - if ( - hasattr(self.hf_text_config, "model_type") - and self.hf_text_config.model_type == "glm_moe_dsa" - and getattr(self.hf_text_config, "kv_lora_rank", None) - is not None - ): - return True - return _orig_is_mla(self) - - ModelConfig.is_deepseek_mla = _patched_is_mla - - def patch_fp8_mqa_logits_dim(): """Fix k_scale dim mismatch for deep_gemm fp8_mqa_logits. @@ -222,6 +203,5 @@ def _patched_forward(self, hidden_states, qr, positions, rotary_emb): def apply_model_patches(): """All GLM-5 patches needed at model registration time.""" - patch_is_deepseek_mla() patch_indexer_schedule_metadata() patch_indexer_rope_reshape() diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 8665202e..45882cc8 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -1,11 +1,11 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/platforms/cuda.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/platforms/cuda.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch @@ -16,20 +16,20 @@ except (ImportError, OSError): pass # NPU or other platforms may not have vllm._C -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import Platform, PlatformEnum from vllm.platforms.interface import DeviceCapability +from vllm.v1.attention.backends.registry import AttentionBackendEnum if TYPE_CHECKING: - from vllm.attention.selector import AttentionSelectorConfig from vllm.config import VllmConfig from vllm.config.cache import CacheDType + from vllm.v1.attention.selector import AttentionSelectorConfig else: VllmConfig = None CacheDType = None -from vllm_fl.utils import DeviceInfo, get_device_name, get_device_type +from vllm_fl.utils import DeviceInfo logger = init_logger(__name__) @@ -46,8 +46,12 @@ class PlatformFL(Platform): _enum = PlatformEnum.OOT device_info = DeviceInfo() vendor_name = device_info.vendor_name - device_type = get_device_type(vendor_name) - device_name = get_device_name(vendor_name) + # cuda_alike (nvidia/metax): device_name = vendor_name (not used in torch.device) + # non-cuda_alike (iluvatar/ascend): device_name = device_type (used in torch.device) + device_name = device_info.vendor_name if ( + device_info.device_type == "cuda" and device_info.vendor_name != "iluvatar" + ) else device_info.device_type + device_type = device_info.device_type dispatch_key = device_info.dispatch_key torch_device_fn = device_info.torch_device_fn ray_device_key: str = "GPU" @@ -80,7 +84,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): @classmethod def get_current_memory_usage( - cls, device: Optional[torch.types.Device] = None + cls, device: torch.types.Device | None = None ) -> float: cls.torch_device_fn.empty_cache() cls.torch_device_fn.reset_peak_memory_stats(device) @@ -112,6 +116,9 @@ def is_pin_memory_available(cls): def import_kernels(cls) -> None: """Import device-specific kernels.""" logger.info(f"current vendor_name is: {cls.vendor_name}") + # Always load base vLLM C extensions + super().import_kernels() + if cls.vendor_name == "metax": try: import mcoplib._C # noqa: F401 @@ -198,9 +205,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: @classmethod def get_attn_backend_cls( cls, - selected_backend: "AttentionBackendEnum", + selected_backend: "AttentionBackendEnum | None", attn_selector_config: "AttentionSelectorConfig", - ) -> list[str]: + num_heads: int | None = None, + ) -> str: """Get the attention backend class path using the dispatch mechanism.""" from vllm_fl.dispatch import call_op @@ -234,8 +242,8 @@ def get_vit_attn_backend( cls, head_size: int, dtype: torch.dtype, - backend: Optional["AttentionBackendEnum"] = None, - ) -> list[str]: + backend: "AttentionBackendEnum | None" = None, + ) -> "AttentionBackendEnum": from vllm_fl.attention.utils import patch_mm_encoder_attention patch_mm_encoder_attention() @@ -324,10 +332,33 @@ def use_custom_allreduce(cls) -> bool: return True @classmethod - def pre_register_and_update(cls, parser = None) -> None: + def pre_register_and_update(cls, parser=None) -> None: if cls.device_name == "npu": import vllm_fl.dispatch.backends.vendor.ascend + @classmethod + def supports_fp8(cls) -> bool: + return cls.has_device_capability(89) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + if cls.device_type == "npu": + return cls.torch_device_fn.get_device_properties( + device_id + ).total_memory + device_props = torch.cuda.get_device_properties(device_id) + return device_props.total_memory + + @classmethod + def use_custom_op_collectives(cls) -> bool: + if cls.vendor_name == "nvidia": + return True + return False + + @classmethod + def num_compute_units(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties(device_id).multi_processor_count + @classmethod def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index 198da881..2b958cad 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -11,64 +11,6 @@ _OP_CONFIG: Optional[dict[str, str]] = None -# Mapping used by dispatch registration to resolve the current runtime platform -# into a backend directory under dispatch/backends/vendor. -# -# Field definitions and sources: -# - top-level key (vendor_name): normalized hardware vendor identifier. -# Source: runtime platform detection (current_platform.vendor_name) and -# fallback device probing (DeviceInfo.vendor_name). -# - device_type: compute class reported by runtime, such as "cuda" or "npu". -# Source: runtime platform detection (current_platform.device_type) and -# fallback device probing (DeviceInfo.device_type). -# - device_name: runtime device family/product alias used by vLLM platform. -# Source: runtime platform detection (current_platform.device_name). -# -# Values are normalized to lowercase and matched against available backend -# subdirectories (for example, cuda/ascend/metax/iluvatar). -VENDOR_DEVICE_MAP: dict[str, dict[str, str]] = { - # Registered backend: vendor/cuda - "nvidia": {"device_type": "cuda", "device_name": "nvidia"}, - # Registered backend: vendor/ascend - "ascend": {"device_type": "npu", "device_name": "npu"}, - # Registered backend: vendor/iluvatar - "iluvatar": {"device_type": "cuda", "device_name": "cuda"}, - # Registered backend: vendor/metax - "metax": {"device_type": "cuda", "device_name": "metax"}, -} - - -def _get_vendor_device_field(vendor_name: str, field: str) -> str: - """Get a required field from VENDOR_DEVICE_MAP for the specified vendor.""" - if not isinstance(vendor_name, str) or not vendor_name.strip(): - raise ValueError("vendor_name must be a non-empty string.") - - normalized_vendor = vendor_name - device_info = VENDOR_DEVICE_MAP.get(normalized_vendor) - if not isinstance(device_info, dict): - raise ValueError( - f"Vendor '{normalized_vendor}' not found in VENDOR_DEVICE_MAP." - ) - - value = device_info.get(field) - if not isinstance(value, str) or not value.strip(): - raise ValueError( - f"Field '{field}' for vendor '{normalized_vendor}' is missing " - "or empty in VENDOR_DEVICE_MAP." - ) - return value - - -def get_device_type(vendor_name: str) -> str: - """Return the configured device_type for the given vendor.""" - return _get_vendor_device_field(vendor_name, "device_type") - - -def get_device_name(vendor_name: str) -> str: - """Return the configured device_name for the given vendor.""" - return _get_vendor_device_field(vendor_name, "device_name") - - def use_flaggems(default: bool = True) -> bool: if os.environ.get("VLLM_FL_PREFER_ENABLED", "True").lower() not in ("true", "1"): return False diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 155c0316..bb32fb24 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -1,19 +1,20 @@ -# Mainly adopted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py +# Copyright (c) 2025 BAAI. All rights reserved. +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import functools import gc import itertools +import threading import time from collections import defaultdict -from collections.abc import Iterator, Sequence -from contextlib import contextmanager, nullcontext +from collections.abc import Callable, Iterable, Iterator, Sequence +from contextlib import contextmanager from copy import copy, deepcopy +from dataclasses import dataclass, replace from functools import reduce -from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -23,22 +24,18 @@ from tqdm import tqdm import vllm.envs as envs -from vllm.attention.backends.abstract import( - AttentionBackend, - AttentionMetadata, - AttentionType, - MultipleOf) -from vllm.attention.layer import Attention, MLAAttention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphStat #CUDAGraphWrapper +from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import ( CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config, update_config, ) +from vllm.config.cache import CacheConfig from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group @@ -47,22 +44,32 @@ get_dcp_group, get_pp_group, get_tp_group, + graph_capture, is_global_first_rank, prepare_communication_buffer_for_model, - GraphCaptureContext ) from vllm.forward_context import ( BatchDescriptor, set_forward_context, ) from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping, LoRAMappingType +from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + RoutedExpertsCapturer, +) from vllm.model_executor.layers.rotary_embedding import ( MRotaryEmbedding, XDRotaryEmbedding, ) -from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.reload import ( + finalize_layerwise_reload, + initialize_layerwise_reload, +) from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal, SupportsXDRoPE, @@ -70,6 +77,7 @@ supports_eagle3, supports_mrope, supports_multimodal_pruning, + supports_realtime, supports_transcription, supports_xdrope, ) @@ -78,75 +86,50 @@ is_pooling_model, is_text_generation_model, ) +from vllm.model_executor.offloader import ( + create_offloader, + get_offloader, + set_offloader, +) from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.encoder_budget import MultiModalBudget from vllm.multimodal.inputs import ( BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange, ) -from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.multimodal.utils import group_and_batch_mm_kwargs from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.tracing import instrument from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.utils.jsontree import json_map_leaves from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.mem_constants import GiB_bytes -from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.nvtx_pytorch_hooks import PytHooks -from vllm.utils.platform_utils import is_pin_memory_available - -from vllm.platforms import current_platform -if current_platform.dist_backend == "flagcx": - @contextmanager - def graph_capture(device: torch.device): - """ - `graph_capture` is a context manager which should surround the code that - is capturing the NPU graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph - is replayed. It returns a `GraphCaptureContext` object which contains the - necessary data for the graph capture. Currently, it only contains the - stream that the graph capture is running on. This stream is set to the - current NPU stream when the context manager is entered and reset to the - default stream when the context manager is exited. This is to ensure that - the graph capture is running on a separate stream from the default stream, - in order to explicitly distinguish the kernels to capture - from other kernels possibly launched on background in the default stream. - """ - graph_capture_context = GraphCaptureContext( - current_platform.torch_device_fn.Stream(device=device)) - stream = graph_capture_context.stream - - # we use nullcontext now - maybe_ca_context = nullcontext() - - # ensure all initialization operations complete before attempting to - # capture the graph on another stream - curr_stream = current_platform.torch_device_fn.current_stream() - if curr_stream != stream: - stream.wait_stream(curr_stream) - - with current_platform.torch_device_fn.stream(stream), maybe_ca_context: - yield graph_capture_context -else: - from vllm.distributed.parallel_state import graph_capture +from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.torch_utils import ( get_dtype_size, kv_cache_dtype_str_to_dtype, - supports_dynamo, ) -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import ( +from vllm.v1.attention.backend import ( + AttentionBackend, AttentionCGSupport, + AttentionMetadata, AttentionMetadataBuilder, + AttentionType, CommonAttentionMetadata, +) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( create_fast_prefill_custom_backend, get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, - split_attn_metadata, ) +from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( AttentionSpec, @@ -180,16 +163,29 @@ def graph_capture(device: torch.device): from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ngram_proposer_gpu import ( + NgramProposerGPU, + copy_num_valid_draft_tokens, + update_ngram_gpu_tensors_incremental, + update_scheduler_for_invalid_drafts, +) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer +from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.cp_utils import check_attention_cp_compatibility +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + check_attention_cp_compatibility, + get_total_cp_world_size, +) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin +from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -198,22 +194,28 @@ def graph_capture(device: torch.device): UBatchSlices, check_ubatch_thresholds, maybe_create_ubatch_slices, + split_attn_metadata, ) from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.workspace import lock_workspace from vllm.v1.worker.utils import ( AttentionGroup, - MultiModalBudget, + KVBlockZeroer, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, + prepare_kernel_block_sizes, sanity_check_mm_encoder_outputs, ) if TYPE_CHECKING: - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager + +logger = init_logger(__name__) +# FL-specific imports from vllm_fl.compilation.graph import GraphWrapper from vllm_fl.dispatch.io_common import managed_inference_mode from vllm_fl.dispatch.io_dumper import ( @@ -221,9 +223,7 @@ def graph_capture(device: torch.device): init_io_dump_from_env, register_io_module_hooks, ) - -logger = init_logger(__name__) - +CUDAGraphWrapper = GraphWrapper AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled @@ -238,7 +238,7 @@ def __init__( sampled_token_ids: torch.Tensor, logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], - async_output_copy_stream: current_platform.torch_device_fn.Stream | None, + async_output_copy_stream: torch.cuda.Stream, vocab_size: int, ): self._model_runner_output = model_runner_output @@ -254,11 +254,11 @@ def __init__( self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. - default_stream = current_platform.torch_device_fn.current_stream() - with current_platform.torch_device_fn.stream(async_output_copy_stream): + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self.sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True + "cpu", non_blocking=True ) self._logprobs_tensors_cpu = ( self._logprobs_tensors.to_cpu_nonblocking() @@ -282,22 +282,106 @@ def get_output(self) -> ModelRunnerOutput: valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() for i in self._invalid_req_indices: valid_sampled_token_ids[i].clear() - cu_num_tokens = None + logprobs_lists = None + if self._logprobs_tensors_cpu is not None: + logprobs_lists = self._logprobs_tensors_cpu.tolists() else: - valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output( self.sampled_token_ids_cpu, self.vocab_size, self._invalid_req_indices, - return_cu_num_tokens=self._logprobs_tensors_cpu is not None, + logprobs_tensors=self._logprobs_tensors_cpu, ) output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids - if self._logprobs_tensors_cpu: - output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens) + output.logprobs = logprobs_lists return output +def _copy_pooler_output_to_cpu( + raw_pooler_output: PoolerOutput, finished_mask: list[bool] +) -> list[torch.Tensor | None]: + num_reqs = len(finished_mask) + + if isinstance(raw_pooler_output, torch.Tensor): + if raw_pooler_output.shape[0] != num_reqs: + raise ValueError( + "Pooler output batch size does not match finished mask size: " + f"{raw_pooler_output.shape[0]} != {num_reqs}." + ) + + num_finished = sum(finished_mask) + if num_finished == 0: + return [None] * num_reqs + if num_finished == num_reqs: + return list(raw_pooler_output.to("cpu", non_blocking=True)) + + # partial finished + finished_indices = [i for i, include in enumerate(finished_mask) if include] + index_tensor = torch.tensor( + finished_indices, device=raw_pooler_output.device, dtype=torch.long + ) + finished_outputs = raw_pooler_output.index_select(0, index_tensor).to( + "cpu", non_blocking=True + ) + partial_pooler_output: list[torch.Tensor | None] = [None] * num_reqs + for i, out in zip(finished_indices, finished_outputs): + partial_pooler_output[i] = out + return partial_pooler_output + + assert isinstance(raw_pooler_output, list) + if len(raw_pooler_output) != num_reqs: + raise ValueError( + "Pooler output batch size does not match finished mask size: " + f"{len(raw_pooler_output)} != {num_reqs}." + ) + + pooler_output: list[torch.Tensor | None] = [None] * num_reqs + for i, (out, include) in enumerate(zip(raw_pooler_output, finished_mask)): + if include and out is not None: + pooler_output[i] = out.to("cpu", non_blocking=True) + return pooler_output + + +class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + raw_pooler_output: PoolerOutput, + finished_mask: list[bool], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + + # Event on the copy stream so we can synchronize the non-blocking copy. + self.async_copy_ready_event = torch.Event() + + # Keep a reference to the device tensors to avoid them being + # deallocated until we finish copying it to the host. + self._raw_pooler_output = raw_pooler_output + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._model_runner_output.pooler_output = _copy_pooler_output_to_cpu( + raw_pooler_output=self._raw_pooler_output, + finished_mask=finished_mask, + ) + self.async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + This function blocks until the copy is finished. + """ + self.async_copy_ready_event.synchronize() + + # Release the device tensors once the copy has completed. + del self._raw_pooler_output + return self._model_runner_output + + class ExecuteModelState(NamedTuple): """Ephemeral cached state transferred between execute_model() and sample_tokens(), after execute_model() returns None.""" @@ -311,6 +395,7 @@ class ExecuteModelState(NamedTuple): aux_hidden_states: list[torch.Tensor] | None ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None class ModelRunnerFL( @@ -324,6 +409,7 @@ def __init__( self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config + self.offload_config = vllm_config.offload_config self.compilation_config = vllm_config.compilation_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config @@ -332,10 +418,6 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config - from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - - set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) - model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config @@ -353,8 +435,12 @@ def __init__( self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model ) - # This will be overridden in load_model() + # These will be overridden in load_model() self.is_multimodal_pruning_enabled = False + self.requires_sequential_video_encoding = False + # Set to True after init_routed_experts_capturer() completes. + # Prevents routed experts code from running during profiling/dummy run. + self.routed_experts_initialized = False self.max_model_len = model_config.max_model_len # Always set to false after the first forward pass @@ -366,11 +452,11 @@ def __init__( # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches + # TODO: Support overlapping micro-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( self.parallel_config.distributed_executor_backend == "external_launcher" - and len(get_pp_group().ranks) > 0 + and len(get_pp_group().ranks) > 1 ) # Model-related. @@ -398,10 +484,15 @@ def __init__( else: self.max_encoder_len = 0 + # Async scheduling + self.use_async_scheduling = self.scheduler_config.async_scheduling + # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: EplbState | None = None + # NOTE(yongji): flag to temporarily disable EPLB during scaling up/down + self.eep_eplb_suppressed = False """ State of the expert parallelism load balancer. @@ -421,6 +512,10 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} + self.late_interaction_runner = LateInteractionRunner() + + # Encoder CUDA graph manager (initialized after model load if enabled) + self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -429,10 +524,41 @@ def __init__( # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( - NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + NgramProposer # noqa: F823 + | NgramProposerGPU + | SuffixDecodingProposer + | EagleProposer + | DraftModelProposer + | MedusaProposer + | ExtractHiddenStatesProposer ) if self.speculative_config.method == "ngram": + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer( + vllm_config=self.vllm_config, + device=self.device, + runner=self, + ) + elif self.speculative_config.use_ngram_gpu(): + self.drafter = NgramProposerGPU(self.vllm_config, self.device, self) + self.num_tokens_no_spec_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=device + ) + self.token_ids_gpu_tensor = torch.zeros( + self.max_num_reqs, + self.max_model_len, + dtype=torch.int32, + device=device, + ) + self._ngram_pinned_idx_buf = torch.zeros( + self.max_num_reqs, dtype=torch.long, pin_memory=True + ) + self._ngram_pinned_val_buf = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=True + ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): @@ -445,22 +571,37 @@ def __init__( self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device ) + elif self.speculative_config.method == "extract_hidden_states": + self.drafter = ExtractHiddenStatesProposer( + vllm_config=self.vllm_config, device=self.device + ) + self.use_aux_hidden_state_outputs = True else: raise ValueError( "Unknown speculative decoding method: " f"{self.speculative_config.method}" ) self.rejection_sampler = RejectionSampler(self.sampler) + self.num_spec_tokens = 0 + self.valid_sampled_token_count_gpu: torch.Tensor | None = None if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens + draft_config = self.speculative_config.draft_model_config + if draft_config is not None and draft_config.max_model_len is not None: + self.effective_drafter_max_model_len = draft_config.max_model_len + else: + self.effective_drafter_max_model_len = self.max_model_len + self.use_async_spec_decode = ( + self.use_async_scheduling and self.num_spec_tokens > 0 + ) # Request states. self.requests: dict[str, CachedRequestState] = {} # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} - self.comm_stream = current_platform.torch_device_fn.Stream() + self.comm_stream = torch.cuda.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -475,17 +616,22 @@ def __init__( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( tuple(logits_processors) if logits_processors is not None else () ) + placeholder_block_size = ( + self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE + ) + self._init_block_sizes = [placeholder_block_size] + self._init_kernel_block_sizes = [placeholder_block_size] self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - # We need to use the encoder length for encoder-decoer + # We need to use the encoder length for encoder-decoder # because of KV cache for cross-attention. max_model_len=max(self.max_model_len, self.max_encoder_len), max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.cache_config.block_size], - kernel_block_sizes=[self.cache_config.block_size], + block_sizes=[placeholder_block_size], + kernel_block_sizes=[placeholder_block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( self.vllm_config, @@ -496,18 +642,22 @@ def __init__( ), # We currently don't know whether a particular custom logits processor # uses output token ids so we set this conservatively. - logitsprocs_need_output_token_ids=bool(custom_logitsprocs), + # ThinkingTokenBudgetLogitsProcessor also needs output token ids to + # correctly track think start/end token sequences in async scheduling. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs) + or self.vllm_config.reasoning_config is not None, is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) - self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = current_platform.torch_device_fn.Stream() if \ - self.use_async_scheduling else None + # Separate cuda stream for overlapping transfer of sampled token ids from + # GPU to CPU when async scheduling is enabled. + self.async_output_copy_stream: torch.cuda.Stream | None = None # cuda event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. self.prepare_inputs_event: torch.Event | None = None if self.use_async_scheduling: + self.async_output_copy_stream = torch.cuda.Stream() self.prepare_inputs_event = torch.Event() # self.cudagraph_batch_sizes sorts in ascending order. @@ -518,13 +668,43 @@ def __init__( self.cudagraph_batch_sizes = sorted( self.compilation_config.cudagraph_capture_sizes ) + else: + self.cudagraph_batch_sizes = [] + + # Cache the device properties. + self._init_device_properties() + + # Encoder timing registry for observability + self.encoder_timing_registry: dict[str, EncoderTimingStats] = {} + self._encoder_timing_lock = threading.Lock() + # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=self.device + ) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) - self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.seq_lens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.optimistic_seq_lens_cpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self.num_computed_tokens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.prev_num_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + # Maps current batch position -> previous batch position (-1 for new reqs) + self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + self.num_scheduled_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) if self.dcp_world_size > 1: self.dcp_local_seq_lens = self._make_buffer( @@ -544,11 +724,18 @@ def __init__( self.max_num_reqs, dtype=torch.int32 ) self.num_accepted_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int64 + self.max_num_reqs, dtype=torch.int32 ) + # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + # Double buffer to avoid race condition: previous iteration's async + # copy may still be reading from CPU while current iteration writes. + self.is_mm_embed_buffers = [ + self._make_buffer(self.max_num_tokens, dtype=torch.bool), + self._make_buffer(self.max_num_tokens, dtype=torch.bool), + ] + self.is_mm_embed_idx = 0 # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -576,12 +763,14 @@ def __init__( # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None - # OPTIMIZATION: Cache the tensors rather than creating them every step. - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), - dtype=np.int64, - ) + # OPTIMIZATION: Cache the arange tensors rather than creating them + # every step. Keep in int64 to avoid overflow with long context. + # - arange_np: immutable [0, 1, 2, ...] used as source for batched computation + # - query_pos: CpuGpuBuffer for the computed batched arange result + arange_size = max(self.max_num_reqs + 1, self.max_num_tokens) + self.arange_np = np.arange(arange_size, dtype=np.int64) + self.query_pos = self._make_buffer(arange_size, dtype=torch.int64) + self._arange_scratch = np.empty(arange_size, dtype=np.int64) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -602,11 +791,7 @@ def __init__( self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.mm_budget = ( - MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) + MultiModalBudget(self.vllm_config, self.mm_registry) if self.supports_mm_inputs else None ) @@ -620,37 +805,96 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. + self._num_valid_draft_tokens: torch.Tensor | None = None + self._num_valid_draft_tokens_cpu: torch.Tensor | None = None + self._num_valid_draft_tokens_event: torch.cuda.Event | None = None + self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + self._num_valid_draft_tokens_cpu = torch.empty( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self._num_valid_draft_tokens_event = torch.cuda.Event() + self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream() + + self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() - # TODO(yxa): NPU uses int32, CUDA uses int64 for sampled token ids - sampled_ids_dtype = torch.int32 if current_platform.device_type == "npu" else torch.int64 self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), - dtype=sampled_ids_dtype, + dtype=torch.int64, device="cpu", pin_memory=self.pin_memory, ) + # Pre-allocated tensor for copying valid sampled token counts to CPU, # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None - if self.use_async_scheduling and self.num_spec_tokens: - self.valid_sampled_token_count_event = torch.Event() - self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() - self.valid_sampled_token_count_cpu = torch.empty( - self.max_num_reqs, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) + # We also copy the drafted tokens to the CPU asynchronously, + # in case we need them for structured outputs. + self.draft_token_ids_event: torch.Event | None = None + self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None + self.valid_sampled_token_count_cpu: torch.Tensor | None = None + self.draft_token_ids_cpu: torch.Tensor | None = None + self.num_accepted_tokens_event: torch.Event | None = None + if self.num_spec_tokens: + self.draft_token_ids_event = torch.Event() + self.num_accepted_tokens_event = torch.Event() + self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_cpu = torch.empty( + (self.max_num_reqs, self.num_spec_tokens), + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory, + ) + if self.use_async_scheduling: + self.valid_sampled_token_count_event = torch.Event() + self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_cpu = torch.empty( + self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) + + # Model weight offloader + # Make sure this is called before any get_offloader call + set_offloader(create_offloader(self.offload_config)) # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None self.layerwise_nvtx_hooks_registered = False + def update_max_model_len(self, max_model_len: int) -> None: + self.max_model_len = max_model_len + if self.speculative_config: + draft_config = self.speculative_config.draft_model_config + if draft_config is None or draft_config.max_model_len is None: + self.effective_drafter_max_model_len = self.max_model_len + def reset_mm_cache(self) -> None: + """ + Clear the multi-modal cache that was used during profiling, + but no longer needed during inference. + """ if self.mm_budget: self.mm_budget.reset_cache() + self.late_interaction_runner.clear() + + def reset_encoder_cache(self) -> None: + """Clear the GPU-side encoder cache storing vision embeddings. + + This should be called when model weights are updated to ensure + stale embeddings computed with old weights are not reused. + """ + self.encoder_cache.clear() + self.late_interaction_runner.clear() @managed_inference_mode() def init_fp8_kv_scales(self) -> None: @@ -702,13 +946,13 @@ def _get_positions(self, num_tokens: Any): return self.mrope_positions.gpu[:, :num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, :num_tokens] - return self.positions.gpu[:num_tokens] + return self.positions[:num_tokens] else: if self.uses_mrope: return self.mrope_positions.gpu[:, num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, num_tokens] - return self.positions.gpu[num_tokens] + return self.positions[num_tokens] def _make_buffer( self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True @@ -721,7 +965,17 @@ def _make_buffer( with_numpy=numpy, ) - def _init_model_kwargs(self, num_tokens: int): + def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers: + if self._mamba_copy_bufs is None: + self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create( + self.max_num_reqs, + self.kv_cache_config, + self.model.get_mamba_state_copy_func(), + self._make_buffer, + ) + return self._mamba_copy_bufs + + def _init_model_kwargs(self): model_kwargs = dict[str, Any]() if not self.is_pooling_model: @@ -742,7 +996,7 @@ def _init_model_kwargs(self, num_tokens: int): if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens.gpu[:num_reqs] + seq_lens = self.seq_lens[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -765,7 +1019,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ - # Attention free models have zero kv_cache_goups, however models + # Attention free models have zero kv_cache_groups, however models # like Mamba are also attention free but use the kv_cache for # keeping its internal state. This is why we check the number # of kv_cache groups instead of solely checking @@ -780,11 +1034,37 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: decode_threshold=self.reorder_batch_threshold, ) + def _init_kv_zero_meta(self) -> None: + """One-time precomputation for _zero_block_ids. + + Delegates to KVBlockZeroer.init_meta with the runner's state. + Called from gpu_worker.py outside the CuMem pool context. + """ + self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory) + self._kv_block_zeroer.init_meta( + attn_groups_iter=self._kv_cache_spec_attn_group_iterator(), + kernel_block_sizes=self._kernel_block_sizes, + cache_dtype=self.cache_config.cache_dtype, + runner_only_attn_layers=self.runner_only_attn_layers, + static_forward_context=(self.compilation_config.static_forward_context), + ) + + def _zero_block_ids(self, block_ids: list[int]) -> None: + """Zero the KV cache memory for the given block IDs.""" + if hasattr(self, "_kv_block_zeroer"): + self._kv_block_zeroer.zero_block_ids(block_ids) + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties""" + + self.num_sms = num_compute_units(self.device.index) + # Note: used for model runner override. def _sync_device(self) -> None: - current_platform.torch_device_fn.synchronize() + torch.accelerator.synchronize() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None: """Update the cached states and the persistent batch with the scheduler output. @@ -798,6 +1078,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + self.late_interaction_runner.on_requests_finished( + scheduler_output.finished_req_ids + ) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -807,6 +1090,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) + # Zero GPU memory for freshly allocated cache blocks to prevent + # stale NaN/data from corrupting attention or SSM computation. + if scheduler_output.new_block_ids_to_zero: + self._zero_block_ids(scheduler_output.new_block_ids_to_zero) + # Free the cached encoder outputs. for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_cache.pop(mm_hash, None) @@ -833,10 +1121,25 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) + is_ngram_gpu = ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ) + if is_ngram_gpu: + ngram_gpu_new_reqs: list[CachedRequestState] = [] + reqs_to_add: list[CachedRequestState] = [] + deferred_spec_decode_corrections = [] + # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id + if req_id in self.requests: + # For streaming case only. + req_state = self._update_streaming_request(req_id, new_req_data) + reqs_to_add.append(req_state) + continue + sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params @@ -872,6 +1175,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) self.requests[req_id] = req_state + self.late_interaction_runner.register_request(req_id, pooling_params) if sampling_params and sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( @@ -889,14 +1193,32 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self._init_xdrope_positions(req_state) reqs_to_add.append(req_state) + # Track new requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens - # Wait until valid_sampled_tokens_count is copied to cpu, - # then use it to update actual num_computed_tokens of each request. - valid_sampled_token_count = self._get_valid_sampled_token_count() + # Save scheduler-allocated spec lengths before trimming so + # prev_num_draft_len keeps the optimistic count for rejection correction. + original_num_spec_per_req: dict[str, int] = {} + if ( + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() + ): + for req_id, toks in scheduled_spec_tokens.items(): + original_num_spec_per_req[req_id] = len(toks) + update_scheduler_for_invalid_drafts( + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens_cpu, + scheduler_output, + self.input_batch.req_id_to_index, + ) + if self.use_async_spec_decode: + self.prev_num_draft_tokens.np.fill(0) for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] @@ -906,58 +1228,82 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: num_output_tokens = req_data.num_output_tokens[i] req_index = self.input_batch.req_id_to_index.get(req_id) - # prev_num_draft_len is used in async scheduling mode with - # spec decode. it indicates if need to update num_computed_tokens - # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], - # prev_num_draft_len = 0. - # second step: num_computed_tokens = 100(prompt lenth), - # spec_tokens = [a,b], prev_num_draft_len = 0. - # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], - # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain - # the spec tokens length, but in third step it contains the - # spec tokens length. we only need to update num_computed_tokens - # when prev_num_draft_len > 0. - if req_state.prev_num_draft_len: + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # first step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt length), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step doesn't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. if req_index is None: req_state.prev_num_draft_len = 0 else: - assert self.input_batch.prev_req_id_to_index is not None - prev_req_index = self.input_batch.prev_req_id_to_index[req_id] - num_accepted = valid_sampled_token_count[prev_req_index] - 1 - num_rejected = req_state.prev_num_draft_len - num_accepted - num_computed_tokens -= num_rejected - req_state.output_token_ids.extend([-1] * num_accepted) + # Optimistically assume all accepted; queue up a correction + # to be called after the model forward to preserve async + # scheduling. Corrected on GPU in _prepare_inputs. + optimistic_num_accepted = req_state.prev_num_draft_len + req_state.output_token_ids.extend([-1] * optimistic_num_accepted) + + deferred_spec_decode_corrections.append( + (req_id, optimistic_num_accepted, req_state) + ) + + prev_req_index = ( + self.input_batch.prev_req_id_to_index.get(req_id) + if self.input_batch.prev_req_id_to_index + else None + ) + if prev_req_index is not None: + self.prev_num_draft_tokens.np[prev_req_index] = ( + optimistic_num_accepted + ) + + if is_ngram_gpu and optimistic_num_accepted > 0: + self.input_batch.num_tokens_no_spec[req_index] += ( + optimistic_num_accepted + ) # Update the cached states. req_state.num_computed_tokens = num_computed_tokens if not is_last_rank: - # When using PP, the scheduler sends the sampled tokens back, - # because there's no direct communication between the first- - # stage worker and the last-stage worker. - new_token_ids = req_data.new_token_ids[i] - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = ( - num_computed_tokens + len(new_token_ids) - req_state.num_tokens - ) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) + if not req_data.new_token_ids: + # Async scheduled PP: Sampled tokens propagated via GPU broadcast. + new_token_ids: list[int] = [] + else: + # Non-async scheduling with PP: The scheduler sends + # sampled token ids back because there's no direct communication + # between the first-stage worker and the last-stage worker. + new_token_ids = req_data.new_token_ids[i] + # Add the sampled token(s) from the previous step (if any). + # This doesn't include "unverified" tokens like spec tokens. + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) + if num_new_tokens == 1: + # Avoid slicing list in most common case. + req_state.output_token_ids.append(new_token_ids[-1]) + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:] + ) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load - # failure. Align the cached state. + # failure, or output_token_ids was inflated by the optimistic + # extend above (async spec decode). Align the cached state. del req_state.output_token_ids[num_output_tokens:] if req_index is not None: end_idx = ( self.input_batch.num_prompt_tokens[req_index] + num_output_tokens ) - self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx # Update the block IDs. @@ -985,6 +1331,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] reqs_to_add.append(req_state) + # Track resumed requests for ngram_gpu full tensor copy + if is_ngram_gpu: + ngram_gpu_new_reqs.append(req_state) continue # Update the persistent batch. @@ -1002,46 +1351,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index, start_token_index:end_token_index ] = new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index - self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [] - ) - num_spec_tokens = len(spec_token_ids) - # For async scheduling, token_ids_cpu assigned from - # spec_token_ids are placeholders and will be overwritten in - # _prepare_input_ids. - if num_spec_tokens: - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index - ] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec tokens. - self.input_batch.num_tokens[req_index] += num_spec_tokens - - # When speculative decoding is used with structured output, - # the scheduler can drop draft tokens that do not - # conform to the schema. This can result in - # scheduler_output.scheduled_spec_decode_tokens being empty, - # even when speculative decoding is enabled. - self.input_batch.spec_token_ids[req_index].clear() - self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) - - # there are no draft tokens with async scheduling, - # we clear the spec_decoding info in scheduler_output and - # use normal sampling but rejection_sampling. - if self.use_async_scheduling: - req_state.prev_num_draft_len = num_spec_tokens - if num_spec_tokens and self._draft_token_ids is None: - scheduler_output.total_num_scheduled_tokens -= num_spec_tokens - scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens - scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None) + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) + # Restore scheduler-side draft count after ngram trimming. + if original_num_spec_per_req: + orig = original_num_spec_per_req.get(req_id, 0) + if orig != req_state.prev_num_draft_len: + req_state.prev_num_draft_len = orig + # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -1050,8 +1373,54 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + # Incrementally update ngram_gpu tensors after batch is stable + if is_ngram_gpu: + update_ngram_gpu_tensors_incremental( + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + ngram_gpu_new_reqs, + self.device, + _pinned_idx_buf=self._ngram_pinned_idx_buf, + _pinned_val_buf=self._ngram_pinned_val_buf, + ) + + if deferred_spec_decode_corrections: + + def correct_spec_decode_token_counts(): + valid_sampled_token_count = self._get_valid_sampled_token_count() + if not valid_sampled_token_count: + return + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + if not prev_req_id_to_index: + return + for ( + req_id, + optimistic_num_accepted, + req_state, + ) in deferred_spec_decode_corrections: + prev_req_index = prev_req_id_to_index.get(req_id) + if prev_req_index is None: + continue + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + correction = optimistic_num_accepted - num_accepted + req_state.num_computed_tokens -= correction + cur_req_index = self.input_batch.req_id_to_index.get(req_id) + if cur_req_index is None: + continue + self.input_batch.num_computed_tokens_cpu[cur_req_index] -= ( + correction + ) + if is_ngram_gpu and correction > 0: + self.input_batch.num_tokens_no_spec[cur_req_index] -= correction + self.num_tokens_no_spec_gpu[cur_req_index] -= correction + + return correct_spec_decode_token_counts + else: + return None + def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor + self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: """Update the cached states after model execution. @@ -1061,17 +1430,21 @@ def _update_states_after_model_execute( each sequence, and a shifting is done during the next iteration based on the number of accepted tokens. """ - if not self.model_config.is_hybrid or not self.speculative_config: + if not self.speculative_config or not self.model_config.is_hybrid: return + # TODO: Remove .cpu() sync to enable fully async for hybrid model; + # Use num_computed_tokens.gpu instead of req.num_computed_tokens to + # support aligned mamba cache mode. # Find the number of accepted tokens for each sequence. - num_accepted_tokens = ( + num_reqs = output_token_ids.size(0) + self.num_accepted_tokens.gpu[:num_reqs] = ( ( torch.cat( [ output_token_ids, torch.full( - (output_token_ids.size(0), 1), + (num_reqs, 1), -1, device=output_token_ids.device, ), @@ -1082,11 +1455,64 @@ def _update_states_after_model_execute( ) .int() .argmax(-1) - .cpu() - .numpy() ) - for i, num_tokens in enumerate(num_accepted_tokens): - self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + + if self.cache_config.mamba_cache_mode == "align": + for i, num_tokens in enumerate( + self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() + ): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + mamba_utils.postprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.input_batch, + self.requests, + self.mamba_state_idx, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + else: + self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_( + self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True + ) + assert self.num_accepted_tokens_event is not None + self.num_accepted_tokens_event.record() + + def _update_streaming_request( + self, req_id: str, new_req_data: NewRequestData + ) -> CachedRequestState: + """Updates streaming session request from `scheduled_new_reqs`. + + Removes the request from InputBatch (if present), updates the cached + state, and prepares it for re-addition to the batch. + + NOTE: prompt_token_ids includes intermediate output tokens - tokens + previously generated but now are input context (part of the prompt). + """ + self.input_batch.remove_request(req_id) + req_state = self.requests[req_id] + + req_state.prompt_token_ids = new_req_data.prompt_token_ids + req_state.mm_features = new_req_data.mm_features + req_state.prompt_embeds = new_req_data.prompt_embeds + req_state.sampling_params = new_req_data.sampling_params + req_state.pooling_params = new_req_data.pooling_params + self.late_interaction_runner.register_request(req_id, req_state.pooling_params) + req_state.block_ids = new_req_data.block_ids + req_state.num_computed_tokens = new_req_data.num_computed_tokens + req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds + ) + + # Clear `output_token_ids` as previous output tokens are now part of + # `prompt_token_ids`. + req_state.output_token_ids.clear() + + if self.uses_mrope: + self._init_mrope_positions(req_state) + + return req_state def _init_mrope_positions(self, req_state: CachedRequestState): model = self.get_model() @@ -1123,22 +1549,20 @@ def _extract_mm_kwargs( if not scheduler_output or not self.is_multimodal_raw_input_only_model: return {} - mm_kwargs = list[MultiModalKwargsItem]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() for req in scheduler_output.scheduled_new_reqs: for feature in req.mm_features: if feature.data is not None: - mm_kwargs.append(feature.data) + mm_kwargs.append((feature.modality, feature.data)) # Input all modalities at once - model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs( mm_kwargs, device=self.device, pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, ): - mm_kwargs_combined.update(mm_kwargs_group) + mm_kwargs_combined.update(mm_kwargs_batch) return mm_kwargs_combined @@ -1149,18 +1573,23 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: mm_budget = self.mm_budget assert mm_budget is not None + if not mm_budget.mm_max_toks_per_item: + return {} # No tower modalities (embed-only mode) + dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) def _get_cumsum_and_arange( self, num_tokens: np.ndarray, + arange_out: np.ndarray, cumsum_dtype: np.dtype | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray: """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_tokens]) + E.g., [2, 5, 3] -> [2, 7, 10], arange written to + arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]. + Equivalent to but faster than: + np.concatenate([np.arange(n) for n in num_tokens]) """ # Step 1. [2, 5, 3] -> [2, 7, 10] cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) @@ -1168,13 +1597,33 @@ def _get_cumsum_and_arange( # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.arange_np[:total_num_tokens] - cumsums_offsets + np.subtract( + self.arange_np[:total_num_tokens], + cumsums_offsets, + out=arange_out[:total_num_tokens], + ) + + return cu_num_tokens + + def _compute_prev_positions(self, num_reqs: int) -> None: + """Build prev_positions mapping: current pos -> previous pos (-1 if new). + + Populates self.prev_positions.np[:num_reqs] with the mapping. + """ + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + prev_positions = self.prev_positions.np[:num_reqs] + + if not prev_req_id_to_index: + prev_positions.fill(-1) + return - return cu_num_tokens, arange + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + prev_positions[i] = prev_req_id_to_index.get(req_id, -1) def _prepare_input_ids( self, scheduler_output: "SchedulerOutput", + num_reqs: int, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray, ) -> None: @@ -1182,7 +1631,11 @@ def _prepare_input_ids( Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the - GPU need to be copied into the corresponding slots into input_ids.""" + GPU need to be copied into the corresponding slots into input_ids. + + Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos + (-1 for new requests). + """ if self.input_batch.prev_sampled_token_ids is None: # Normal scheduling case @@ -1195,73 +1648,76 @@ def _prepare_input_ids( # Async scheduling case, where some decode requests from the previous # iteration won't have entries in input_ids_cpu and need to be copied # on the GPU from prev_sampled_token_ids. - prev_req_id_to_index = self.input_batch.prev_req_id_to_index - assert prev_req_id_to_index is not None + prev_positions = self.prev_positions.np[:num_reqs] + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens sample_flattened_indices: list[int] = [] spec_flattened_indices: list[int] = [] - prev_common_req_indices: list[int] = [] prev_draft_token_indices: list[int] = [] - indices_match = True + prev_indices: list[int] = [] + common_indices_match = True max_flattened_index = -1 total_num_spec_tokens = 0 - scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens - for req_id, cur_index in self.input_batch.req_id_to_index.items(): - if (prev_index := prev_req_id_to_index.get(req_id)) is not None: - prev_common_req_indices.append(prev_index) - # We need to compute the flattened input_ids index of the - # last token in each common request. - draft_len = len(scheduled_spec_tokens.get(req_id, ())) - total_num_spec_tokens += draft_len - flattened_index = cu_num_tokens[cur_index].item() - 1 - # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] - # sample_flattened_indices = [0, 2, 5] - # spec_flattened_indices = [1, 3, 4, 6, 7] - sample_flattened_indices.append(flattened_index - draft_len) - spec_flattened_indices.extend( - range(flattened_index - draft_len + 1, flattened_index + 1) - ) - start = prev_index * self.num_spec_tokens - # prev_draft_token_indices is used to find which draft_tokens_id - # should be copied to input_ids - # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] - # flatten draft_tokens_id [1,2,3,4,5,6] - # draft_len of each request [1, 2, 1] - # then prev_draft_token_indices is [0, 2, 3, 4] - prev_draft_token_indices.extend(range(start, start + draft_len)) - indices_match &= prev_index == flattened_index - max_flattened_index = max(max_flattened_index, flattened_index) - num_commmon_tokens = len(sample_flattened_indices) + for cur_index in range(num_reqs): + prev_index = prev_positions[cur_index] + if prev_index < 0: + continue + prev_indices.append(prev_index) + req_id = self.input_batch.req_ids[cur_index] + # We need to compute the flattened input_ids index of the + # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) + common_indices_match &= prev_index == flattened_index + max_flattened_index = max(max_flattened_index, flattened_index) + + num_common_tokens = len(sample_flattened_indices) total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens - if num_commmon_tokens < total_without_spec: + if num_common_tokens < total_without_spec: # If not all requests are decodes from the last iteration, - # We need to copy the input_ids_cpu to the GPU first. + # we need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.enable_prompt_embeds: self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) - if num_commmon_tokens == 0: + if num_common_tokens == 0: # No requests in common with the previous iteration # So input_ids.cpu will have all the input ids. return - if indices_match and max_flattened_index == (num_commmon_tokens - 1): + if common_indices_match and max_flattened_index == (num_common_tokens - 1): # Common-case optimization: the batch is unchanged # and no reordering happened. # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. - self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + self.input_ids.gpu[:num_common_tokens].copy_( + self.input_batch.prev_sampled_token_ids[:num_common_tokens, 0], non_blocking=True, ) if self.enable_prompt_embeds: - self.is_token_ids.gpu[:num_commmon_tokens] = True + self.is_token_ids.gpu[:num_common_tokens] = True return # Upload the index tensors asynchronously so the scatter can be non-blocking. sampled_tokens_index_tensor = torch.tensor( sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + prev_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, @@ -1286,7 +1742,6 @@ def _prepare_input_ids( # because input_ids dtype is torch.int32, # so convert draft_token_ids to torch.int32 here. draft_token_ids = self._draft_token_ids.to(dtype=torch.int32) - self._draft_token_ids = None self.input_ids.gpu.scatter_( dim=0, @@ -1299,12 +1754,14 @@ def _get_encoder_seq_lens( num_scheduled_tokens: dict[str, int], kv_cache_spec: KVCacheSpec, num_reqs: int, + for_cudagraph_capture: bool = False, ) -> tuple[torch.Tensor | None, np.ndarray | None]: if not isinstance(kv_cache_spec, CrossAttentionSpec): return None, None # Zero out buffer for padding requests that are not actually scheduled (CGs) self.encoder_seq_lens.np[:num_reqs] = 0 + # Build encoder_seq_lens array mapping request indices to # encoder lengths for inputs scheduled in this batch for req_id in num_scheduled_tokens: @@ -1321,6 +1778,15 @@ def _get_encoder_seq_lens( feature.mm_position.length for feature in req_state.mm_features ) self.encoder_seq_lens.np[req_index] = encoder_input_tokens + if for_cudagraph_capture: + # During CUDA graph capture, we need to use realistic encoder lengths + # so that max_seqlen_k is captured with the correct value. + max_encoder_len = getattr( + self.model_config.hf_config, + "max_source_positions", + self.max_encoder_len, + ) + self.encoder_seq_lens.np[:num_reqs] = max_encoder_len self.encoder_seq_lens.copy_to_gpu(num_reqs) encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs] @@ -1355,15 +1821,15 @@ def _prepare_inputs( req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + # self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np + ) # Get positions. - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, + positions_np = ( + self.input_batch.num_computed_tokens_cpu[req_indices] + + self.query_pos.np[: cu_num_tokens[-1]] ) # Calculate M-RoPE positions. @@ -1441,9 +1907,6 @@ def _prepare_inputs( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1453,12 +1916,21 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + # Compute optimistic seq_lens (assumes all draft tokens from previous + # iteration accepted). Store in optimistic_seq_lens_cpu for use by + # _build_attention_metadata (max_seq_len) and discard_request_mask. + # seq_lens (GPU) will be computed later using the same optimistic values. + torch.add( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + torch.from_numpy(num_scheduled_tokens), + out=self.optimistic_seq_lens_cpu[:num_reqs], ) - # Fill unused with 0 for full cuda graph mode. - self.seq_lens.np[num_reqs:].fill(0) - self.seq_lens.copy_to_gpu() + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + + # Build prev_positions mapping: current pos -> prev pos (-1 if new). + # Used for gathering from previous iteration's GPU tensors. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + self._compute_prev_positions(num_reqs) num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) @@ -1466,13 +1938,78 @@ def _prepare_inputs( # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning self.discard_request_mask.np[:num_reqs] = ( - self.seq_lens.np[:num_reqs] < num_tokens_np + self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np ) self.discard_request_mask.copy_to_gpu(num_reqs) + # Sync num_accepted_tokens from CPU (set by + # _update_states_after_model_execute for hybrid models). + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + else: + self.num_accepted_tokens.np.fill(1) + self.num_accepted_tokens.gpu.fill_(1) + + # Update num_computed_tokens on GPU. In async spec decode, + # CPU values are optimistic (all drafts accepted). The kernel + # corrects on GPU using the previous step's + # valid_sampled_token_count_gpu. Otherwise, just copy from CPU. + if ( + self.use_async_spec_decode + and self.valid_sampled_token_count_gpu is not None + and prev_req_id_to_index + ): + self.prev_positions.copy_to_gpu(num_reqs) + self.prev_num_draft_tokens.copy_to_gpu() + cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to( + device=self.device, non_blocking=True + ) + update_num_computed_tokens_for_batch_change( + self.num_computed_tokens, + self.num_accepted_tokens.gpu[:num_reqs], + self.prev_positions.gpu[:num_reqs], + self.valid_sampled_token_count_gpu, + self.prev_num_draft_tokens.gpu, + cpu_values, + ) + else: + self.num_computed_tokens[:num_reqs].copy_( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + non_blocking=True, + ) + + self.req_indices.np[:total_num_scheduled_tokens] = req_indices + self.req_indices.copy_to_gpu(total_num_scheduled_tokens) + req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens] + + self.query_pos.copy_to_gpu(total_num_scheduled_tokens) + self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens + self.num_scheduled_tokens.copy_to_gpu(num_reqs) + num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs] + self.positions[:total_num_scheduled_tokens] = ( + self.num_computed_tokens[req_indices_gpu].to(torch.int64) + + self.query_pos.gpu[:total_num_scheduled_tokens] + ) + self.seq_lens[:num_reqs] = ( + self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu + ) + self.seq_lens[num_reqs:].fill_(0) + + self.input_batch.block_table.compute_slot_mapping( + num_reqs, + self.query_start_loc.gpu[: num_reqs + 1], + self.positions[:total_num_scheduled_tokens], + ) + # Copy the tensors to the GPU. self._prepare_input_ids( scheduler_output, + num_reqs, total_num_scheduled_tokens, cu_num_tokens, ) @@ -1489,9 +2026,14 @@ def _prepare_inputs( self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True, ) - else: - # Common case (1D positions) - self.positions.copy_to_gpu(total_num_scheduled_tokens) + if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0): + drift = self.num_computed_tokens[req_indices_gpu].to( + torch.int64 + ) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + target = self.mrope_positions if self.uses_mrope else self.xdrope_positions + target.gpu[:, :total_num_scheduled_tokens] += drift use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1501,7 +2043,6 @@ def _prepare_inputs( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: @@ -1517,15 +2058,13 @@ def _prepare_inputs( draft_token_ids, ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = ( - len(draft_token_ids) - if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx] - ) - else -1 - ) + draft_len = len(draft_token_ids) + num_draft_tokens[req_idx] = draft_len + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ): + num_decode_draft_tokens[req_idx] = draft_len spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens ) @@ -1564,6 +2103,7 @@ def _build_attention_metadata( for_cudagraph_capture: bool = False, num_scheduled_tokens: dict[str, int] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None, + slot_mappings: dict[int, torch.Tensor] | None = None, ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: """ :return: tuple[attn_metadata, spec_decode_common_attn_metadata] @@ -1578,7 +2118,7 @@ def _build_attention_metadata( attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: - attn_metadata = [{} for _ in range(len(ubatch_slices))] + attn_metadata = [dict() for _ in range(len(ubatch_slices))] if for_cudagraph_capture: # For some attention backends (e.g. FA) with sliding window models we need @@ -1586,18 +2126,11 @@ def _build_attention_metadata( # window size when capturing to make sure the correct kernel is selected. max_seq_len = self.max_model_len else: - max_seq_len = self.seq_lens.np[:num_reqs].max().item() - - if use_spec_decode: - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() + max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item() kv_cache_groups = self.kv_cache_config.kv_cache_groups - def _get_block_table_and_slot_mapping(kv_cache_gid: int): + def _get_block_table(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): @@ -1606,32 +2139,47 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): dtype=torch.int32, device=self.device, ) - slot_mapping = torch.zeros( - (num_tokens_padded,), - dtype=torch.int64, - device=self.device, - ) else: blk_table = self.input_batch.block_table[kv_cache_gid] blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) - slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID - slot_mapping[num_tokens:num_tokens_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + return blk_table_tensor + + assert slot_mappings is not None + block_table_gid_0 = _get_block_table(0) + slot_mapping_gid_0 = slot_mappings[0] + + if self.routed_experts_initialized: + attn_gid = self.routed_experts_attn_gid + slot_mapping_attn = slot_mappings[attn_gid] + self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ] + num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[ + :num_reqs_padded + ] + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] - return blk_table_tensor, slot_mapping + # is_prefilling: True if request is still in prefill phase. + # Used by mamba backends to distinguish actual decodes from + # short extends. + is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu + + if self.use_async_spec_decode: + # GPU tensors are authoritative in async mode. + seq_lens_cpu = None + num_computed_tokens_cpu = None - block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], - seq_lens=self.seq_lens.gpu[:num_reqs_padded], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], - _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ], + seq_lens=self.seq_lens[:num_reqs_padded], + _seq_lens_cpu=seq_lens_cpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -1639,11 +2187,12 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): block_table_tensor=block_table_gid_0, slot_mapping=slot_mapping_gid_0, causal=True, + is_prefilling=is_prefilling, ) if self.dcp_world_size > 1: self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( - self.seq_lens.cpu[:num_reqs], + self.optimistic_seq_lens_cpu[:num_reqs], self.dcp_world_size, self.dcp_rank, self.parallel_config.cp_kv_cache_interleave_size, @@ -1662,6 +2211,15 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): logits_indices ) + # Cache attention metadata builds across hybrid KV-cache groups + # The only thing that changes between different hybrid KV-cache groups when the + # same metadata builder and KVCacheSpec is the same is the block table, so we + # can cache the attention metadata builds and just update the block table using + # `builder.update_block_table` if the builder supports it. + cached_attn_metadata: dict[ + tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata + ] = {} + def _build_attn_group_metadata( kv_cache_gid: int, attn_gid: int, @@ -1669,15 +2227,22 @@ def _build_attn_group_metadata( ubid: int | None = None, ) -> None: attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder(ubid or 0) + kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): + kv_cache_spec = kv_cache_spec.kv_cache_specs[attn_group.layer_names[0]] + cache_key = (kv_cache_spec, type(builder)) + cascade_attn_prefix_len = ( cascade_attn_prefix_lens[kv_cache_gid][attn_gid] if cascade_attn_prefix_lens else 0 ) - builder = attn_group.get_metadata_builder(ubid or 0) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance( + builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) + ): assert ubid is None, "UBatching not supported with GDN yet" extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], @@ -1690,12 +2255,23 @@ def _build_attn_group_metadata( attn_metadata_i = builder.build_for_cudagraph_capture( common_attn_metadata ) + elif ( + cache_key in cached_attn_metadata + and builder.supports_update_block_table + ): + attn_metadata_i = builder.update_block_table( + cached_attn_metadata[cache_key], + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping, + ) else: attn_metadata_i = builder.build( common_prefix_len=cascade_attn_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args, ) + if builder.supports_update_block_table: + cached_attn_metadata[cache_key] = attn_metadata_i if ubid is None: assert isinstance(attn_metadata, dict) @@ -1719,15 +2295,15 @@ def _build_attn_group_metadata( num_scheduled_tokens or {}, kv_cache_group.kv_cache_spec, num_reqs_padded, + for_cudagraph_capture=for_cudagraph_capture, ) if kv_cache_gid > 0: - cm.block_table_tensor, cm.slot_mapping = ( - _get_block_table_and_slot_mapping(kv_cache_gid) - ) + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + if self.drafter.kv_cache_gid == kv_cache_gid: spec_decode_common_attn_metadata = cm else: spec_decode_common_attn_metadata = cm @@ -1808,7 +2384,6 @@ def _compute_cascade_attn_prefix_lens( return cascade_attn_prefix_lens if use_cascade_attn else None - def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, @@ -1834,6 +2409,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. @@ -1892,8 +2468,6 @@ def _compute_cascade_attn_prefix_len( and kv_cache_spec.attention_chunk_size is not None ) assert isinstance(kv_cache_spec, AttentionSpec) - - ###TODO(lms): fix use of num_sms use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, @@ -1902,7 +2476,7 @@ def _compute_cascade_attn_prefix_len( use_alibi=self.use_alibi, use_sliding_window=use_sliding_window, use_local_attention=use_local_attention, - num_sms=1, + num_sms=self.num_sms, dcp_world_size=self.dcp_world_size, ) return common_prefix_len if use_cascade else 0 @@ -2022,33 +2596,34 @@ def _calc_spec_decode_metadata( # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 - # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] - # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32 + # Step 1. + # cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens = self._get_cumsum_and_arange( + num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32 ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - logits_indices += arange + logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]] # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 # Compute the draft logits indices. # cu_num_draft_tokens: [3, 3, 5, 5, 6] - # arange: [0, 1, 2, 0, 1, 0] - cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32 + # _arange_scratch[:6]: [0, 1, 2, 0, 1, 0] + cu_num_draft_tokens = self._get_cumsum_and_arange( + num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32 ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens ) # [0, 1, 2, 5, 6, 9] - target_logits_indices += arange + target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]] # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( @@ -2076,6 +2651,7 @@ def _calc_spec_decode_metadata( draft_token_ids=draft_token_ids, num_draft_tokens=num_draft_tokens.tolist(), cu_num_draft_tokens=cu_num_draft_tokens, + cu_num_sampled_tokens=cu_num_sampled_tokens, target_logits_indices=target_logits_indices, bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, @@ -2096,42 +2672,45 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item() ) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits + # Dispatch for the decoder portion of the model. + _, batch_desc = self.cudagraph_dispatcher.dispatch( + num_logits, invalid_modes={CUDAGraphMode.FULL} + ) + num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ :num_logits_padded ] return logits_indices_padded - def _batch_mm_kwargs_from_scheduler( + def _batch_mm_inputs_from_scheduler( self, scheduler_output: "SchedulerOutput", - ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: - """Batch multimodal kwargs from scheduled encoder inputs. + ) -> tuple[ + list[str], + list[tuple[str, MultiModalKwargsItem]], + list[tuple[str, PlaceholderRange]], + ]: + """Batch multimodal inputs from scheduled encoder inputs. Args: scheduler_output: The scheduler output containing scheduled encoder inputs. Returns: - A tuple of (mm_kwargs, req_ids_pos) where: - - mm_kwargs: List of multimodal kwargs items to be batched - - mm_hashes_pos: List of (mm_hash, position_info) tuples + A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where: + - mm_hashes: List of multimodal hashes for each item + - mm_kwargs: List of multimodal kwargs for each item + - mm_lora_refs: List of (req_id, placeholder_range) for each item """ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: - return [], [] - # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - # list of tuple (mm_hash, position_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() + return [], [], [] + + mm_hashes = list[str]() + mm_kwargs = list[tuple[str, MultiModalKwargsItem]]() + # Multimodal LoRA reference info to map each multimodal item + # back to its request & position + mm_lora_refs = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] @@ -2139,23 +2718,29 @@ def _batch_mm_kwargs_from_scheduler( mm_feature = req_state.mm_features[mm_input_id] if mm_feature.data is None: continue - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - return mm_kwargs, mm_hashes_pos + mm_hashes.append(mm_feature.identifier) + mm_kwargs.append((mm_feature.modality, mm_feature.data)) + mm_lora_refs.append((req_id, mm_feature.mm_position)) + + return mm_hashes, mm_kwargs, mm_lora_refs def _execute_mm_encoder( self, scheduler_output: "SchedulerOutput" ) -> list[torch.Tensor]: - # Batch the multi-modal inputs using the helper method. - mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( + mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler( scheduler_output ) if not mm_kwargs: return [] + should_time = bool( + self.observability_config + and self.observability_config.enable_mm_processor_stats + and scheduler_output.scheduled_encoder_inputs + ) + # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. @@ -2164,62 +2749,147 @@ def _execute_mm_encoder( # multimodal inputs. The proper solution should be reordering the # encoder outputs. model = cast(SupportsMultiModal, self.model) + + if self.lora_config and self.lora_manager.supports_tower_connector_lora(): + # Build LoRA mappings independently for encoder inputs + # (encoder batch structure is different from main batch) + prompt_lora_mapping = [] + token_lora_mapping = [] + lora_requests = set() + encoder_token_counts = [] + + for req_id, pos_info in mm_lora_refs: + req_idx = self.input_batch.req_id_to_index[req_id] + lora_id = int(self.input_batch.request_lora_mapping[req_idx]) + + # Prefer pos_info.get_num_embeds to count precise MM embedding tokens. + num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined] + pos_info.get_num_embeds() + ) + prompt_lora_mapping.append(lora_id) + token_lora_mapping.extend([lora_id] * num_tokens) + encoder_token_counts.append(num_tokens) + + if lora_id > 0: + lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id) + if lora_request is not None: + lora_requests.add(lora_request) + + # Set tower adapter mapping + tower_mapping = LoRAMapping( + tuple(token_lora_mapping), + tuple(prompt_lora_mapping), + is_prefill=True, + type=LoRAMappingType.TOWER, + ) + self.lora_manager.set_active_adapters(lora_requests, tower_mapping) + + if hasattr(self.model, "get_num_mm_connector_tokens"): + post_op_counts = [ + self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined] + for num_tokens in encoder_token_counts + ] + + connector_token_mapping = np.repeat( + np.array(prompt_lora_mapping, dtype=np.int32), + np.array(post_op_counts, dtype=np.int32), + ) + connector_mapping = LoRAMapping( + index_mapping=tuple(connector_token_mapping.tolist()), + prompt_mapping=tuple(prompt_lora_mapping), + is_prefill=True, + type=LoRAMappingType.CONNECTOR, + ) + + self.lora_manager.set_active_adapters( + lora_requests, + connector_mapping, + ) + encoder_outputs: list[torch.Tensor] = [] - for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( + # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs + current_item_idx = 0 + for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs( mm_kwargs, device=self.device, pin_memory=self.pin_memory, ): - curr_group_outputs: list[torch.Tensor] = [] + batch_outputs: MultiModalEmbeddings - # EVS-related change. + # EVS and dynamic res video related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data. This solves the issue with scheduler # putting too many video samples into a single batch. Scheduler # uses pruned vision tokens count to compare it versus compute # budget which is incorrect (Either input media size or non-pruned # output vision tokens count should be considered) + # dynamic res video for nemotron temporarily uses this hack via + # requires_sequential_video_encoding + # because it doesn't yet support video batching. # TODO(ywang96): Fix memory profiling to take EVS into account and # remove this hack. if ( - self.is_multimodal_pruning_enabled + ( + self.is_multimodal_pruning_enabled + or self.requires_sequential_video_encoding + ) and modality == "video" and num_items > 1 ): - for video_mm_kwargs_item in filter( - lambda item: item.modality == "video", mm_kwargs - ): - _, _, micro_batch_mm_inputs = next( - group_mm_kwargs_by_modality( - [video_mm_kwargs_item], - device=self.device, - pin_memory=self.pin_memory, + batch_outputs_lst = list[torch.Tensor]() + for video_idx in range(num_items): + video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx] + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx + video_idx, 1 + ): + _, _, micro_batch_mm_inputs = next( + group_and_batch_mm_kwargs( + [video_mm_kwargs_item], + device=self.device, + pin_memory=self.pin_memory, + ) ) - ) - micro_batch_outputs = model.embed_multimodal( - **micro_batch_mm_inputs - ) + micro_batch_outputs = model.embed_multimodal( + **micro_batch_mm_inputs + ) + + batch_outputs_lst.extend(micro_batch_outputs) - curr_group_outputs.extend(micro_batch_outputs) + batch_outputs = batch_outputs_lst else: # Run the encoder. - # `curr_group_outputs` is either of the following: + # `batch_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) # in case feature_size is fixed across all multimodal items. # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment] - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=num_items, - ) - encoder_outputs.extend(curr_group_outputs) + with self.timed_encoder_operation( + should_time, mm_lora_refs, current_item_idx, num_items + ): + cudagraph_output = None + if ( + self.encoder_cudagraph_manager is not None + and self.encoder_cudagraph_manager.supports_modality(modality) + ): + cudagraph_output = self.encoder_cudagraph_manager.execute( + mm_kwargs_batch, + ) + + if cudagraph_output is not None: + batch_outputs = cudagraph_output + else: + batch_outputs = model.embed_multimodal(**mm_kwargs_batch) + + sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items) + encoder_outputs.extend(batch_outputs) + + current_item_idx += num_items # Cache the encoder outputs by mm_hash - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + for mm_hash, output in zip(mm_hashes, encoder_outputs): self.encoder_cache[mm_hash] = output logger.debug("Finish execute for mm hash %s", mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) @@ -2233,8 +2903,13 @@ def _gather_mm_embeddings( ) -> tuple[list[torch.Tensor], torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + # Swap to the other buffer to avoid race condition with previous + # iteration's async copy that may still be reading from CPU. + self.is_mm_embed_idx = 1 - self.is_mm_embed_idx + is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx] + mm_embeds = list[torch.Tensor]() - is_mm_embed = self.is_mm_embed.cpu + is_mm_embed = is_mm_embed_buf.cpu is_mm_embed[:total_num_scheduled_tokens] = False req_start_idx = 0 @@ -2290,9 +2965,15 @@ def _gather_mm_embeddings( mm_embeds_item = encoder_output[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( - True if is_embed is None else is_embed - ) + # OR mask for overlapping mm_features (use_audio_in_video) + if is_embed is None: + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True + ) + else: + is_mm_embed[ + req_start_pos + start_idx : req_start_pos + end_idx + ] |= is_embed mm_embeds_req.append(mm_embeds_item) if self.is_multimodal_pruning_enabled and self.uses_mrope: @@ -2312,7 +2993,7 @@ def _gather_mm_embeddings( mm_embeds.extend(mm_embeds_req) req_start_idx += num_scheduled_tokens - is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens) + is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens) if should_sync_mrope_positions: self._calc_mrope_positions(scheduler_output) @@ -2325,8 +3006,10 @@ def _gather_mm_embeddings( return mm_embeds, is_mm_embed def get_model(self) -> nn.Module: - # get raw model out of the cudagraph wrapper. - if isinstance(self.model, (GraphWrapper, UBatchWrapper)): + if not hasattr(self, "model"): + raise ValueError("Cannot get model before model has been initialized") + if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + # get raw model out of the cudagraph wrapper. return self.model.unwrap() return self.model @@ -2343,6 +3026,9 @@ def get_supported_generation_tasks(self) -> list[GenerationTask]: supported_tasks.append("transcription") + if supports_realtime(model): + supported_tasks.append("realtime") + return supported_tasks def get_supported_pooling_tasks(self) -> list[PoolingTask]: @@ -2350,15 +3036,7 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: if not is_pooling_model(model): return [] - supported_tasks = list(model.pooler.get_supported_tasks()) - - if "score" in supported_tasks: - num_labels = getattr(self.model_config.hf_config, "num_labels", 0) - if num_labels != 1: - supported_tasks.remove("score") - logger.debug_once("Score API is only enabled for num_labels == 1.") - - return supported_tasks + return list(model.pooler.get_supported_tasks()) def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() @@ -2405,7 +3083,7 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ - if not self.parallel_config.enable_eplb: + if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed: return assert self.eplb_state is not None @@ -2417,50 +3095,87 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: log_stats=self.parallel_config.eplb_config.log_balancedness, ) + def setup_eplb_from_mapping( + self, + expanded_physical_to_logical: torch.Tensor, + old_num_physical_experts: int, + ) -> None: + model = self.get_model() + assert is_mixture_of_experts(model) + + self.eplb_state = EplbState.from_mapping( + model=model, + model_config=self.model_config, + device=self.device, + parallel_config=self.parallel_config, + expanded_physical_to_logical=expanded_physical_to_logical, + num_valid_physical_experts=old_num_physical_experts, + ) + def _pool( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + kv_connector_output: KVConnectorOutput | None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + num_reqs = self.input_batch.num_reqs + assert num_reqs == len(self.input_batch.pooling_params), ( "Either all or none of the requests in a batch must be pooling request" ) hidden_states = hidden_states[:num_scheduled_tokens] - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs] pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device + num_scheduled_tokens_np, + seq_lens_cpu, + device=hidden_states.device, + query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1], ) model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( - hidden_states=hidden_states, - pooling_metadata=pooling_metadata, + hidden_states=hidden_states, pooling_metadata=pooling_metadata + ) + + finished_mask = [ + seq_len == prompt_len + for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens) + ] + raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output( + raw_pooler_output=raw_pooler_output, + pooling_params=pooling_metadata.pooling_params, + req_ids=self.input_batch.req_ids, + finished_mask=finished_mask, ) - raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True) if x is not None else x, - raw_pooler_output, + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids.copy(), + req_id_to_index=self.input_batch.req_id_to_index.copy(), + kv_connector_output=kv_connector_output, ) - self._sync_device() - pooler_output: list[torch.Tensor | None] = [] - for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens - ): - output = raw_output if seq_len == prompt_len else None - pooler_output.append(output) + if raw_pooler_output is None or not any(finished_mask): + model_runner_output.pooler_output = [None] * num_reqs + return model_runner_output - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, + if self.use_async_scheduling: + return AsyncGPUPoolingModelRunnerOutput( + model_runner_output=model_runner_output, + raw_pooler_output=raw_pooler_output, + finished_mask=finished_mask, + async_output_copy_stream=self.async_output_copy_stream, + ) + + model_runner_output.pooler_output = _copy_pooler_output_to_cpu( + raw_pooler_output=raw_pooler_output, + finished_mask=finished_mask, ) + self._sync_device() + + return model_runner_output def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when @@ -2470,6 +3185,17 @@ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens + def _prepare_mm_inputs( + self, num_tokens: int + ) -> tuple[torch.Tensor | None, torch.Tensor]: + if self.model.requires_raw_input_tokens: + input_ids = self.input_ids.gpu[:num_tokens] + else: + input_ids = None + + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + return input_ids, inputs_embeds + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -2512,10 +3238,9 @@ def _preprocess( # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens) model_kwargs = { - **self._init_model_kwargs(num_scheduled_tokens), + **self._init_model_kwargs(), **self._extract_mm_kwargs(scheduler_output), } elif self.enable_prompt_embeds and is_first_rank: @@ -2543,7 +3268,7 @@ def _preprocess( self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() input_ids = None else: # For text-only models, we use token ids as input. @@ -2552,14 +3277,16 @@ def _preprocess( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids.gpu[:num_input_tokens] inputs_embeds = None - model_kwargs = self._init_model_kwargs(num_input_tokens) + model_kwargs = self._init_model_kwargs() if self.uses_mrope: positions = self.mrope_positions.gpu[:, :num_input_tokens] elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions.gpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] + if num_input_tokens > num_scheduled_tokens: + self.positions[num_scheduled_tokens:num_input_tokens].zero_() if is_first_rank: intermediate_tensors = None @@ -2594,22 +3321,27 @@ def _sample( ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata + # Update output token ids with tokens sampled in last step + # if async scheduling and required by current sampling params. + self.input_batch.update_async_output_token_ids() if spec_decode_metadata is None: - # Update output token ids with tokens sampled in last step - # if async scheduling and required by current sampling params. - self.input_batch.update_async_output_token_ids() return self.sampler( logits=logits, sampling_metadata=sampling_metadata, ) + # Update spec_token_ids with real draft tokens from pre step only when + # output_token_ids is needed (penalties or bad_words are in use). + if self.use_async_scheduling and self._draft_token_req_ids is not None: + draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu() + self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + sampler_output = self.rejection_sampler( spec_decode_metadata, None, # draft_probs logits, sampling_metadata, ) - self._update_states_after_model_execute(sampler_output.sampled_token_ids) return sampler_output def _bookkeeping_sync( @@ -2651,7 +3383,7 @@ def _bookkeeping_sync( sampled_token_ids = sampler_output.sampled_token_ids logprobs_tensors = sampler_output.logprobs_tensors invalid_req_indices = [] - cu_num_tokens: list[int] | None = None + logprobs_lists = None if not self.use_async_scheduling: # Get the valid generated tokens. max_gen_len = sampled_token_ids.shape[-1] @@ -2661,13 +3393,16 @@ def _bookkeeping_sync( # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[int(i)].clear() + + if logprobs_tensors is not None: + logprobs_lists = logprobs_tensors.tolists() else: # Includes spec decode tokens. - valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output( + valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output( sampled_token_ids, self.input_batch.vocab_size, discard_sampled_tokens_req_indices, - return_cu_num_tokens=logprobs_tensors is not None, + logprobs_tensors=logprobs_tensors, ) else: valid_sampled_token_ids = [] @@ -2715,18 +3450,11 @@ def _bookkeeping_sync( self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - logprobs_lists = ( - logprobs_tensors.tolists(cu_num_tokens) - if not self.use_async_scheduling and logprobs_tensors is not None - else None - ) - # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( hidden_states[:num_scheduled_tokens], @@ -2790,6 +3518,27 @@ def _model_forward( **model_kwargs, ) + @staticmethod + def _is_uniform_decode( + max_num_scheduled_tokens: int, + uniform_decode_query_len: int, + num_tokens: int, + num_reqs: int, + force_uniform_decode: bool | None = None, + ) -> bool: + """ + Checks if it's a decode batch with same amount scheduled tokens + across all requests. + """ + return ( + ( + (max_num_scheduled_tokens == uniform_decode_query_len) + and (num_tokens == max_num_scheduled_tokens * num_reqs) + ) + if force_uniform_decode is None + else force_uniform_decode + ) + def _determine_batch_execution_and_padding( self, num_tokens: int, @@ -2803,6 +3552,7 @@ def _determine_batch_execution_and_padding( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, @@ -2811,14 +3561,12 @@ def _determine_batch_execution_and_padding( torch.Tensor | None, CUDAGraphStat | None, ]: - num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - uniform_decode = ( - ( - (max_num_scheduled_tokens == self.uniform_decode_query_len) - and (num_tokens_padded == max_num_scheduled_tokens * num_reqs) - ) - if force_uniform_decode is None - else force_uniform_decode + uniform_decode = self._is_uniform_decode( + max_num_scheduled_tokens=max_num_scheduled_tokens, + uniform_decode_query_len=self.uniform_decode_query_len, + num_tokens=num_tokens, + num_reqs=num_reqs, + force_uniform_decode=force_uniform_decode, ) # Encoder-decoder models only support CG for decoder_step > 0 (no enc_output # is present). Also, chunked-prefill is disabled, so batch are uniform. @@ -2826,46 +3574,49 @@ def _determine_batch_execution_and_padding( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - has_lora = ( - len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None - else force_has_lora + # Compute LoRA state for cudagraph dispatch + num_active_loras = ( + force_num_active_loras + if force_num_active_loras is not None + else len(self.input_batch.lora_id_to_lora_request) ) + has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora + + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) - dispatch_cudagraph = ( - lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch( + def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None): + return self.cudagraph_dispatcher.dispatch( num_tokens=num_tokens, has_lora=has_lora, uniform_decode=uniform_decode, - disable_full=disable_full, + num_active_loras=num_active_loras, + valid_modes={CUDAGraphMode.NONE} if force_eager else valid_modes, + invalid_modes={CUDAGraphMode.FULL} if disable_full else None, ) - if not force_eager - else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) - ) cudagraph_mode, batch_descriptor = dispatch_cudagraph( - num_tokens_padded, use_cascade_attn or has_encoder_output + num_tokens_padded, disable_full=use_cascade_attn or has_encoder_output ) num_tokens_padded = batch_descriptor.num_tokens + if self.compilation_config.pass_config.enable_sp: + assert ( + batch_descriptor.num_tokens + % self.vllm_config.parallel_config.tensor_parallel_size + == 0 + ), ( + "Sequence parallelism requires num_tokens to be " + "a multiple of tensor parallel size" + ) # Extra coordination when running data-parallel since we need to coordinate # across ranks should_ubatch, num_tokens_across_dp = False, None if self.vllm_config.parallel_config.data_parallel_size > 1: - # Disable DP padding when running eager to avoid excessive padding when - # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller - # in a P/D setup and still use CUDA graphs (enabled by this padding) on the - # decoder. - allow_dp_padding = ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - ) - should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = ( coordinate_batch_across_dp( num_tokens_unpadded=num_tokens, parallel_config=self.parallel_config, allow_microbatching=allow_microbatching, - allow_dp_padding=allow_dp_padding, num_tokens_padded=num_tokens_padded, uniform_decode=uniform_decode, num_scheduled_tokens_per_request=num_scheduled_tokens_np, @@ -2880,7 +3631,7 @@ def _determine_batch_execution_and_padding( # Re-dispatch with DP padding so we have the correct batch_descriptor cudagraph_mode, batch_descriptor = dispatch_cudagraph( num_tokens_padded, - disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value, + valid_modes={CUDAGraphMode(synced_cudagraph_mode)}, ) # Assert to make sure the agreed upon token count is correct otherwise # num_tokens_across_dp will no-longer be valid @@ -2939,152 +3690,295 @@ def _register_layerwise_nvtx_hooks(self) -> None: pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True + def _get_slot_mappings( + self, + num_tokens_padded: int, + num_reqs_padded: int, + num_tokens_unpadded: int, + ubatch_slices: "UBatchSlices | None" = None, + ) -> tuple[ + dict[int, torch.Tensor] | None, + dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + ]: + """ + Build slot mappings in both formats needed by the system. + + Args: + num_tokens_padded: Total number of tokens (padded) + num_reqs_padded: Total number of requests (padded) + num_tokens_unpadded: Actual number of tokens (unpadded) + ubatch_slices: Optional ubatch slicing info for DBO + + Returns: + A tuple of: + - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata + - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext + """ + if not ( + hasattr(self, "kv_cache_config") + and self.kv_cache_config is not None + and len(self.kv_cache_config.kv_cache_groups) > 0 + ): + return None, None + + def _get_slot_mapping(kv_cache_gid: int): + assert num_reqs_padded is not None and num_tokens_padded is not None + kv_cache_spec = self.kv_cache_config.kv_cache_groups[ + kv_cache_gid + ].kv_cache_spec + if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): + slot_mapping = torch.zeros( + (num_tokens_padded,), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_gid] + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID + slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1) + + return slot_mapping + + slot_mappings_by_gid = { + gid: _get_slot_mapping(gid) + for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups) + } + + slot_mappings_by_layer: dict[str, torch.Tensor] = {} + for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups): + slot_mapping = slot_mappings_by_gid[gid] + for layer_name in kv_cache_group.layer_names: + slot_mappings_by_layer[layer_name] = slot_mapping + + if ubatch_slices is not None: + result: list[dict[str, torch.Tensor]] = [] + for ubatch in ubatch_slices: + sliced_mappings: dict[str, torch.Tensor] = {} + for layer_name, slot_mapping in slot_mappings_by_layer.items(): + sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice] + result.append(sliced_mappings) + return slot_mappings_by_gid, result + + return slot_mappings_by_gid, slot_mappings_by_layer + @managed_inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | IntermediateTensors | None: - + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError( "State error: sample_tokens() must be called " "after execute_model() returns None." ) - # self._draft_token_ids is None when `input_fits_in_drafter=False` - # and there is no draft tokens scheduled. so it need to update the - # spec_decoding info in scheduler_output with async_scheduling. - # use deepcopy to avoid the modification has influence on the - # scheduler_output in engine core process. - # TODO(Ronald1995): deepcopy is expensive when there is a large - # number of requests, optimize it later. + if self.routed_experts_initialized: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + + # If ngram_gpu is used, we need to copy the scheduler_output to avoid + # the modification has influence on the scheduler_output in engine core process. + # The replace is much faster than deepcopy. if ( - self.use_async_scheduling - and self.num_spec_tokens - and self._draft_token_ids is None + self.speculative_config is not None + and self.speculative_config.use_ngram_gpu() ): - scheduler_output = deepcopy(scheduler_output) + num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy() + spec_decode_tokens_copy = ( + scheduler_output.scheduled_spec_decode_tokens.copy() + ) + scheduler_output = replace( + scheduler_output, + num_scheduled_tokens=num_scheduled_tokens_copy, + scheduled_spec_decode_tokens=spec_decode_tokens_copy, + ) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - with record_function_or_nullcontext("gpu_model_runner: preprocess"): - with self.synchronize_input_prep(): - # Update persistent batch states. - self._update_states(scheduler_output) - - if has_ec_transfer() and get_ec_transfer().is_producer: - with self.maybe_get_ec_connector_output( - scheduler_output, - encoder_cache=self.encoder_cache, - ) as ec_connector_output: - self._execute_mm_encoder(scheduler_output) - return make_empty_encoder_model_runner_output(scheduler_output) - - if not num_scheduled_tokens: - if ( - self.parallel_config.distributed_executor_backend - == "external_launcher" - and self.parallel_config.data_parallel_size > 1 - ): - # this is a corner case when both external launcher - # and DP are enabled, num_scheduled_tokens could be - # 0, and has_unfinished_requests in the outer loop - # returns True. before returning early here we call - # dummy run to ensure coordinate_batch_across_dp - # is called into to avoid out of sync issues. - self._dummy_run(1) - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward( - scheduler_output, self.vllm_config - ) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect " - "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs" - ) + if has_kv_transfer_group(): + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + get_kv_transfer_group().handle_preemptions(kv_connector_metadata) - num_reqs = self.input_batch.num_reqs - req_ids = self.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + with ( + record_function_or_nullcontext("gpu_model_runner: preprocess"), + self.synchronize_input_prep(), + ): + # Update persistent batch states. + deferred_state_corrections_fn = self._update_states(scheduler_output) - ( - logits_indices, - spec_decode_metadata, - ) = self._prepare_inputs( + if has_ec_transfer() and not get_ec_transfer().is_consumer: + with self.maybe_get_ec_connector_output( scheduler_output, - num_scheduled_tokens_np, + encoder_cache=self.encoder_cache, + ) as ec_connector_output: + self._execute_mm_encoder(scheduler_output) + return make_empty_encoder_model_runner_output(scheduler_output) + + if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) + + if self.cache_config.kv_sharing_fast_prefill: + assert not self.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs" ) - cascade_attn_prefix_lens = None - # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: - # Pre-compute cascade attention prefix lengths - cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( - num_scheduled_tokens_np, - self.input_batch.num_computed_tokens_cpu[:num_reqs], - scheduler_output.num_common_prefix_blocks, - ) + num_reqs = self.input_batch.num_reqs + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) + max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - ( - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, - cudagraph_stats, - ) = self._determine_batch_execution_and_padding( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs, - num_scheduled_tokens_np=num_scheduled_tokens_np, - max_num_scheduled_tokens=max_num_scheduled_tokens, - use_cascade_attn=cascade_attn_prefix_lens is not None, - num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), - ) + logits_indices, spec_decode_metadata = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) - logger.debug( - "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " - "should_ubatch: %s, num_tokens_across_dp: %s", - cudagraph_mode, - batch_desc, - should_ubatch, - num_tokens_across_dp, + cascade_attn_prefix_lens = None + # Disable cascade attention when using microbatching (DBO) + if self.cascade_attn_enabled and not self.parallel_config.use_ubatching: + # Pre-compute cascade attention prefix lengths + cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( + num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], + scheduler_output.num_common_prefix_blocks, ) - num_tokens_padded = batch_desc.num_tokens - num_reqs_padded = ( - batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ( + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + cudagraph_stats, + ) = self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens_np, + max_num_scheduled_tokens=max_num_scheduled_tokens, + use_cascade_attn=cascade_attn_prefix_lens is not None, + num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs), + ) + + logger.debug( + "Running batch with cudagraph_mode: %s, batch_descriptor: %s, " + "should_ubatch: %s, num_tokens_across_dp: %s", + cudagraph_mode, + batch_desc, + should_ubatch, + num_tokens_across_dp, + ) + + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = ( + batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs + ) + ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( + should_ubatch, + num_scheduled_tokens_np, + num_tokens_padded, + num_reqs_padded, + self.parallel_config.num_ubatches, + ) + + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, + ) + + # True if any attention backend handles KV cache update separately + # from forward() (i.e., forward_includes_kv_cache_update=False). When true, + # slot_mappings must use padded dimensions to match the key/value tensors. + has_separate_kv_update = not all( + all( + g.backend.forward_includes_kv_cache_update + for g in self.attn_groups[id] ) - ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, - num_scheduled_tokens_np, - num_tokens_padded, - num_reqs_padded, + for id, spec in enumerate(self.kv_cache_config.kv_cache_groups) + if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec) + ) + pad_attn = cudagraph_mode == CUDAGraphMode.FULL + + if self.cache_config.mamba_cache_mode == "align": + # preprocess_mamba reads req_state.num_computed_tokens (CPU) + # to decide copy operations, so we must apply deferred + # corrections before it runs. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() + deferred_state_corrections_fn = None + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), ) + # preprocess_mamba resets num_accepted_tokens_cpu to 1 + # for requests whose state was copied to a new block. + # Re-sync to GPU so the mamba kernel reads from the + # correct initial state slot (init_token_idx = 0). + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.copy_to_gpu(num_reqs) - pad_attn = cudagraph_mode == CUDAGraphMode.FULL - - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 - ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices - - (attn_metadata, spec_decode_common_attn_metadata) = ( - self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs, - num_reqs_padded=num_reqs_padded if pad_attn else None, - max_query_len=max_num_scheduled_tokens, - ubatch_slices=ubatch_slices_attn, - logits_indices=logits_indices, - use_spec_decode=use_spec_decode, - num_scheduled_tokens=scheduler_output.num_scheduled_tokens, - cascade_attn_prefix_lens=cascade_attn_prefix_lens, - ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens_padded + if pad_attn or has_separate_kv_update + else num_tokens_unpadded, + num_reqs_padded=( + num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs + ), + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) + + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs, + num_reqs_padded=num_reqs_padded if pad_attn else None, + max_query_len=max_num_scheduled_tokens, + ubatch_slices=ubatch_slices_attn, + logits_indices=logits_indices, + use_spec_decode=use_spec_decode, + num_scheduled_tokens=scheduler_output.num_scheduled_tokens, + cascade_attn_prefix_lens=cascade_attn_prefix_lens, + slot_mappings=slot_mappings_by_group, ) + ) ( input_ids, @@ -3105,8 +3999,18 @@ def execute_model( # Mark KV scales as calculated after the first forward pass self.calculate_kv_scales = False + # Encoder-decoder models can only compile the pure decode steps where no + # encoder inputs are present. Use eager for the first pass. + num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs) + has_encoder_input = ( + self.model_config.is_encoder_decoder and num_encoder_reqs > 0 + ) + # Run the model. # Use persistent buffers for CUDA graphs. + # When spec decode is enabled, defer connector finalization + # (wait_for_save + clear metadata) until after draft model runs. + defer_kv_connector_finalize = self.speculative_config is not None with ( set_forward_context( attn_metadata, @@ -3116,9 +4020,14 @@ def execute_model( cudagraph_runtime_mode=cudagraph_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, + skip_compiled=has_encoder_input, ), record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + self.maybe_get_kv_connector_output( + scheduler_output, + defer_finalize=defer_kv_connector_finalize, + ) as kv_connector_output, ): model_output = self._model_forward( input_ids=input_ids, @@ -3148,11 +4057,12 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool( - hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output, ) - output.kv_connector_output = kv_connector_output - return output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -3179,6 +4089,7 @@ def execute_model( model_output_broadcast_data: dict[str, Any] = {} if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() + broadcasted = get_pp_group().broadcast_tensor_dict( model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 ) @@ -3195,20 +4106,27 @@ def execute_model( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) self.kv_connector_output = kv_connector_output + # Now the batch has been launched we can wait for corrections from the + # previous model forward without breaking async scheduling. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() + return None @managed_inference_mode() def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: - kv_connector_output = self.kv_connector_output - self.kv_connector_output = None - if self.execute_model_state is None: - # Nothing to do (PP non-final rank case), output isn't used. + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + # receive sampled token ids from the last PP rank. + if self.use_async_scheduling and get_pp_group().world_size > 1: + self._pp_receive_prev_sampled_token_ids_to_input_batch() if not kv_connector_output: return None # type: ignore[return-value] @@ -3232,6 +4150,7 @@ def sample_tokens( aux_hidden_states, ec_connector_output, cudagraph_stats, + slot_mappings, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -3245,6 +4164,22 @@ def sample_tokens( with record_function_or_nullcontext("gpu_model_runner: sample"): sampler_output = self._sample(logits, spec_decode_metadata) + self._update_states_after_model_execute( + sampler_output.sampled_token_ids, scheduler_output + ) + if self.use_async_scheduling: + pp = get_pp_group() + # For torchrun external_launcher PP mode with broadcast_pp_output=True, + # PP outputs have been broadcasted to all ranks at logits computation. + # Therefore, here is no need to send sampled token ids again in this case. + if not self.broadcast_pp_output and pp.world_size > 1 and pp.is_last_rank: + self._pp_broadcast_prev_sampled_token_ids( + sampler_output.sampled_token_ids + ) + + self._draft_token_ids = None + self._draft_token_req_ids = None + self.valid_sampled_token_count_gpu = None self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -3259,51 +4194,80 @@ def propose_draft_token_ids(sampled_token_ids): aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, + slot_mappings, ) + self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config - use_padded_batch_for_eagle = ( - spec_config is not None - and spec_config.use_eagle() - and not spec_config.disable_padded_drafter_batch - ) - effective_drafter_max_model_len = self.max_model_len - if effective_drafter_max_model_len is None: - effective_drafter_max_model_len = self.model_config.max_model_len - if ( - spec_config is not None - and spec_config.draft_model_config is not None - and spec_config.draft_model_config.max_model_len is not None - ): - effective_drafter_max_model_len = ( - spec_config.draft_model_config.max_model_len - ) - input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens - <= effective_drafter_max_model_len - ) - if use_padded_batch_for_eagle: - assert self.speculative_config is not None - assert isinstance(self.drafter, EagleProposer) - sampled_token_ids = sampler_output.sampled_token_ids - if input_fits_in_drafter: - # EAGLE speculative decoding can use the GPU sampled tokens + propose_drafts_after_bookkeeping = False + if spec_config is not None: + input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( + spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens + <= self.effective_drafter_max_model_len + ) + use_gpu_toks = ( + spec_config.use_eagle() + or spec_config.uses_draft_model() + or spec_config.uses_extract_hidden_states() + ) and not spec_config.disable_padded_drafter_batch + if use_gpu_toks: + # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. - propose_draft_token_ids(sampled_token_ids) - elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None - next_token_ids, valid_sampled_tokens_count = ( - self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - ) - self._copy_valid_sampled_token_count( - next_token_ids, valid_sampled_tokens_count + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, ) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + self.optimistic_seq_lens_cpu, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + elif ( + spec_config.use_ngram_gpu() + and not spec_config.disable_padded_drafter_batch + ): + assert isinstance(self.drafter, NgramProposerGPU) + sampled_token_ids = sampler_output.sampled_token_ids + if input_fits_in_drafter: + propose_draft_token_ids(sampled_token_ids) + elif self.valid_sampled_token_count_event is not None: + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count, _ = ( + self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + # Since we couldn't run the drafter, + # just use zeros for the draft tokens. + self._draft_token_ids = torch.zeros( + 1, device=self.device, dtype=torch.int32 + ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) + else: + propose_drafts_after_bookkeeping = input_fits_in_drafter with record_function_or_nullcontext("gpu_model_runner: bookkeep"): ( @@ -3323,25 +4287,38 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, ) - if ( - self.speculative_config - and not use_padded_batch_for_eagle - and input_fits_in_drafter - ): + if propose_drafts_after_bookkeeping: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) + # Finalize KV connector (wait_for_save + clear metadata) after + # draft model runs. Deferred from target model forward to allow + # draft model to also save its KV cache. + if spec_config is not None: + self.finalize_kv_connector() + with record_function_or_nullcontext("gpu_model_runner: eplb"): self.eplb_step() + + # self.kv_connector_output may be modified during drafting + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): + if self.routed_experts_initialized: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.slot_mapping) # noqa + else: + logger.error("RoutedExpertsCapturer not initialized.") + output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs @@ -3350,13 +4327,12 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_stats=cudagraph_stats, ) - # Advance IO step after the full inference cycle (forward + sampling) - # so that one step encompasses both execute_model and sample_tokens. + # FL: Advance IO step after the full inference cycle advance_io_step() - ### TODO(lms): abstract async schedule for all hardware if not self.use_async_scheduling: return output + with record_function_or_nullcontext( "gpu_model_runner: AsyncGPUModelRunnerOutput" ): @@ -3380,46 +4356,127 @@ def propose_draft_token_ids(sampled_token_ids): return async_output + def _pp_broadcast_prev_sampled_token_ids( + self, sampled_token_ids: torch.Tensor + ) -> None: + """Broadcast sampled token ids (GPU) from last PP stage""" + pp = get_pp_group() + assert pp.is_last_rank + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, ( + "PP+async expects sampled_token_ids to have shape [num_reqs, 1]" + ) + torch.distributed.broadcast( + sampled_token_ids, src=pp.rank, group=pp.device_group + ) + + def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None: + """Receive sampled token ids broadcast from last PP stage""" + pp = get_pp_group() + assert not pp.is_last_rank + num_reqs = self.input_batch.num_reqs + # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1]. + recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device) + torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group) + self.input_batch.prev_sampled_token_ids = recv + + # construct `prev_req_id_to_index` here so `_prepare_input_ids` + # can map req_id -> previous batch row + discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0] + discard_req_indices_set = set(discard_req_indices) + prev_req_id_to_index: dict[str, int] = {} + for i, req_id in enumerate(self.input_batch.req_ids): + if i in discard_req_indices_set: + continue + prev_req_id_to_index[req_id] = i + # PP+async scheduling: advance per-request local cached output length by + # appending a placeholder (-1) token id. + if (req_state := self.requests.get(req_id)) is not None: + req_state.output_token_ids.append(-1) + self.input_batch.prev_req_id_to_index = prev_req_id_to_index + def take_draft_token_ids(self) -> DraftTokenIds | None: - if self._draft_token_ids is None: + if not self.num_spec_tokens or not self._draft_token_req_ids: return None - req_ids = self.input_batch.req_ids - if isinstance(self._draft_token_ids, torch.Tensor): - draft_token_ids = self._draft_token_ids.tolist() - else: - draft_token_ids = self._draft_token_ids - self._draft_token_ids = None + draft_token_ids, req_ids = self._get_draft_token_ids_cpu() return DraftTokenIds(req_ids, draft_token_ids) + def _copy_draft_token_ids_to_cpu( + self, scheduler_output: "SchedulerOutput", zeros_only: bool = False + ) -> None: + # Check if we need to copy draft tokens to CPU. In async scheduling, + # we only copy when needed for structured output, penalties or bad_words. + if self.use_async_scheduling and not ( + scheduler_output.has_structured_output_requests + or self.input_batch.sampling_metadata.output_token_ids + ): + return + # We must also set the corresponding request ids. + self._draft_token_req_ids = self.input_batch.req_ids.copy() + + draft_token_ids: torch.Tensor = self._draft_token_ids + if not torch.is_tensor(draft_token_ids): + return + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_copy_stream is not None + assert self.draft_token_ids_cpu is not None + default_stream = torch.cuda.current_stream() + num_reqs = draft_token_ids.shape[0] + with torch.cuda.stream(self.draft_token_ids_copy_stream): + if not zeros_only: + # Trigger async copy of draft token ids to cpu. + self.draft_token_ids_copy_stream.wait_stream(default_stream) + self.draft_token_ids_cpu[:num_reqs].copy_( + draft_token_ids, non_blocking=True + ) + else: + # No copy needed, just zero-out cpu tensor. + self.draft_token_ids_cpu[:num_reqs] = 0 + self.draft_token_ids_event.record() + + def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids, self.input_batch.req_ids + req_ids = self._draft_token_req_ids + if req_ids is None: + return [], [] + assert self.draft_token_ids_event is not None + assert self.draft_token_ids_cpu is not None + self.draft_token_ids_event.synchronize() + return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor ) -> None: if self.valid_sampled_token_count_event is None: return - default_stream = current_platform.torch_device_fn.current_stream() + default_stream = torch.cuda.current_stream() # Initialize a new stream to overlap the copy operation with # prepare_input of draft model. with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore counts = valid_sampled_tokens_count counts_cpu = self.valid_sampled_token_count_cpu + assert counts_cpu is not None counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) self.valid_sampled_token_count_event.record() + if self.use_async_spec_decode: + # Stash for GPU-side correction in _prepare_inputs. + self.valid_sampled_token_count_gpu = valid_sampled_tokens_count self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) def _get_valid_sampled_token_count(self) -> list[int]: # Wait until valid_sampled_tokens_count is copied to cpu, prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids - if ( - self.valid_sampled_token_count_event is None - or prev_sampled_token_ids is None - ): + sampled_count_event = self.valid_sampled_token_count_event + if sampled_count_event is None or prev_sampled_token_ids is None: return [] counts_cpu = self.valid_sampled_token_count_cpu - self.valid_sampled_token_count_event.synchronize() + assert counts_cpu is not None + sampled_count_event.synchronize() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() def propose_draft_token_ids( @@ -3432,24 +4489,65 @@ def propose_draft_token_ids( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, + slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None if spec_config.method == "ngram": + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( sampled_token_ids, - self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs, + slot_mappings=slot_mappings, + ) + elif spec_config.use_ngram_gpu(): + assert isinstance(self.drafter, NgramProposerGPU) + ( + next_token_ids, + valid_sampled_tokens_count, + valid_sampled_token_ids_gpu, + ) = self.drafter.update_token_ids_ngram( + sampled_token_ids, + self.input_batch, + self.token_ids_gpu_tensor, + self.num_tokens_no_spec_gpu, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) + + batch_size = next_token_ids.shape[0] + + draft_token_ids, num_valid_draft_tokens = self.drafter.propose( + self.num_tokens_no_spec_gpu[:batch_size], + self.token_ids_gpu_tensor[:batch_size], + valid_sampled_token_ids_gpu, + valid_sampled_tokens_count, + ) + + # Cache valid draft counts for scheduler-side trimming. + self._num_valid_draft_tokens = num_valid_draft_tokens + + # Async D2H copy on a dedicated stream. + copy_num_valid_draft_tokens( + self._num_valid_draft_tokens_cpu, + self._num_valid_draft_tokens_copy_stream, + self._num_valid_draft_tokens_event, + self._num_valid_draft_tokens, + self.input_batch.num_reqs, ) elif spec_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) - draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) + draft_token_ids = self.drafter.propose( + self.input_batch, sampled_token_ids, slot_mappings=slot_mappings + ) elif spec_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -3474,9 +4572,41 @@ def propose_draft_token_ids( draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, + slot_mappings=slot_mappings, + ) + elif spec_config.uses_extract_hidden_states(): + assert isinstance(self.drafter, ExtractHiddenStatesProposer) + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor for " + "extract_hidden_states method." + ) + if not self.use_aux_hidden_state_outputs or aux_hidden_states is None: + raise ValueError( + "aux_hidden_states are required when using `extract_hidden_states`" + ) + target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] + + draft_token_ids = self.drafter.propose( + sampled_token_ids=sampled_token_ids, + target_hidden_states=target_hidden_states, + common_attn_metadata=common_attn_metadata, + slot_mappings=slot_mappings, + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + self.optimistic_seq_lens_cpu, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count ) - elif spec_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + + elif spec_config.use_eagle() or spec_config.uses_draft_model(): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -3503,7 +4633,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -3514,6 +4644,7 @@ def propose_draft_token_ids( next_token_ids, valid_sampled_tokens_count ) + num_rejected_tokens_gpu = None if spec_decode_metadata is None: token_indices_to_sample = None # input_ids can be None for multimodal models. @@ -3544,12 +4675,14 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[token_indices] else: - common_attn_metadata, token_indices_to_sample = ( - self.drafter.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count, - ) + ( + common_attn_metadata, + token_indices_to_sample, + num_rejected_tokens_gpu, + ) = self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -3563,7 +4696,7 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:total_num_tokens] - if self.supports_mm_inputs: + if self.supports_mm_inputs and self.drafter.supports_mm_inputs: mm_embed_inputs = self._gather_mm_embeddings( scheduler_output, shift_computed_tokens=1, @@ -3576,10 +4709,12 @@ def propose_draft_token_ids( target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, + token_indices_to_sample=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + slot_mappings=slot_mappings, ) return draft_token_ids @@ -3595,34 +4730,30 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) - def load_model(self, eep_scale_up: bool = False) -> None: + @instrument(span_name="Loading (GPU)") + def load_model(self, load_dummy_weights: bool = False) -> None: """ Args: - eep_scale_up: the model loading is for elastic EP scale up. + load_dummy_weights: load dummy weights instead of real weights. """ logger.info_once( "Starting to load model %s...", self.model_config.model, scope="global", ) - global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( - EplbState.get_eep_state(self.parallel_config) - if eep_scale_up - else (None, None, None) - ) if self.parallel_config.enable_eplb: self.eplb_state = EplbState(self.parallel_config, self.device) eplb_models = 0 - # IO dumper is only supported in eager mode. In graph mode (torch.compile) - # TorchDispatchMode and the module forward hooks are incompatible with - # Dynamo tracing, so IO dumping is silently skipped. + # FL: IO dumper init (skipped under Dynamo tracing) init_io_dump_from_env(getattr(self.model_config, "enforce_eager", False)) try: with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() + if load_dummy_weights: + self.load_config.load_format = "dummy" model_loader = get_model_loader(self.load_config) self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config @@ -3639,6 +4770,9 @@ def load_model(self, eep_scale_up: bool = False) -> None: and is_mixture_of_experts(self.drafter.model) and self.parallel_config.enable_eplb ): + assert not self.parallel_config.enable_elastic_ep, ( + "Elastic EP is not supported with drafter model." + ) spec_config = self.vllm_config.speculative_config assert spec_config is not None assert spec_config.draft_model_config is not None @@ -3646,17 +4780,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: "EPLB is enabled for drafter model %s.", spec_config.draft_model_config.model, ) - - global_expert_load = ( - global_expert_loads[eplb_models] - if global_expert_loads - else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) if self.eplb_state is None: self.eplb_state = EplbState( self.parallel_config, self.device @@ -3664,9 +4787,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.eplb_state.add_model( self.drafter.model, spec_config.draft_model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) eplb_models += 1 @@ -3686,90 +4806,84 @@ def load_model(self, eep_scale_up: bool = False) -> None: aux_layers, ) else: - aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + aux_layers = ( + self.model.get_eagle3_default_aux_hidden_state_layers() + ) self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - except Exception as e: - is_oom = 'out of memory' in str(e).lower() - - if is_oom: - msg = ( - "Failed to load model - not enough device memory. " - "Try lowering --gpu-memory-utilization to free memory for weights, " - "increasing --tensor-parallel-size, or using --quantization. " - "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " - "for more tips." - ) - combined_msg = f"{msg} (original error: {e})" - logger.error(combined_msg) + except torch.cuda.OutOfMemoryError as e: + msg = ( + "Failed to load model - not enough GPU memory. " + "Try lowering --gpu-memory-utilization to free memory for weights, " + "increasing --tensor-parallel-size, or using --quantization. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more tips." + ) + combined_msg = f"{msg} (original error: {e})" + logger.error(combined_msg) raise e logger.info_once( - "Model loading took %.4f GiB memory and %.6f seconds", - self.model_memory_usage / GiB_bytes, + "Model loading took %s GiB memory and %.6f seconds", + format_gib(self.model_memory_usage), time_after_load - time_before_load, scope="local", ) - - # IO dumper: register module paths and install module context hooks. - register_io_module_hooks(self.model) - - prepare_communication_buffer_for_model(self.model) - if (drafter := getattr(self, "drafter", None)) and ( - drafter_model := getattr(drafter, "model", None) - ): - prepare_communication_buffer_for_model(drafter_model) + if not load_dummy_weights: + prepare_communication_buffer_for_model(self.model) + # FL: register IO dumper module hooks + register_io_module_hooks(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) and mm_config is not None and mm_config.is_multimodal_pruning_enabled() ) + self.requires_sequential_video_encoding = hasattr( + self.get_model(), "requires_sequential_video_encoding" + ) # Temporary hack for dynamic res video w/o support for bs>1 yet - if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + if ( + is_mixture_of_experts(self.model) + and self.parallel_config.enable_eplb + and not load_dummy_weights + ): logger.info_once("EPLB is enabled for model %s.", self.model_config.model) - global_expert_load = ( - global_expert_loads[eplb_models] if global_expert_loads else None - ) - old_global_expert_indices = ( - old_global_expert_indices_per_model[eplb_models] - if old_global_expert_indices_per_model - else None - ) assert self.eplb_state is not None self.eplb_state.add_model( self.model, self.model_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, ) if self.eplb_state.is_async: - self.eplb_state.start_async_loop(rank_mapping=rank_mapping) - - # print(f"{self.vllm_config.compilation_config.mode=}") + self.eplb_state.start_async_loop() if ( self.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE - and supports_dynamo() ): backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.stock_torch_compile_count += 1 self.model.compile(fullgraph=True, backend=backend) return # for other compilation modes, cudagraph behavior is controlled by - # CudagraphWraper and CudagraphDispatcher of vllm. + # CudagraphWrapper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None - if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo: - self.model = GraphWrapper( + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.use_ubatching + ): + self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - elif self.parallel_config.enable_dbo: + elif self.parallel_config.use_ubatching: if cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( self.model, self.vllm_config, CUDAGraphMode.FULL, self.device @@ -3779,6 +4893,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, CUDAGraphMode.NONE, self.device ) + get_offloader().post_init() + def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: """Extract Eagle3 auxiliary layer indices from speculative config. @@ -3803,23 +4919,89 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." - ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), model_config=self.model_config) - - def save_tensorized_model( + def reload_weights( self, - tensorizer_config: "TensorizerConfig", + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + weights_path: str | None = None, + is_checkpoint_format: bool = True, ) -> None: - TensorizerLoader.save_model( - self.get_model(), - tensorizer_config=tensorizer_config, - model_config=self.model_config, + """ + Reload weights from a weights iterator or from disk + + :param weights_iterator: weights to load into model + :param weights_path: path to load weights from if weights_iterator is not + provided. Use path of original model if neither is provided. + :param is_checkpoint_format: set to False if weights have already been processed + into kernel format (repacking, renaming, etc.) + """ + # TODO(@kylesayrs): generalize to all runners and loaders + # argument validation + if weights_iterator is None and not is_checkpoint_format: + logger.warning( + "Reloading from disk means that weights will be in checkpoint format. " + "Please use `is_checkpoint_format=True` " + "to avoid weight reloading errors" + ) + + model = self.get_model() + weights_to_load = {name for name, _ in model.named_parameters()} + counter_before_reloading = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + if not hasattr(model_loader, "get_all_weights"): + raise NotImplementedError( + f"Model reloading with `{self.load_config.load_format}` format" + ) + + if weights_path is not None: + self.model_config.model = weights_path + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + # begin loading weights + logger.info_once("Reloading weights inplace...", scope="local") + load_device = ( + self.vllm_config.load_config.device or self.vllm_config.device_config.device + ) + with torch.device(load_device): + if is_checkpoint_format: + # load weights from checkpoint/ original model format + initialize_layerwise_reload(model) + loaded_weights = model.load_weights(weights_iterator) + finalize_layerwise_reload(model, self.model_config) + + else: + # load weights from kernel format + logger.warning_once( + "Reloading with `is_checkpoint_format=True` requires that " + "weights be in kernel format and already sharded", + scope="local", + ) + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) # TODO: buffers? + param.copy_(loaded_weight) + loaded_weights.add(name) + + # logging and validation + counter_after_reloading = time.perf_counter() + diff_seconds = counter_after_reloading - counter_before_reloading + logger.info_once( + "Reloading and processing weights took %.2f seconds", + diff_seconds, + scope="local", ) + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + logger.warning( + "Following weights were not loaded from checkpoint: %s", + weights_not_loaded, + ) def _get_prompt_logprobs_dict( self, @@ -3900,7 +5082,7 @@ def _get_prompt_logprobs_dict( # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) - token_ids, logprobs, ranks = self.sampler.gather_logprobs( + token_ids, logprobs, ranks, _ = self.sampler.gather_logprobs( logprobs, num_prompt_logprobs, tgt_token_ids ) @@ -3932,7 +5114,7 @@ def _get_nans_in_logits( ) -> dict[str, int]: try: if logits is None: - return dict.fromkeys(self.input_batch.req_ids, 0) + return {req_id: 0 for req_id in self.input_batch.req_ids} num_nans_in_logits = {} num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() @@ -4000,22 +5182,22 @@ def _get_mm_dummy_batch( """Dummy data for profiling and precompiling multimodal models.""" assert self.mm_budget is not None - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_model_len, + # Don't use `max_items_per_batch` here to avoid redundant computation + dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs( + self.model_config, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) - dummy_mm_data = dummy_decoder_data.multi_modal_data + dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0] - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] - dummy_mm_items = [dummy_mm_item] * max_items_per_batch + # We use the cache so that the item is saved to the cache, + # but not read from the cache + assert dummy_mm_item is not None, "Item should not already be cached" return next( - mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, + mm_kwargs_batch + for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs( + [(modality, dummy_mm_item)] * max_items_per_batch, device=self.device, pin_memory=self.pin_memory, ) @@ -4033,12 +5215,13 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, - activate_lora: bool = False, is_graph_capturing: bool = False, + num_active_loras: int = 0, + profile_seq_lens: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the - graph for the model. + CUDA graph for the model. Args: num_tokens: Number of tokens to run the dummy forward pass. @@ -4057,11 +5240,21 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run - activate_lora: If False, dummy_run is performed without LoRAs. + num_active_loras: Number of distinct active LoRAs to capture for. + LoRA is activated when num_active_loras > 0. + profile_seq_lens: If provided, use this value for seq_lens instead + of max_query_len. Used to profile attention workspace that + scales with context length. """ + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # The current dummy run only covers LM execution, so we can skip it. + # mm encoder dummy run may need to add in the future. + return torch.tensor([]), torch.tensor([]) + assert ( cudagraph_runtime_mode is None - or cudagraph_runtime_mode.valid_runtime_modes() + or cudagraph_runtime_mode.is_valid_runtime_mode() ) # If cudagraph_mode.decode_mode() == FULL and @@ -4082,7 +5275,7 @@ def _dummy_run( # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. - assert num_tokens <= self.scheduler_config.max_num_batched_tokens + assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: assert not uniform_decode @@ -4133,7 +5326,10 @@ def _dummy_run( # `force_has_lora` is used for cudagraph capture; because LoRA is # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + force_has_lora=num_active_loras > 0, + # `force_num_active_loras` is used for cudagraph capture; because we + # need to capture graphs for specific num_active_loras counts + force_num_active_loras=num_active_loras, ) ) @@ -4150,51 +5346,88 @@ def _dummy_run( batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs ) ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices( - should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded + should_ubatch, + num_scheduled_tokens, + num_tokens_padded, + num_reqs_padded, + self.vllm_config.parallel_config.num_ubatches, + ) + logger.debug( + "ubatch_slices: %s, ubatch_slices_padded: %s", + ubatch_slices, + ubatch_slices_padded, ) attn_metadata: PerLayerAttnMetadata | None = None - # If force_attention is True, we always capture attention. Otherwise, - # it only happens for cudagraph_runtime_mode=FULL. - if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: - if create_mixed_batch: - # In the mixed batch mode (used for FI warmup), we use - # shorter sequence lengths to run faster. - # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] - else: - seq_lens = max_query_len # type: ignore[assignment] - self.seq_lens.np[:num_reqs] = seq_lens - self.seq_lens.np[num_reqs:] = 0 - self.seq_lens.copy_to_gpu() + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs_padded, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=ubatch_slices_padded, + ) - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) - self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens - self.query_start_loc.copy_to_gpu() + # _dummy_run shares pinned CPU buffers (seq_lens, query_start_loc, + # etc.) with execute_model. It must participate in the same event + # protocol so that back-to-back dummy/real steps don't overwrite + # pinned memory while a prior non_blocking H2D DMA is still reading. + with self.synchronize_input_prep(): + # If force_attention is True, we always capture attention. + # Otherwise, it only happens for cudagraph_runtime_mode=FULL. + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if profile_seq_lens is not None: + seq_lens = profile_seq_lens # type: ignore[assignment] + elif create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = torch.tensor( # type: ignore[assignment] + [1] * num_decode_tokens + [num_prefill_tokens + 1], + dtype=torch.int, + ) + else: + seq_lens = max_query_len # type: ignore[assignment] + self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True) - pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL - attn_metadata, _ = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_reqs=num_reqs_padded, - max_query_len=max_query_len, - ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, - for_cudagraph_capture=is_graph_capturing, - ) + cum_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np + ) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + # Sync block table CPU->GPU so cleared rows from + # remove_request() are visible to the attention metadata + # builder. Without this, stale block IDs from finished + # requests can corrupt Mamba state. + self.input_batch.block_table.commit_block_table(num_reqs_padded) + + pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices), + for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, + use_spec_decode=self.speculative_config is not None, + ) with self.maybe_dummy_run_with_lora( self.lora_config, num_scheduled_tokens, num_sampled_tokens, - activate_lora, remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: - input_ids = None - inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] + input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded) + model_kwargs = { **model_kwargs, **self._dummy_mm_kwargs(num_reqs), @@ -4202,7 +5435,7 @@ def _dummy_run( elif self.enable_prompt_embeds: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] - model_kwargs = self._init_model_kwargs(num_tokens_padded) + model_kwargs = self._init_model_kwargs() else: input_ids = self.input_ids.gpu[:num_tokens_padded] inputs_embeds = None @@ -4212,7 +5445,7 @@ def _dummy_run( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions.gpu[:num_tokens_padded] + positions = self.positions[:num_tokens_padded] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -4237,6 +5470,7 @@ def _dummy_run( num_tokens_padded = ubatch_slices_padded[0].num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_padded + with ( self.maybe_randomize_inputs(input_ids, inputs_embeds), set_forward_context( @@ -4247,6 +5481,7 @@ def _dummy_run( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_desc, ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, ), ): outputs = self.model( @@ -4262,8 +5497,16 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + or self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance( + self.drafter, + EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, + ) + assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. @@ -4282,13 +5525,17 @@ def _dummy_run( # lora cases when cudagraph_specialize_lora is enabled. This is a # short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/28334 - if self.compilation_config.cudagraph_specialize_lora and activate_lora: + if ( + self.compilation_config.cudagraph_specialize_lora + and num_active_loras > 0 + ): use_cudagraphs = False self.drafter.dummy_run( num_tokens, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, + slot_mappings=slot_mappings, ) # We register layerwise NVTX hooks here after the first dynamo tracing is @@ -4326,6 +5573,12 @@ def _dummy_sampler_run( # The dummy hidden states may contain special values, # like `inf` or `nan`. # To avoid breaking the sampler, we use a random tensor here instead. + + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # MM Encoder only model no need to run sampler. + return torch.tensor([]) + hidden_states = torch.rand_like(hidden_states) logits = self.model.compute_logits(hidden_states) @@ -4400,24 +5653,21 @@ def _dummy_pooler_run_task( max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req) + num_scheduled_tokens_np[-1] += num_tokens % num_reqs + assert np.sum(num_scheduled_tokens_np) == num_tokens + assert len(num_scheduled_tokens_np) == num_reqs req_num_tokens = num_tokens // num_reqs - dummy_prompt_lens = torch.tensor( - num_scheduled_tokens_list, - device="cpu", - ) + dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np) dummy_token_ids = torch.zeros( (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) - dummy_pooling_params.verify(task=task, model_config=self.model_config) + dummy_pooling_params.verify(self.model_config) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) @@ -4429,7 +5679,7 @@ def _dummy_pooler_run_task( ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, + num_scheduled_tokens_np, seq_lens_cpu=dummy_prompt_lens, device=hidden_states.device, ) @@ -4454,6 +5704,11 @@ def _dummy_pooler_run( self, hidden_states: torch.Tensor, ) -> PoolerOutput: + mm_config = self.vllm_config.model_config.multimodal_config + if mm_config and mm_config.mm_encoder_only: + # MM Encoder only model not need to run pooler. + return torch.tensor([]) + # Find the task that has the largest output for subsequent steps supported_pooling_tasks = self.get_supported_pooling_tasks() @@ -4469,7 +5724,7 @@ def _dummy_pooler_run( for task in supported_pooling_tasks: # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = sum(o.nbytes for o in output) + output_size[task] = sum(o.nbytes for o in output if o is not None) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0] @@ -4489,40 +5744,51 @@ def profile_run(self) -> None: assert mm_budget is not None if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ - dummy_modality - ] + if not mm_budget.mm_max_toks_per_item: + # All modality limits are 0 — embedding-only mode. + # Budget is non-zero for embedding storage, but + # there's no encoder to profile. + logger.info( + "Skipping encoder profiling for embedding-only " + "mode (all modality limits=0 with " + "enable_mm_embeds=True).", + ) + else: + # NOTE: Currently model is profiled with a single + # non-text modality with the max possible input + # tokens even when it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ + dummy_modality + ] - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + logger.info_once( + "Encoder cache will be initialized with a " + "budget of %s tokens, and profiled with " + "%s %s items of the maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + scope="local", + ) - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) + # Run multimodal encoder. + dummy_encoder_outputs = self.model.embed_multimodal( + **batched_dummy_mm_inputs + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - for i, output in enumerate(dummy_encoder_outputs): - self.encoder_cache[f"tmp_{i}"] = output + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run( @@ -4540,6 +5806,172 @@ def profile_run(self) -> None: self.encoder_cache.clear() gc.collect() + def _init_minimal_kv_cache_for_profiling(self) -> None: + from vllm.v1.core.kv_cache_utils import ( + get_kv_cache_config_from_groups, + get_kv_cache_groups, + ) + + kv_cache_spec = self.get_kv_cache_spec() + kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) + min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 + + # Temporarily change num_gpu_blocks_override to allocate a minimal KV cache + saved_override = self.cache_config.num_gpu_blocks_override + self.cache_config.num_gpu_blocks_override = min_blocks + minimal_config = get_kv_cache_config_from_groups( + self.vllm_config, kv_cache_groups, available_memory=0 + ) + self.cache_config.num_gpu_blocks_override = saved_override + + self.initialize_kv_cache(minimal_config) + self.cache_config.num_gpu_blocks = minimal_config.num_blocks + + logger.debug("Initialized minimal KV cache for CUDA graph profiling") + + @staticmethod + @contextmanager + def _freeze_gc(): + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + gc.collect() + + def _cleanup_profiling_kv_cache(self) -> None: + torch.accelerator.synchronize() + if hasattr(self, "kv_caches") and self.kv_caches: + for i in range(len(self.kv_caches)): + self.kv_caches[i] = None # type: ignore + self.kv_caches.clear() + if hasattr(self, "cross_layers_kv_cache"): + self.cross_layers_kv_cache = None + self.cross_layers_attn_backend = None + if hasattr(self, "attn_groups"): + self.attn_groups.clear() + if hasattr(self, "kv_cache_config"): + delattr(self, "kv_cache_config") + self.cache_config.num_gpu_blocks = None + + for layer in self.compilation_config.static_forward_context.values(): + if hasattr(layer, "kv_cache"): + kv_cache = layer.kv_cache + layer.kv_cache = ( + torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else [] + ) + + gc.collect() + torch.accelerator.empty_cache() + + logger.debug("Cleaned up profiling KV cache and CUDA graphs") + + @torch.inference_mode() + def profile_cudagraph_memory(self) -> int: + with set_current_vllm_config(self.vllm_config): + self._init_minimal_kv_cache_for_profiling() + + saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured + + capture_descs = self.cudagraph_dispatcher.get_capture_descs() + + total_graphs = sum(len(descs) for _, descs in capture_descs) + if total_graphs == 0: + logger.debug("No CUDA graphs will be captured, skipping profiling") + self._cleanup_profiling_kv_cache() + return 0 + + logger.info( + "Profiling CUDA graph memory: %s", + ", ".join( + f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})" + for mode, descs in capture_descs + if descs + ), + ) + + # Use a temporary pool for profiling to avoid fragmentation in the main pool. + profiling_pool = current_platform.graph_pool_handle() + original_pools: dict[int, Any] = {} + for instance in list(CUDAGraphWrapper._all_instances): + original_pools[id(instance)] = instance.graph_pool + instance.graph_pool = profiling_pool + + set_cudagraph_capturing_enabled(True) + with self._freeze_gc(), graph_capture(device=self.device): + shared_memory_estimate = {} + per_graph_estimate = {} + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + + for mode, descs in capture_descs: + profile_descs = descs[:2] + mem_samples: list[int] = [] + + for i, desc in enumerate(profile_descs): + mem_before = torch.cuda.mem_get_info()[0] + self._warmup_and_capture( + desc, + cudagraph_runtime_mode=mode, + profile_seq_lens=( + min( + self.max_model_len, + self.max_num_tokens // desc.num_tokens, + ) + if mode == CUDAGraphMode.FULL and i == 0 + else None + ), + ) + torch.accelerator.synchronize() + free_after = torch.cuda.mem_get_info()[0] + mem_samples.append(mem_before - free_after) + + first_capture = mem_samples[0] + # Use at least 1 MiB per graph for driver overhead + per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20) + + shared_memory_estimate[mode] = first_capture + per_graph_estimate[mode] = per_graph * (len(descs) - 1) + + logger.debug( + "Estimated %s CUDA graph memory: " + "%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph", + mode.name, + first_capture / (1 << 20), + len(descs), + per_graph / (1 << 20), + ) + + set_cudagraph_capturing_enabled(False) + CUDAGraphWrapper.clear_all_graphs() + for instance in list(CUDAGraphWrapper._all_instances): + if id(instance) in original_pools: + instance.graph_pool = original_pools[id(instance)] + for key_set in self.cudagraph_dispatcher.cudagraph_keys.values(): + key_set.clear() + self.cudagraph_dispatcher.keys_initialized = False + self.maybe_remove_all_loras(self.lora_config) + self._cleanup_profiling_kv_cache() + compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured + + # FULL and PIECEWISE graphs share the global pool at runtime and are + # never replayed concurrently, so the pool overlays their memory. + # Take the max to avoid double-counting the overlap. + total_estimate = max(shared_memory_estimate.values()) + sum( + per_graph_estimate.values() + ) + logger.info( + "Estimated CUDA graph memory: %.2f GiB total", + total_estimate / (1 << 30), + ) + + return int(total_estimate) + + @instrument(span_name="Capture model") def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( @@ -4548,79 +5980,61 @@ def capture_model(self) -> int: ) return 0 + # Initialize encoder CUDA graph manager if enabled. + # Use get_model() to unwrap CUDAGraphWrapper/UBatchWrapper, + # because @runtime_checkable Protocol isinstance() checks do not + # work through __getattr__ forwarding. + if ( + self.compilation_config.cudagraph_mm_encoder + and self.supports_mm_inputs + and self.encoder_cudagraph_manager is None + ): + from vllm.model_executor.models.interfaces import ( + SupportsEncoderCudaGraph, + supports_encoder_cudagraph, + ) + from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager + + raw_model = self.get_model() + if supports_encoder_cudagraph(raw_model): + self.encoder_cudagraph_manager = EncoderCudaGraphManager( + vllm_config=self.vllm_config, + device=self.device, + dtype=self.dtype, + model=cast(SupportsEncoderCudaGraph, raw_model), + ) + logger.info("Initialized EncoderCudaGraphManager for vision encoder") + compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() - @contextmanager - def freeze_gc(): - # Optimize garbage collection during CUDA graph capture. - # Clean up, then freeze all remaining objects from being included - # in future collections. - gc.collect() - should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC - if should_freeze: - gc.freeze() - try: - yield - finally: - if should_freeze: - gc.unfreeze() - gc.collect() - # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) - with freeze_gc(), graph_capture(device=self.device): - start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] - cudagraph_mode = self.compilation_config.cudagraph_mode - assert cudagraph_mode is not None - - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] - if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - # make sure we capture the largest batch size first - compilation_cases = list( - product(reversed(self.cudagraph_batch_sizes), lora_cases) - ) - self._capture_cudagraphs( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, - ) + with self._freeze_gc(), graph_capture(device=self.device): + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] - # Capture full cudagraph for uniform decode batches if we - # don't already have full mixed prefill-decode cudagraphs. - if ( - cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() - ): - max_num_tokens = ( - self.scheduler_config.max_num_seqs * self.uniform_decode_query_len - ) - decode_cudagraph_batch_sizes = [ - x - for x in self.cudagraph_batch_sizes - if max_num_tokens >= x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - product(reversed(decode_cudagraph_batch_sizes), lora_cases) - ) + for ( + runtime_mode, + batch_descs, + ) in self.cudagraph_dispatcher.get_capture_descs(): self._capture_cudagraphs( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + batch_descriptors=batch_descs, + cudagraph_runtime_mode=runtime_mode, ) + torch.accelerator.synchronize() + + # Capture encoder CUDA graphs if enabled + if self.encoder_cudagraph_manager is not None: + self.encoder_cudagraph_manager.capture() + + torch.accelerator.synchronize() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] - current_platform.torch_device_fn.synchronize() - end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because @@ -4628,6 +6042,9 @@ def freeze_gc(): # after here. set_cudagraph_capturing_enabled(False) + torch.accelerator.synchronize() + torch.accelerator.empty_cache() + # Lock workspace to prevent resizing during execution. # Max workspace sizes should have been captured during warmup/profiling. lock_workspace() @@ -4644,21 +6061,59 @@ def freeze_gc(): ) return cuda_graph_size + def _warmup_and_capture( + self, + desc: BatchDescriptor, + cudagraph_runtime_mode: CUDAGraphMode, + profile_seq_lens: int | None = None, + allow_microbatching: bool = False, + num_warmups: int | None = None, + ): + if num_warmups is None: + num_warmups = self.compilation_config.cudagraph_num_of_warmups + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + for _ in range(num_warmups): + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + ) + self._dummy_run( + desc.num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=desc.uniform, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + num_active_loras=desc.num_active_loras, + is_graph_capturing=True, + profile_seq_lens=profile_seq_lens, + ) + def _capture_cudagraphs( self, - compilation_cases: list[tuple[int, bool]], + batch_descriptors: list[BatchDescriptor], cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool, ): assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE - and cudagraph_runtime_mode.valid_runtime_modes() + and cudagraph_runtime_mode.is_valid_runtime_mode() ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + if not batch_descriptors: + return + + uniform_decode = batch_descriptors[0].uniform + # Only rank 0 should print progress bar during capture if is_global_first_rank(): - compilation_cases = tqdm( - compilation_cases, + batch_descriptors = tqdm( + batch_descriptors, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", @@ -4667,49 +6122,27 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for batch_desc in batch_descriptors: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph allow_microbatching = ( - self.parallel_config.enable_dbo + self.parallel_config.use_ubatching and cudagraph_runtime_mode == CUDAGraphMode.FULL and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, - num_tokens=num_tokens, + num_tokens=batch_desc.num_tokens, uniform_decode=uniform_decode, ) ) - - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - ) - self._dummy_run( - num_tokens, + self._warmup_and_capture( + batch_desc, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, - activate_lora=activate_lora, - is_graph_capturing=True, ) + torch.accelerator.synchronize() self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -4808,14 +6241,22 @@ def initialize_metadata_builders( if kv_cache_group_id < len(kernel_block_sizes) else None, num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, + if not self.parallel_config.use_ubatching + else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) # Note (tdoublep): do this *after* constructing builders, # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() + # Initialize drafter attention backend + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) + self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], @@ -4959,10 +6400,22 @@ def _check_and_update_cudagraph_mode( self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size ) - capture_sizes = self.compilation_config.cudagraph_capture_sizes - self.cudagraph_batch_sizes = ( - capture_sizes if capture_sizes is not None else [] + + # If the model has Mamba layers and cudagraph mode includes FULL + # decode, cap cudagraph capture sizes to the number of available + # Mamba cache blocks. Each decode request needs one conv_state + # cache line, so capture batch sizes cannot exceed num_blocks. + # Only FULL decode graphs are affected because PIECEWISE captures + # run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update). + # See: https://github.com/vllm-project/vllm/issues/34094 + if cudagraph_mode.has_full_cudagraphs(): + has_mamba = any( + isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups ) + if has_mamba and self.kv_cache_config is not None: + self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache( + self.kv_cache_config.num_blocks + ) # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. @@ -4971,6 +6424,14 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + # Initialize drafter's cudagraph dispatcher if using spec decode. + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer) + self.drafter.initialize_cudagraph_keys(cudagraph_mode) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups. @@ -4991,120 +6452,75 @@ def calculate_reorder_batch_threshold(self) -> None: return self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] - @staticmethod - def select_common_block_size( - kv_manager_block_size: int, attn_groups: list[AttentionGroup] - ) -> int: - """ - Select a block size that is supported by all backends and is a factor of - kv_manager_block_size. - - If kv_manager_block_size is supported by all backends, return it directly. - Otherwise, return the max supported size. - - Args: - kv_manager_block_size: Block size of KV cache - attn_groups: List of attention groups - - Returns: - The selected block size - - Raises: - ValueError: If no valid block size found - """ - - def block_size_is_supported( - backends: list[type[AttentionBackend]], block_size: int - ) -> bool: - """ - Check if the block size is supported by all backends. - """ - for backend in backends: - is_supported = False - for supported_size in backend.get_supported_kernel_block_sizes(): - if isinstance(supported_size, int): - if block_size == supported_size: - is_supported = True - elif isinstance(supported_size, MultipleOf): - if block_size % supported_size.base == 0: - is_supported = True - else: - raise ValueError(f"Unknown supported size: {supported_size}") - if not is_supported: - return False - return True - - backends = [group.backend for group in attn_groups] - - # Case 1: if the block_size of kv cache manager is supported by all backends, - # return it directly - if block_size_is_supported(backends, kv_manager_block_size): - return kv_manager_block_size - - # Case 2: otherwise, the block_size must be an `int`-format supported size of - # at least one backend. Iterate over all `int`-format supported sizes in - # descending order and return the first one that is supported by all backends. - # Simple proof: - # If the supported size b is in MultipleOf(x_i) format for all attention - # backends i, and b a factor of kv_manager_block_size, then - # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will - # return kv_manager_block_size in case 1. - all_int_supported_sizes = set( - supported_size - for backend in backends - for supported_size in backend.get_supported_kernel_block_sizes() - if isinstance(supported_size, int) - ) - - for supported_size in sorted(all_int_supported_sizes, reverse=True): - if kv_manager_block_size % supported_size != 0: - continue - if block_size_is_supported(backends, supported_size): - return supported_size - raise ValueError(f"No common block size for {kv_manager_block_size}. ") - def may_reinitialize_input_batch( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] ) -> None: """ Re-initialize the input batch if the block sizes are different from - `[self.cache_config.block_size]`. This usually happens when there - are multiple KV cache groups. + what it was originally created with. This happens when the final + block size (determined after model loading) differs from the + placeholder used during __init__, or when there are multiple + KV cache groups. Args: kv_cache_config: The KV cache configuration. kernel_block_sizes: The kernel block sizes for each KV cache group. """ - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec) - ] + block_sizes = [] + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for kv_cache_group in kv_cache_config.kv_cache_groups: + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + block_size = kv_cache_group.kv_cache_spec.block_size + block_sizes.append(block_size) + max_num_blocks_per_req = cdiv( + max_model_len, block_size * get_total_cp_world_size() + ) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + max_num_blocks_per_req = ( + max_num_blocks_per_req + if self.cache_config.enable_prefix_caching + else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + max_num_blocks.append(max_num_blocks_per_req) - if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ - self.cache_config.block_size - ]: - assert self.cache_config.cpu_offload_gb == 0, ( + if ( + block_sizes != self._init_block_sizes + or kernel_block_sizes != self._init_kernel_block_sizes + ): + assert self.offload_config.uva.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "for more details." ) + self._init_block_sizes = block_sizes + self._init_kernel_block_sizes = kernel_block_sizes self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, - num_speculative_tokens=self.num_spec_tokens, ) + assert self._init_block_sizes == block_sizes, ( + f"InputBatch block_sizes {self._init_block_sizes} != " + f"kv_cache block_sizes {block_sizes}" + ) + assert self._init_kernel_block_sizes == kernel_block_sizes, ( + f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} " + f"!= kv_cache kernel_block_sizes {kernel_block_sizes}" + ) + def _allocate_kv_cache_tensors( self, kv_cache_config: KVCacheConfig ) -> dict[str, torch.Tensor]: @@ -5146,49 +6562,6 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: for attn_groups in self.attn_groups: yield from attn_groups - def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]: - """ - Generate kernel_block_sizes that matches each block_size. - - For attention backends that support virtual block splitting, - use the supported block sizes from the backend. - For other backends (like Mamba), use the same block size (no splitting). - - Args: - kv_cache_config: The KV cache configuration. - - Returns: - list[int]: List of kernel block sizes for each cache group. - """ - kernel_block_sizes = [] - for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): - kv_cache_spec = kv_cache_group.kv_cache_spec - if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs): - # All layers in the UniformTypeKVCacheSpecs have the same type, - # Pick an arbitrary one to dispatch. - kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values())) - if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): - continue - elif isinstance(kv_cache_spec, AttentionSpec): - # This is an attention backend that supports virtual - # block splitting. Get the supported block sizes from - # all backends in the group. - attn_groups = self.attn_groups[kv_cache_gid] - kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size - selected_kernel_size = self.select_common_block_size( - kv_manager_block_size, attn_groups - ) - kernel_block_sizes.append(selected_kernel_size) - elif isinstance(kv_cache_spec, MambaSpec): - # This is likely Mamba or other non-attention cache, - # no splitting. - kernel_block_sizes.append(kv_cache_spec.block_size) - else: - raise NotImplementedError( - f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" - ) - return kernel_block_sizes - def _reshape_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, @@ -5252,7 +6625,8 @@ def _reshape_kv_cache_tensors( ) # Maintain original KV shape view. inv_order = [ - kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + kv_cache_stride_order.index(i) + for i in range(len(kv_cache_stride_order)) ] kv_caches[layer_name] = ( kv_cache_raw_tensors[layer_name] @@ -5411,6 +6785,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) @@ -5419,7 +6794,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # backends for that group only supports block_size 64, we will return # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # tokens each. - kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + kernel_block_sizes = prepare_kernel_block_sizes( + kv_cache_config, self.attn_groups + ) + self._kernel_block_sizes = kernel_block_sizes # create metadata builders self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) @@ -5430,8 +6808,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if ( + self.speculative_config + and self.speculative_config.uses_extract_hidden_states() + ): + assert isinstance(self.drafter, ExtractHiddenStatesProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) @@ -5447,6 +6828,58 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + def _get_attention_kv_cache_gid(self) -> int: + """Find the KV cache group index for attention layers.""" + for gid, group in enumerate(self.kv_cache_config.kv_cache_groups): + if isinstance(group.kv_cache_spec, AttentionSpec): + return gid + return 0 + + def init_routed_experts_capturer(self): + logger.info( + "Initializing routed experts capturer, enable_return_routed_experts: %s", + self.model_config.enable_return_routed_experts, + ) + routed_experts_capturer = RoutedExpertsCapturer.create() + self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() + min_block_size = min( + [ + group.kv_cache_spec.block_size + for group in self.kv_cache_config.kv_cache_groups + ] + ) + num_groups = len(self.kv_cache_config.kv_cache_groups) + self.max_num_kv_tokens = ( + self.kv_cache_config.num_blocks // num_groups + ) * min_block_size + dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size + pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size + if pcp_size * dcp_size > 1: + self.max_num_kv_tokens *= pcp_size * dcp_size + + routed_experts_capturer.init_buffer( + max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_num_kv_tokens=self.max_num_kv_tokens, + vllm_config=self.vllm_config, + ) + self._bind_routed_experts_capturer(routed_experts_capturer) + self.routed_experts_initialized = True + + def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + from vllm.model_executor.layers.fused_moe.router.base_router import ( + BaseRouter, + ) + + for module in self.compilation_config.static_forward_context.values(): + if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter): + layer_id = module.layer_id + + def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer): + _capturer.capture(_layer_id, topk_ids) + + module.router.set_capture_fn(_capture_fn) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. @@ -5481,7 +6914,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ - if has_ec_transfer() and get_ec_transfer().is_producer: + if has_ec_transfer() and not get_ec_transfer().is_consumer: return {} kv_cache_spec: dict[str, KVCacheSpec] = {} layer_type = cast(type[Any], AttentionLayerBase) @@ -5519,3 +6952,79 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() + + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """ + Get encoder timing stats for all requests and clear the registry. + + Returns: + Dictionary mapping request_id to stats dict. + """ + with self._encoder_timing_lock: + stats = { + req_id: stats_obj.to_dict() + for req_id, stats_obj in self.encoder_timing_registry.items() + } + self.encoder_timing_registry.clear() + return stats + + @contextmanager + def timed_encoder_operation( + self, + should_time: bool, + group_lora_refs: list[tuple[str, Any]], + current_item_idx: int, + num_items: int, + ): + """ + Context manager to time encoder forward operations. + + Args: + should_time: Whether timing is enabled + group_lora_refs: Full list of (request_id, pos_info) tuples + current_item_idx: Starting index for this group + num_items: Number of items in this group + """ + if not should_time: + yield + return + + group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items] + group_request_ids = {req_id for req_id, _ in group_refs} + + torch.accelerator.synchronize() + start_time = time.perf_counter() + + try: + yield + finally: + torch.accelerator.synchronize() + elapsed = time.perf_counter() - start_time + + per_request_time = elapsed / max(len(group_request_ids), 1) + + with self._encoder_timing_lock: + for req_id in group_request_ids: + if req_id not in self.encoder_timing_registry: + self.encoder_timing_registry[req_id] = EncoderTimingStats() + + stats = self.encoder_timing_registry[req_id] + stats.encoder_forward_secs += per_request_time + stats.num_encoder_calls += 1 + + +@dataclass +class EncoderTimingStats: + """Per-request timing statistics for encoder forward pass.""" + + encoder_forward_secs: float = 0.0 + """Time spent in vision encoder forward pass (seconds).""" + + num_encoder_calls: int = 0 + """Number of times encoder was called for this request.""" + + def to_dict(self) -> dict[str, float | int]: + return { + "encoder_forward_secs": self.encoder_forward_secs, + "num_encoder_calls": self.num_encoder_calls, + } diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index de6dbb0f..4cd7631f 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -1,5 +1,5 @@ # Copyright (c) 2025 BAAI. All rights reserved. -# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py # Below is the original copyright: # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project @@ -16,7 +16,7 @@ import torch.distributed import torch.nn as nn -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config from vllm.config.compilation import CompilationMode from vllm.distributed import ( ensure_model_parallel_initialized, @@ -47,7 +47,7 @@ def kernel_warmup(worker): ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed +from vllm.utils.torch_utils import set_random_seed from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.platforms import current_platform from vllm.profiler.wrapper import TorchProfilerWrapper @@ -206,11 +206,6 @@ def __init__( distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} @@ -429,10 +424,10 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. - def load_model(self) -> None: - eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + def load_model(self, *, load_dummy_weights: bool = False) -> None: ### TODO(lms): support manages a memory pool for device tensors. - self.model_runner.load_model(eep_scale_up=eep_scale_up) + with set_current_vllm_config(self.vllm_config): + self.model_runner.load_model(load_dummy_weights=load_dummy_weights) # with self._maybe_get_memory_pool_context(tag="weights"): # self.model_runner.load_model(eep_scale_up=eep_scale_up) @@ -546,6 +541,11 @@ def get_kv_connector_handshake_metadata(self) -> dict | None: def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() + def update_max_model_len(self, max_model_len: int) -> None: + """Update max_model_len after auto-fit to GPU memory.""" + self.model_config.max_model_len = max_model_len + self.model_runner.max_model_len = max_model_len + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" # Init kv cache connector here, because it requires @@ -565,7 +565,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: # context = nullcontext() self.model_runner.initialize_kv_cache(kv_cache_config) - def compile_or_warm_up_model(self) -> None: + def compile_or_warm_up_model(self) -> float: warmup_sizes = [] if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: # warm up sizes that are not in cudagraph capture sizes, @@ -688,15 +688,24 @@ def compile_or_warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + return self.compilation_config.compilation_time + def reset_mm_cache(self) -> None: self.model_runner.reset_mm_cache() + def reset_encoder_cache(self) -> None: + self.model_runner.reset_encoder_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model() def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.model_runner.get_supported_tasks() + def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]: + """Get encoder timing stats from model runner.""" + return self.model_runner.get_encoder_timing_stats() + def annotate_profile(self, scheduler_output): # add trace annotation so that we can easily distinguish # new/cached request numbers in each iteration @@ -789,7 +798,7 @@ def execute_model( def take_draft_token_ids(self) -> Optional[DraftTokenIds]: return self.model_runner.take_draft_token_ids() - def profile(self, is_start: bool = True): + def profile(self, is_start: bool = True, profile_prefix: str | None = None): if self.profiler is None: raise RuntimeError("Profiling is not enabled.") if is_start: From a4f63519da8b9af5c759a7b778eeecac2474278c Mon Sep 17 00:00:00 2001 From: cyberpioneer Date: Tue, 31 Mar 2026 08:07:11 +0000 Subject: [PATCH 2/7] fix manully --- vllm_fl/dispatch/README.md | 4 +- vllm_fl/dispatch/__init__.py | 2 +- vllm_fl/dispatch/backends/__init__.py | 28 +----- vllm_fl/dispatch/backends/vendor/__init__.py | 35 +------- .../impl/attention/ops/merge_attn_states.py | 2 +- vllm_fl/dispatch/builtin_ops.py | 68 ++++++++++++++- vllm_fl/dispatch/config/__init__.py | 2 + vllm_fl/dispatch/config/utils.py | 31 ++++++- vllm_fl/dispatch/manager.py | 13 +-- vllm_fl/dispatch/policy.py | 13 ++- vllm_fl/ops/custom_ops.py | 7 +- vllm_fl/platform.py | 20 ++--- vllm_fl/utils.py | 57 +++++++++++++ vllm_fl/worker/model_runner.py | 85 ++++++++++++++----- 14 files changed, 263 insertions(+), 104 deletions(-) diff --git a/vllm_fl/dispatch/README.md b/vllm_fl/dispatch/README.md index 6a382901..51fb5a84 100644 --- a/vllm_fl/dispatch/README.md +++ b/vllm_fl/dispatch/README.md @@ -480,7 +480,7 @@ export VLLM_FL_FLAGOS_BLACKLIST="custom_op1,custom_op2" - **Environment variables override, not merge**: Setting an env var replaces the config value entirely - **`VLLM_FL_PREFER` sets preference, not exclusivity**: It defines the selection order but will fall back to other backends if the preferred one is unavailable - **To force a specific backend**: Combine `PREFER` with `DENY_VENDORS` or use `PER_OP` to exclude unwanted backends -- **`VLLM_FL_STRICT=1`**: Enables automatic fallback when the primary implementation fails at runtime +- **`VLLM_FL_STRICT=1`**: Enables strict mode — fails immediately if the primary implementation fails, no fallback is attempted #### Backend Priority Values @@ -536,7 +536,7 @@ Currently supported operators: ## Fallback Mechanism -When `VLLM_FL_STRICT=1`, if the primary implementation fails, the system automatically tries other available implementations: +When `VLLM_FL_STRICT=0` (default), if the primary implementation fails, the system automatically tries other available implementations: ``` Op 'rms_norm' using 'default.flagos' (kind=flagos, vendor=None) diff --git a/vllm_fl/dispatch/__init__.py b/vllm_fl/dispatch/__init__.py index 1eaea799..a2f10206 100644 --- a/vllm_fl/dispatch/__init__.py +++ b/vllm_fl/dispatch/__init__.py @@ -21,7 +21,7 @@ Environment Variables: VLLM_FL_CONFIG: Path to YAML configuration file (highest priority, overrides env vars) VLLM_FL_PREFER: Preferred backend ("flagos", "vendor", "reference") - VLLM_FL_STRICT: Enable strict mode ("1" or "0") + VLLM_FL_STRICT: Strict mode: "1" = fail immediately on error (no fallback), "0" = try fallback (default) VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors VLLM_FL_PER_OP: Per-operator order (format: op1=a|b|c;op2=x|y) diff --git a/vllm_fl/dispatch/backends/__init__.py b/vllm_fl/dispatch/backends/__init__.py index ae8c3144..078452c4 100644 --- a/vllm_fl/dispatch/backends/__init__.py +++ b/vllm_fl/dispatch/backends/__init__.py @@ -2,6 +2,10 @@ """ Backend implementations for vllm-plugin-FL dispatch. + +Vendor backends are dynamically discovered and loaded by builtin_ops.py +based on the current platform. This package does not eagerly import vendor +backends to avoid loading unnecessary dependencies at startup. """ from .base import Backend @@ -9,27 +13,3 @@ from .reference import ReferenceBackend __all__ = ["Backend", "FlagGemsBackend", "ReferenceBackend"] - -# Try to import vendor backends -try: - from .vendor.ascend import AscendBackend - - __all__.append("AscendBackend") -except ImportError: - AscendBackend = None - -# Add more vendor backends here as they become available -try: - from .vendor.cuda import CudaBackend - - __all__.append("CudaBackend") -except ImportError: - CudaBackend = None - -# Import MACA backend -try: - from .vendor.maca import MacaBackend - - __all__.append("MacaBackend") -except ImportError: - pass diff --git a/vllm_fl/dispatch/backends/vendor/__init__.py b/vllm_fl/dispatch/backends/vendor/__init__.py index bdeacfbe..6818887b 100644 --- a/vllm_fl/dispatch/backends/vendor/__init__.py +++ b/vllm_fl/dispatch/backends/vendor/__init__.py @@ -8,6 +8,10 @@ Available vendor backends: - ascend: Huawei Ascend NPU backend +This package intentionally avoids eager imports of vendor subpackages. +Importing a specific backend such as ``vllm_fl.dispatch.backends.vendor.ascend`` +should not pull in other vendor branches. + To add a new vendor backend: 1. Create a subdirectory: vendor// 2. Implement the backend class inheriting from Backend @@ -18,34 +22,3 @@ """ __all__ = [] - -# Import Ascend backend -try: - from .ascend import AscendBackend - - __all__.append("AscendBackend") -except ImportError: - pass - -# Import CUDA backend -try: - from .cuda import CudaBackend - - __all__.append("CudaBackend") -except ImportError: - pass - -# Import MACA backend -try: - from .maca import MacaBackend - - __all__.append("MacaBackend") -except ImportError: - pass - -# Add more vendor backends here as they become available: -# try: -# from .rocm import RocmBackend -# __all__.append("RocmBackend") -# except ImportError: -# pass diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py index 4395bb25..c2c51159 100644 --- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py +++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py @@ -43,7 +43,7 @@ def supported_headdim(o: torch.Tensor) -> bool: output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse ) else: - from vllm.v1.attention.ops.merge_attn_states import merge_attn_states + from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states return merge_attn_states( output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py index 426c6019..2d0883d9 100644 --- a/vllm_fl/dispatch/builtin_ops.py +++ b/vllm_fl/dispatch/builtin_ops.py @@ -12,6 +12,7 @@ import importlib import os +from .config import get_vendor_device_map from .registry import OpRegistry from .logger_manager import get_logger @@ -20,6 +21,44 @@ # Directory containing vendor backends _VENDOR_BACKENDS_DIR = os.path.join(os.path.dirname(__file__), "backends", "vendor") +def _find_vendor_backend_dir( + vendor_name: str, + available_vendor_dirs: set[str], +) -> str | None: + """Return the backend directory for *vendor_name*, or None if not found. + + Resolves *vendor_name* against get_vendor_device_map() and picks the first + candidate directory that exists in *available_vendor_dirs*. The "maca" alias + is treated as "metax" for MetaX runtime compatibility. + """ + # Keep compatibility with MetaX runtime naming. + if vendor_name == "maca": + vendor_name = "metax" + vendor_map = get_vendor_device_map() + if vendor_name not in vendor_map: + return None + value = vendor_map[vendor_name] + device_type = value.get("device_type") + device_name = value.get("device_name") + return next( + (c for c in (vendor_name, device_name, device_type) if c in available_vendor_dirs), + None, + ) + + +def _get_current_vendor_backend_dirs(available_vendor_dirs: set[str]) -> set[str]: + """Detect current platform vendor name and return its backend directory.""" + try: + from vllm.platforms import current_platform + + vendor_name = getattr(current_platform, "vendor_name", None) + if not isinstance(vendor_name, str) or not vendor_name: + return None + return _find_vendor_backend_dir(vendor_name, available_vendor_dirs) + except Exception as exc: + raise RuntimeError( + "Failed to detect current vendor backend from current_platform." + ) from exc def _register_vendor_backends(registry: OpRegistry) -> None: """ @@ -38,8 +77,33 @@ def _register_vendor_backends(registry: OpRegistry) -> None: for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR): vendor_path = os.path.join(_VENDOR_BACKENDS_DIR, vendor_name) - # Skip non-directories and special files - if not os.path.isdir(vendor_path) or vendor_name.startswith("_"): + available_vendor_dirs = { + vendor_name + for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR) + if os.path.isdir(os.path.join(_VENDOR_BACKENDS_DIR, vendor_name)) + and not vendor_name.startswith("_") + } + + current_vendor_dir = _get_current_vendor_backend_dirs(available_vendor_dirs) + if not current_vendor_dir: + logger.warning( + "Unable to detect current vendor backend; skipping vendor backend registration" + ) + return + + logger.info( + "Registering vendor backends for current platform: %s", + current_vendor_dir, + ) + + for vendor_name in available_vendor_dirs: + vendor_path = os.path.join(_VENDOR_BACKENDS_DIR, vendor_name) + + if vendor_name != current_vendor_dir: + logger.debug( + "Skipping vendor backend '%s' for current platform", + vendor_name, + ) continue # Skip if no register_ops.py exists diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py index c60b1ce8..cc4e8b49 100644 --- a/vllm_fl/dispatch/config/__init__.py +++ b/vllm_fl/dispatch/config/__init__.py @@ -15,6 +15,7 @@ get_per_op_order, get_platform_name, load_platform_config, + get_vendor_device_map, ) __all__ = [ @@ -25,4 +26,5 @@ 'get_flagos_blacklist', 'get_oot_blacklist', 'get_effective_config', + 'get_vendor_device_map', ] diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index 411cf943..3877923c 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -11,7 +11,7 @@ 1. VLLM_FL_CONFIG: User-specified config file path (complete override) 2. Environment variables: Override specific items from platform config - VLLM_FL_PREFER: Backend preference (flagos, vendor, reference) - - VLLM_FL_STRICT: Strict mode (1 or 0) + - VLLM_FL_STRICT: Strict mode: 1 = fail immediately on error (no fallback), 0 = try fallback (default) - VLLM_FL_PER_OP: Per-operator backend order - VLLM_FL_FLAGOS_BLACKLIST: FlagOS operator blacklist - VLLM_FL_OOT_BLACKLIST: OOT operator blacklist @@ -35,6 +35,7 @@ from typing import Any, Optional import yaml +from vllm_fl.utils import VENDOR_DEVICE_MAP # Directory containing config files (config/) _CONFIG_DIR = Path(__file__).parent @@ -207,3 +208,31 @@ def get_effective_config() -> dict[str, Any]: # Return empty config return {} +def get_vendor_device_map() -> dict[str, dict[str, str]]: + """Load vendor mapping from Python config module. + + Returns: + Mapping where key is vendor_name and value is + {"device_type": ..., "device_name": ...}. + """ + if not isinstance(VENDOR_DEVICE_MAP, dict): + return {} + + result: dict[str, dict[str, str]] = {} + for vendor_name, value in VENDOR_DEVICE_MAP.items(): + if not isinstance(vendor_name, str) or not isinstance(value, dict): + continue + + device_type = value.get("device_type") + device_name = value.get("device_name") + if not isinstance(device_type, str) or not device_type: + continue + if not isinstance(device_name, str) or not device_name: + continue + + result[vendor_name] = { + "device_type": device_type, + "device_name": device_name, + } + + return result diff --git a/vllm_fl/dispatch/manager.py b/vllm_fl/dispatch/manager.py index cee87132..c886940e 100644 --- a/vllm_fl/dispatch/manager.py +++ b/vllm_fl/dispatch/manager.py @@ -483,8 +483,11 @@ def call(self, op_name: str, *args, **kwargs): """ Resolve and call an operator implementation with optional fallback support. - When VLLM_FL_STRICT=1, this method will try alternative implementations - if the primary one fails. Otherwise, it behaves like the original implementation. + Behavior is controlled by the active policy's strict flag (VLLM_FL_STRICT): + - VLLM_FL_STRICT=0 (default): fallback mode — if the primary implementation + fails, the system automatically tries the next available implementation. + - VLLM_FL_STRICT=1: strict mode — fail immediately on the first error, + no fallback is attempted. Logs on first call or when the implementation changes (e.g., backend switch). @@ -496,10 +499,10 @@ def call(self, op_name: str, *args, **kwargs): Result from the implementation Raises: - RuntimeError: If all implementations fail (when fallback enabled) or - if the primary implementation fails (when fallback disabled) + RuntimeError: If all implementations fail (fallback mode) or + if the primary implementation fails (strict mode) """ - enable_fallback = os.getenv("VLLM_FL_STRICT", "1") != "0" + enable_fallback = not get_policy().strict if not enable_fallback: # Original behavior: use cached resolve() and fast-fail diff --git a/vllm_fl/dispatch/policy.py b/vllm_fl/dispatch/policy.py index f1523e80..e7f06139 100644 --- a/vllm_fl/dispatch/policy.py +++ b/vllm_fl/dispatch/policy.py @@ -386,7 +386,7 @@ def _policy_from_env(self) -> SelectionPolicy: Environment variables: - VLLM_FL_CONFIG: Path to YAML configuration file (complete override) - VLLM_FL_PREFER: Preference (flagos, vendor, reference) - - VLLM_FL_STRICT: Enable strict mode (1 or 0) + - VLLM_FL_STRICT: Strict mode: 1 = fail immediately on error (no fallback), 0 = try fallback (default) - VLLM_FL_DENY_VENDORS: Comma-separated list of denied vendors - VLLM_FL_ALLOW_VENDORS: Comma-separated list of allowed vendors - VLLM_FL_PER_OP: Per-op order (format: op1=a|b|c;op2=x|y) @@ -409,7 +409,7 @@ def _policy_from_env(self) -> SelectionPolicy: # Priority 2: Environment variables override platform config # Get values from environment variables env_prefer_str = os.environ.get("VLLM_FL_PREFER", "").strip().lower() - env_strict_str = os.environ.get("VLLM_FL_STRICT", "").strip() + env_strict_str = os.environ.get("VLLM_FL_STRICT", "0").strip() env_deny_str = os.environ.get("VLLM_FL_DENY_VENDORS", "").strip() env_allow_str = os.environ.get("VLLM_FL_ALLOW_VENDORS", "").strip() env_per_op_str = os.environ.get("VLLM_FL_PER_OP", "").strip() @@ -423,7 +423,14 @@ def _policy_from_env(self) -> SelectionPolicy: prefer_str = PREFER_DEFAULT if env_strict_str: - strict = env_strict_str == "1" + if env_strict_str not in ("0", "1"): + logger.warning( + f"Invalid VLLM_FL_STRICT value '{env_strict_str}', " + f"expected '0' or '1'. Defaulting to '0' (fallback mode)." + ) + strict = False + else: + strict = env_strict_str == "1" elif platform_policy: strict = platform_policy.strict else: diff --git a/vllm_fl/ops/custom_ops.py b/vllm_fl/ops/custom_ops.py index 0485156a..6c971cb8 100644 --- a/vllm_fl/ops/custom_ops.py +++ b/vllm_fl/ops/custom_ops.py @@ -68,6 +68,11 @@ def register_oot_ops(whitelist: Optional[List[str]] = None) -> None: logger.warning(f"OOT op '{op_name}' not found in OOT_OPS, skipping.") continue + # unquantized_fused_moe_method only registers when use_flaggems_op is True + if op_name == "unquantized_fused_moe_method" and not use_flaggems_op(op_name): + logger.debug(f"Skipping '{op_name}': use_flaggems_op returned False") + continue + op_cls, registration_name = OOT_OPS[op_name] logger.info(f"Registering oot op: {op_name} as '{registration_name}'") CustomOp.register_oot(_decorated_op_cls=op_cls, name=registration_name) @@ -77,4 +82,4 @@ def register_oot_ops(whitelist: Optional[List[str]] = None) -> None: from vllm.platforms import current_platform if current_platform.device_type == "npu": from vllm_fl.dispatch.backends.vendor.ascend.patch import apply_ascend_patches - apply_ascend_patches() \ No newline at end of file + apply_ascend_patches() diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index 45882cc8..c53e8332 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -29,7 +29,7 @@ VllmConfig = None CacheDType = None -from vllm_fl.utils import DeviceInfo +from vllm_fl.utils import DeviceInfo, get_device_name, get_device_type logger = init_logger(__name__) @@ -46,6 +46,8 @@ class PlatformFL(Platform): _enum = PlatformEnum.OOT device_info = DeviceInfo() vendor_name = device_info.vendor_name + device_type = get_device_type(vendor_name) + device_name = get_device_name(vendor_name) # cuda_alike (nvidia/metax): device_name = vendor_name (not used in torch.device) # non-cuda_alike (iluvatar/ascend): device_name = device_type (used in torch.device) device_name = device_info.vendor_name if ( @@ -336,18 +338,16 @@ def pre_register_and_update(cls, parser=None) -> None: if cls.device_name == "npu": import vllm_fl.dispatch.backends.vendor.ascend - @classmethod def supports_fp8(cls) -> bool: - return cls.has_device_capability(89) + if cls.vendor_name == "nvidia": + return True + return False @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: - if cls.device_type == "npu": - return cls.torch_device_fn.get_device_properties( - device_id - ).total_memory - device_props = torch.cuda.get_device_properties(device_id) - return device_props.total_memory + return cls.torch_device_fn.get_device_properties( + device_id + ).total_memory @classmethod def use_custom_op_collectives(cls) -> bool: @@ -357,7 +357,7 @@ def use_custom_op_collectives(cls) -> bool: @classmethod def num_compute_units(cls, device_id: int = 0) -> int: - return torch.cuda.get_device_properties(device_id).multi_processor_count + return cls.torch_device_fn.get_device_properties(device_id).multi_processor_count @classmethod diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index 2b958cad..afbc553c 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -10,6 +10,63 @@ _OP_CONFIG: Optional[dict[str, str]] = None +# Mapping used by dispatch registration to resolve the current runtime platform +# into a backend directory under dispatch/backends/vendor. +# +# Field definitions and sources: +# - top-level key (vendor_name): normalized hardware vendor identifier. +# Source: runtime platform detection (current_platform.vendor_name) and +# fallback device probing (DeviceInfo.vendor_name). +# - device_type: compute class reported by runtime, such as "cuda" or "npu". +# Source: runtime platform detection (current_platform.device_type) and +# fallback device probing (DeviceInfo.device_type). +# - device_name: runtime device family/product alias used by vLLM platform. +# Source: runtime platform detection (current_platform.device_name). +# +# Values are normalized to lowercase and matched against available backend +# subdirectories (for example, cuda/ascend/metax/iluvatar). +VENDOR_DEVICE_MAP: dict[str, dict[str, str]] = { + # Registered backend: vendor/cuda + "nvidia": {"device_type": "cuda", "device_name": "nvidia"}, + # Registered backend: vendor/ascend + "ascend": {"device_type": "npu", "device_name": "npu"}, + # Registered backend: vendor/iluvatar + "iluvatar": {"device_type": "cuda", "device_name": "cuda"}, + # Registered backend: vendor/metax + "metax": {"device_type": "cuda", "device_name": "metax"}, +} + + +def _get_vendor_device_field(vendor_name: str, field: str) -> str: + """Get a required field from VENDOR_DEVICE_MAP for the specified vendor.""" + if not isinstance(vendor_name, str) or not vendor_name.strip(): + raise ValueError("vendor_name must be a non-empty string.") + + normalized_vendor = vendor_name + device_info = VENDOR_DEVICE_MAP.get(normalized_vendor) + if not isinstance(device_info, dict): + raise ValueError( + f"Vendor '{normalized_vendor}' not found in VENDOR_DEVICE_MAP." + ) + + value = device_info.get(field) + if not isinstance(value, str) or not value.strip(): + raise ValueError( + f"Field '{field}' for vendor '{normalized_vendor}' is missing " + "or empty in VENDOR_DEVICE_MAP." + ) + return value + + +def get_device_type(vendor_name: str) -> str: + """Return the configured device_type for the given vendor.""" + return _get_vendor_device_field(vendor_name, "device_type") + + +def get_device_name(vendor_name: str) -> str: + """Return the configured device_name for the given vendor.""" + return _get_vendor_device_field(vendor_name, "device_name") + def use_flaggems(default: bool = True) -> bool: if os.environ.get("VLLM_FL_PREFER_ENABLED", "True").lower() not in ("true", "1"): diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index bb32fb24..3e348b24 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -110,6 +110,40 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.nvtx_pytorch_hooks import PytHooks from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units +from vllm.platforms import current_platform +if current_platform.dist_backend == "flagcx": + @contextmanager + def graph_capture(device: torch.device): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the NPU graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current NPU stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + graph_capture_context = GraphCaptureContext( + current_platform.torch_device_fn.Stream(device=device)) + stream = graph_capture_context.stream + + # we use nullcontext now + maybe_ca_context = nullcontext() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = current_platform.torch_device_fn.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with current_platform.torch_device_fn.stream(stream), maybe_ca_context: + yield graph_capture_context +else: + from vllm.distributed.parallel_state import graph_capture from vllm.utils.torch_utils import ( get_dtype_size, kv_cache_dtype_str_to_dtype, @@ -223,7 +257,7 @@ init_io_dump_from_env, register_io_module_hooks, ) -CUDAGraphWrapper = GraphWrapper +GraphWrapper = GraphWrapper AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled @@ -822,9 +856,11 @@ def __init__( self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() + # TODO(yxa): NPU uses int32, CUDA uses int64 for sampled token ids + sampled_ids_dtype = torch.int32 if current_platform.device_type == "npu" else torch.int64 self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_num_reqs, 1), - dtype=torch.int64, + dtype=sampled_ids_dtype, device="cpu", pin_memory=self.pin_memory, ) @@ -1062,7 +1098,7 @@ def _init_device_properties(self) -> None: # Note: used for model runner override. def _sync_device(self) -> None: - torch.accelerator.synchronize() + current_platform.torch_device_fn.synchronize() def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None: """Update the cached states and the persistent batch with the scheduler @@ -3008,7 +3044,7 @@ def _gather_mm_embeddings( def get_model(self) -> nn.Module: if not hasattr(self, "model"): raise ValueError("Cannot get model before model has been initialized") - if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)): + if isinstance(self.model, (GraphWrapper, UBatchWrapper)): # get raw model out of the cudagraph wrapper. return self.model.unwrap() return self.model @@ -4420,9 +4456,9 @@ def _copy_draft_token_ids_to_cpu( assert self.draft_token_ids_event is not None assert self.draft_token_ids_copy_stream is not None assert self.draft_token_ids_cpu is not None - default_stream = torch.cuda.current_stream() + default_stream = current_platform.torch_device_fn.current_stream() num_reqs = draft_token_ids.shape[0] - with torch.cuda.stream(self.draft_token_ids_copy_stream): + with current_platform.torch_device_fn.stream(self.draft_token_ids_copy_stream): if not zeros_only: # Trigger async copy of draft token ids to cpu. self.draft_token_ids_copy_stream.wait_stream(default_stream) @@ -4451,10 +4487,10 @@ def _copy_valid_sampled_token_count( if self.valid_sampled_token_count_event is None: return - default_stream = torch.cuda.current_stream() + default_stream = current_platform.torch_device_fn.current_stream() # Initialize a new stream to overlap the copy operation with # prepare_input of draft model. - with torch.cuda.stream(self.valid_sampled_token_count_copy_stream): + with current_platform.torch_device_fn.stream(self.valid_sampled_token_count_copy_stream): self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore counts = valid_sampled_tokens_count counts_cpu = self.valid_sampled_token_count_cpu @@ -4813,16 +4849,19 @@ def load_model(self, load_dummy_weights: bool = False) -> None: self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - except torch.cuda.OutOfMemoryError as e: - msg = ( - "Failed to load model - not enough GPU memory. " - "Try lowering --gpu-memory-utilization to free memory for weights, " - "increasing --tensor-parallel-size, or using --quantization. " - "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " - "for more tips." - ) - combined_msg = f"{msg} (original error: {e})" - logger.error(combined_msg) + except Exception as e: + is_oom = 'out of memory' in str(e).lower() + + if is_oom: + msg = ( + "Failed to load model - not enough device memory. " + "Try lowering --gpu-memory-utilization to free memory for weights, " + "increasing --tensor-parallel-size, or using --quantization. " + "See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ " + "for more tips." + ) + combined_msg = f"{msg} (original error: {e})" + logger.error(combined_msg) raise e logger.info_once( "Model loading took %s GiB memory and %.6f seconds", @@ -4880,7 +4919,7 @@ def load_model(self, load_dummy_weights: bool = False) -> None: cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.use_ubatching ): - self.model = CUDAGraphWrapper( + self.model = GraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) elif self.parallel_config.use_ubatching: @@ -5897,7 +5936,7 @@ def profile_cudagraph_memory(self) -> int: # Use a temporary pool for profiling to avoid fragmentation in the main pool. profiling_pool = current_platform.graph_pool_handle() original_pools: dict[int, Any] = {} - for instance in list(CUDAGraphWrapper._all_instances): + for instance in list(GraphWrapper._all_instances): original_pools[id(instance)] = instance.graph_pool instance.graph_pool = profiling_pool @@ -5947,8 +5986,8 @@ def profile_cudagraph_memory(self) -> int: ) set_cudagraph_capturing_enabled(False) - CUDAGraphWrapper.clear_all_graphs() - for instance in list(CUDAGraphWrapper._all_instances): + GraphWrapper.clear_all_graphs() + for instance in list(GraphWrapper._all_instances): if id(instance) in original_pools: instance.graph_pool = original_pools[id(instance)] for key_set in self.cudagraph_dispatcher.cudagraph_keys.values(): @@ -5981,7 +6020,7 @@ def capture_model(self) -> int: return 0 # Initialize encoder CUDA graph manager if enabled. - # Use get_model() to unwrap CUDAGraphWrapper/UBatchWrapper, + # Use get_model() to unwrap GraphWrapper/UBatchWrapper, # because @runtime_checkable Protocol isinstance() checks do not # work through __getattr__ forwarding. if ( From c1fc8e7d01972ed5eb65f1bf5f4a0bdd43e2750f Mon Sep 17 00:00:00 2001 From: cyberpioneer Date: Thu, 2 Apr 2026 06:57:53 +0000 Subject: [PATCH 3/7] fix unittest --- tests/unit_tests/ops/test_layernorm.py | 4 + tests/unit_tests/worker/test_model_runner.py | 3 + vllm_fl/__init__.py | 4 +- vllm_fl/patches/glm_moe_dsa.py | 29 - vllm_fl/worker/model_runner.py | 624 +++++-------------- 5 files changed, 177 insertions(+), 487 deletions(-) diff --git a/tests/unit_tests/ops/test_layernorm.py b/tests/unit_tests/ops/test_layernorm.py index 4119a3d0..3dcbe1ec 100644 --- a/tests/unit_tests/ops/test_layernorm.py +++ b/tests/unit_tests/ops/test_layernorm.py @@ -13,6 +13,10 @@ class TestRMSNormFL: """Test RMSNormFL class behavior.""" + def __init__(self): + from vllm.config import VllmConfig, set_current_vllm_config + set_current_vllm_config(VllmConfig()) + @pytest.fixture def mock_call_op(self): with patch("vllm_fl.ops.layernorm.call_op") as mock: diff --git a/tests/unit_tests/worker/test_model_runner.py b/tests/unit_tests/worker/test_model_runner.py index 2e844470..04a15981 100644 --- a/tests/unit_tests/worker/test_model_runner.py +++ b/tests/unit_tests/worker/test_model_runner.py @@ -60,6 +60,7 @@ def test_fields_match_expected_contract(self): "aux_hidden_states", "ec_connector_output", "cudagraph_stats", + 'slot_mappings', ) assert ExecuteModelState._fields == expected_fields, ( "ExecuteModelState fields changed - this may break execute_model consumers" @@ -79,6 +80,7 @@ def test_immutability_prevents_accidental_mutation(self): aux_hidden_states=None, ec_connector_output=None, cudagraph_stats=None, + slot_mappings=None, ) with pytest.raises(AttributeError): @@ -101,6 +103,7 @@ def test_unpacking_for_downstream_processing(self): aux_hidden_states=None, ec_connector_output=None, cudagraph_stats=None, + slot_mappings=None, ) # Simulate downstream unpacking diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 56e3ece4..117a6502 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -56,7 +56,7 @@ def register_model(): from vllm_fl.configs.glm_moe_dsa import GlmMoeDsaConfig _CONFIG_REGISTRY["glm_moe_dsa"] = GlmMoeDsaConfig - from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model - glm5_model() + #from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model + #glm5_model() except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") diff --git a/vllm_fl/patches/glm_moe_dsa.py b/vllm_fl/patches/glm_moe_dsa.py index df8ec255..f4a7dd67 100644 --- a/vllm_fl/patches/glm_moe_dsa.py +++ b/vllm_fl/patches/glm_moe_dsa.py @@ -52,33 +52,6 @@ def _patched_set(self, special_tokens=None): pass -def patch_fp8_mqa_logits_dim(): - """Fix k_scale dim mismatch for deep_gemm fp8_mqa_logits. - - vLLM 0.13.0 passes k_scale as [N, 1] but deep_gemm 2.3.0 expects [N]. - Upstream fix: https://github.com/vllm-project/vllm/pull/32652 - We wrap fp8_mqa_logits to flatten k_scale before calling the native impl. - """ - import vllm.utils.deep_gemm as dg_mod - - dg_mod._lazy_init() - _orig_impl = dg_mod._fp8_mqa_logits_impl - if _orig_impl is None: - return - - def _fixed_fp8_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke): - k_fp8, k_scale = kv - return _orig_impl( - q, (k_fp8, k_scale.flatten()), weights, - cu_seqlen_ks, cu_seqlen_ke, - ) - - dg_mod._fp8_mqa_logits_impl = _fixed_fp8_mqa_logits - dg_mod._lazy_init = lambda: None - logger.info("Patched fp8_mqa_logits: flatten k_scale [N,1] -> [N] " - "for deep_gemm 2.3.0 compat") - - def patch_indexer_schedule_metadata(): """Fix schedule_metadata not computed when VLLM_USE_DEEP_GEMM=0. @@ -125,8 +98,6 @@ def _patched_build(self, common_prefix_len, common_attn_metadata, def apply_platform_patches(): """All GLM-5 patches needed at platform registration time.""" patch_tokenizer_compat() - patch_fp8_mqa_logits_dim() - def patch_indexer_rope_reshape(): """Fix RoPE output shape in Indexer.forward for DSA models. diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index 3e348b24..e5f552b7 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -209,7 +209,6 @@ def graph_capture(device: torch.device): update_scheduler_for_invalid_drafts, ) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer -from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker import mamba_utils @@ -245,7 +244,6 @@ def graph_capture(device: torch.device): if TYPE_CHECKING: from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.spec_decode.ngram_proposer import NgramProposer - from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager logger = init_logger(__name__) @@ -272,7 +270,7 @@ def __init__( sampled_token_ids: torch.Tensor, logprobs_tensors: LogprobsTensors | None, invalid_req_indices: list[int], - async_output_copy_stream: torch.cuda.Stream, + async_output_copy_stream: current_platform.torch_device_fn.Stream, vocab_size: int, ): self._model_runner_output = model_runner_output @@ -288,8 +286,8 @@ def __init__( self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. - default_stream = torch.cuda.current_stream() - with torch.cuda.stream(async_output_copy_stream): + default_stream = current_platform.torch_device_fn.current_stream() + with current_platform.torch_device_fn.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self.sampled_token_ids_cpu = self._sampled_token_ids.to( "cpu", non_blocking=True @@ -384,7 +382,7 @@ def __init__( model_runner_output: ModelRunnerOutput, raw_pooler_output: PoolerOutput, finished_mask: list[bool], - async_output_copy_stream: torch.cuda.Stream, + async_output_copy_stream: current_platform.torch_device_fn.Stream, ): self._model_runner_output = model_runner_output @@ -396,8 +394,8 @@ def __init__( self._raw_pooler_output = raw_pooler_output # Initiate the copy on a separate stream, but do not synchronize it. - default_stream = torch.cuda.current_stream() - with torch.cuda.stream(async_output_copy_stream): + default_stream = current_platform.torch_device_fn.current_stream() + with current_platform.torch_device_fn.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._model_runner_output.pooler_output = _copy_pooler_output_to_cpu( raw_pooler_output=self._raw_pooler_output, @@ -469,9 +467,8 @@ def __init__( self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model ) - # These will be overridden in load_model() + # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False - self.requires_sequential_video_encoding = False # Set to True after init_routed_experts_capturer() completes. # Prevents routed experts code from running during profiling/dummy run. self.routed_experts_initialized = False @@ -548,9 +545,6 @@ def __init__( self.encoder_cache: dict[str, torch.Tensor] = {} self.late_interaction_runner = LateInteractionRunner() - # Encoder CUDA graph manager (initialized after model load if enabled) - self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None - self.use_aux_hidden_state_outputs = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on @@ -618,7 +612,6 @@ def __init__( self.rejection_sampler = RejectionSampler(self.sampler) self.num_spec_tokens = 0 - self.valid_sampled_token_count_gpu: torch.Tensor | None = None if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens draft_config = self.speculative_config.draft_model_config @@ -626,16 +619,13 @@ def __init__( self.effective_drafter_max_model_len = draft_config.max_model_len else: self.effective_drafter_max_model_len = self.max_model_len - self.use_async_spec_decode = ( - self.use_async_scheduling and self.num_spec_tokens > 0 - ) # Request states. self.requests: dict[str, CachedRequestState] = {} # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} - self.comm_stream = torch.cuda.Stream() + self.comm_stream = current_platform.torch_device_fn.Stream() # Input Batch # NOTE(Chen): Ideally, we should initialize the input batch inside @@ -676,22 +666,19 @@ def __init__( ), # We currently don't know whether a particular custom logits processor # uses output token ids so we set this conservatively. - # ThinkingTokenBudgetLogitsProcessor also needs output token ids to - # correctly track think start/end token sequences in async scheduling. - logitsprocs_need_output_token_ids=bool(custom_logitsprocs) - or self.vllm_config.reasoning_config is not None, + logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) # Separate cuda stream for overlapping transfer of sampled token ids from # GPU to CPU when async scheduling is enabled. - self.async_output_copy_stream: torch.cuda.Stream | None = None + self.async_output_copy_stream: current_platform.torch_device_fn.Stream | None = None # cuda event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. self.prepare_inputs_event: torch.Event | None = None if self.use_async_scheduling: - self.async_output_copy_stream = torch.cuda.Stream() + self.async_output_copy_stream = current_platform.torch_device_fn.Stream() self.prepare_inputs_event = torch.Event() # self.cudagraph_batch_sizes sorts in ascending order. @@ -714,31 +701,11 @@ def __init__( # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=self.device - ) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) - self.seq_lens = torch.zeros( - self.max_num_reqs, dtype=torch.int32, device=self.device - ) - self.optimistic_seq_lens_cpu = torch.zeros( - self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory - ) - self.num_computed_tokens = torch.zeros( - self.max_num_reqs, dtype=torch.int32, device=self.device - ) - self.prev_num_draft_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int32 - ) - self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64) - # Maps current batch position -> previous batch position (-1 for new reqs) - self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64) - self.num_scheduled_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int32 - ) - + self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) if self.dcp_world_size > 1: self.dcp_local_seq_lens = self._make_buffer( @@ -758,7 +725,7 @@ def __init__( self.max_num_reqs, dtype=torch.int32 ) self.num_accepted_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int32 + self.max_num_reqs, dtype=torch.int64 ) # Only relevant for multimodal models @@ -797,14 +764,12 @@ def __init__( # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None - # OPTIMIZATION: Cache the arange tensors rather than creating them - # every step. Keep in int64 to avoid overflow with long context. - # - arange_np: immutable [0, 1, 2, ...] used as source for batched computation - # - query_pos: CpuGpuBuffer for the computed batched arange result - arange_size = max(self.max_num_reqs + 1, self.max_num_tokens) - self.arange_np = np.arange(arange_size, dtype=np.int64) - self.query_pos = self._make_buffer(arange_size, dtype=torch.int64) - self._arange_scratch = np.empty(arange_size, dtype=np.int64) + # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -842,8 +807,8 @@ def __init__( # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. self._num_valid_draft_tokens: torch.Tensor | None = None self._num_valid_draft_tokens_cpu: torch.Tensor | None = None - self._num_valid_draft_tokens_event: torch.cuda.Event | None = None - self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None + self._num_valid_draft_tokens_event: current_platform.torch_device_fn.Event | None = None + self._num_valid_draft_tokens_copy_stream: current_platform.torch_device_fn.Stream | None = None if ( self.speculative_config is not None and self.speculative_config.use_ngram_gpu() @@ -851,8 +816,8 @@ def __init__( self._num_valid_draft_tokens_cpu = torch.empty( self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory ) - self._num_valid_draft_tokens_event = torch.cuda.Event() - self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream() + self._num_valid_draft_tokens_event = current_platform.torch_device_fn.Event() + self._num_valid_draft_tokens_copy_stream = current_platform.torch_device_fn.Stream() self._draft_token_req_ids: list[str] | None = None self.transfer_event = torch.Event() @@ -868,18 +833,18 @@ def __init__( # Pre-allocated tensor for copying valid sampled token counts to CPU, # with dedicated stream for overlapping and event for coordination. self.valid_sampled_token_count_event: torch.Event | None = None - self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None + self.valid_sampled_token_count_copy_stream: current_platform.torch_device_fn.Stream | None = None # We also copy the drafted tokens to the CPU asynchronously, # in case we need them for structured outputs. self.draft_token_ids_event: torch.Event | None = None - self.draft_token_ids_copy_stream: torch.cuda.Stream | None = None + self.draft_token_ids_copy_stream: current_platform.torch_device_fn.Stream | None = None self.valid_sampled_token_count_cpu: torch.Tensor | None = None self.draft_token_ids_cpu: torch.Tensor | None = None self.num_accepted_tokens_event: torch.Event | None = None if self.num_spec_tokens: self.draft_token_ids_event = torch.Event() self.num_accepted_tokens_event = torch.Event() - self.draft_token_ids_copy_stream = torch.cuda.Stream() + self.draft_token_ids_copy_stream = current_platform.torch_device_fn.Stream() self.draft_token_ids_cpu = torch.empty( (self.max_num_reqs, self.num_spec_tokens), dtype=torch.int64, @@ -888,10 +853,10 @@ def __init__( ) if self.use_async_scheduling: self.valid_sampled_token_count_event = torch.Event() - self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() + self.valid_sampled_token_count_copy_stream = current_platform.torch_device_fn.Stream() self.valid_sampled_token_count_cpu = torch.empty( self.max_num_reqs, - dtype=torch.int32, + dtype=sampled_ids_dtype, device="cpu", pin_memory=self.pin_memory, ) @@ -982,13 +947,13 @@ def _get_positions(self, num_tokens: Any): return self.mrope_positions.gpu[:, :num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, :num_tokens] - return self.positions[:num_tokens] + return self.positions.gpu[:num_tokens] else: if self.uses_mrope: return self.mrope_positions.gpu[:, num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, num_tokens] - return self.positions[num_tokens] + return self.positions.gpu[num_tokens] def _make_buffer( self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True @@ -1032,7 +997,7 @@ def _init_model_kwargs(self): if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens[:num_reqs] + seq_lens = self.seq_lens.gpu[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -1092,7 +1057,7 @@ def _zero_block_ids(self, block_ids: list[int]) -> None: # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties""" + """Initialize attributes from current_platform.torch_device_fn.get_device_properties""" self.num_sms = num_compute_units(self.device.index) @@ -1100,7 +1065,7 @@ def _init_device_properties(self) -> None: def _sync_device(self) -> None: current_platform.torch_device_fn.synchronize() - def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None: + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -1165,8 +1130,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None ngram_gpu_new_reqs: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = [] - deferred_spec_decode_corrections = [] - # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -1253,8 +1216,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None scheduler_output, self.input_batch.req_id_to_index, ) - if self.use_async_spec_decode: - self.prev_num_draft_tokens.np.fill(0) + + # Wait until valid_sampled_tokens_count is copied to cpu, + # then use it to update actual num_computed_tokens of each request. + valid_sampled_token_count = self._get_valid_sampled_token_count() for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] @@ -1281,30 +1246,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None if req_index is None: req_state.prev_num_draft_len = 0 else: - # Optimistically assume all accepted; queue up a correction - # to be called after the model forward to preserve async - # scheduling. Corrected on GPU in _prepare_inputs. - optimistic_num_accepted = req_state.prev_num_draft_len - req_state.output_token_ids.extend([-1] * optimistic_num_accepted) - - deferred_spec_decode_corrections.append( - (req_id, optimistic_num_accepted, req_state) - ) - - prev_req_index = ( - self.input_batch.prev_req_id_to_index.get(req_id) - if self.input_batch.prev_req_id_to_index - else None - ) - if prev_req_index is not None: - self.prev_num_draft_tokens.np[prev_req_index] = ( - optimistic_num_accepted - ) + assert self.input_batch.prev_req_id_to_index is not None + prev_req_index = self.input_batch.prev_req_id_to_index[req_id] + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + num_rejected = req_state.prev_num_draft_len - num_accepted + num_computed_tokens -= num_rejected + req_state.output_token_ids.extend([-1] * num_accepted) - if is_ngram_gpu and optimistic_num_accepted > 0: - self.input_batch.num_tokens_no_spec[req_index] += ( - optimistic_num_accepted - ) + if is_ngram_gpu and num_accepted > 0 and req_index is not None: + self.input_batch.num_tokens_no_spec[req_index] += num_accepted # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -1332,8 +1282,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None ) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load - # failure, or output_token_ids was inflated by the optimistic - # extend above (async spec decode). Align the cached state. + # failure. Align the cached state. del req_state.output_token_ids[num_output_tokens:] if req_index is not None: end_idx = ( @@ -1421,40 +1370,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None _pinned_val_buf=self._ngram_pinned_val_buf, ) - if deferred_spec_decode_corrections: - - def correct_spec_decode_token_counts(): - valid_sampled_token_count = self._get_valid_sampled_token_count() - if not valid_sampled_token_count: - return - prev_req_id_to_index = self.input_batch.prev_req_id_to_index - if not prev_req_id_to_index: - return - for ( - req_id, - optimistic_num_accepted, - req_state, - ) in deferred_spec_decode_corrections: - prev_req_index = prev_req_id_to_index.get(req_id) - if prev_req_index is None: - continue - num_accepted = valid_sampled_token_count[prev_req_index] - 1 - correction = optimistic_num_accepted - num_accepted - req_state.num_computed_tokens -= correction - cur_req_index = self.input_batch.req_id_to_index.get(req_id) - if cur_req_index is None: - continue - self.input_batch.num_computed_tokens_cpu[cur_req_index] -= ( - correction - ) - if is_ngram_gpu and correction > 0: - self.input_batch.num_tokens_no_spec[cur_req_index] -= correction - self.num_tokens_no_spec_gpu[cur_req_index] -= correction - - return correct_spec_decode_token_counts - else: - return None - def _update_states_after_model_execute( self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: @@ -1469,9 +1384,6 @@ def _update_states_after_model_execute( if not self.speculative_config or not self.model_config.is_hybrid: return - # TODO: Remove .cpu() sync to enable fully async for hybrid model; - # Use num_computed_tokens.gpu instead of req.num_computed_tokens to - # support aligned mamba cache mode. # Find the number of accepted tokens for each sequence. num_reqs = output_token_ids.size(0) self.num_accepted_tokens.gpu[:num_reqs] = ( @@ -1492,12 +1404,12 @@ def _update_states_after_model_execute( .int() .argmax(-1) ) - if self.cache_config.mamba_cache_mode == "align": for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -1618,14 +1530,12 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: def _get_cumsum_and_arange( self, num_tokens: np.ndarray, - arange_out: np.ndarray, cumsum_dtype: np.dtype | None = None, - ) -> np.ndarray: + ) -> tuple[np.ndarray, np.ndarray]: """Get the cumulative sum and batched arange of the given array. - E.g., [2, 5, 3] -> [2, 7, 10], arange written to - arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]. - Equivalent to but faster than: - np.concatenate([np.arange(n) for n in num_tokens]) + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) """ # Step 1. [2, 5, 3] -> [2, 7, 10] cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) @@ -1633,33 +1543,13 @@ def _get_cumsum_and_arange( # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - np.subtract( - self.arange_np[:total_num_tokens], - cumsums_offsets, - out=arange_out[:total_num_tokens], - ) + arange = self.arange_np[:total_num_tokens] - cumsums_offsets - return cu_num_tokens - - def _compute_prev_positions(self, num_reqs: int) -> None: - """Build prev_positions mapping: current pos -> previous pos (-1 if new). - - Populates self.prev_positions.np[:num_reqs] with the mapping. - """ - prev_req_id_to_index = self.input_batch.prev_req_id_to_index - prev_positions = self.prev_positions.np[:num_reqs] - - if not prev_req_id_to_index: - prev_positions.fill(-1) - return - - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - prev_positions[i] = prev_req_id_to_index.get(req_id, -1) + return cu_num_tokens, arange def _prepare_input_ids( self, scheduler_output: "SchedulerOutput", - num_reqs: int, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray, ) -> None: @@ -1667,11 +1557,7 @@ def _prepare_input_ids( Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the - GPU need to be copied into the corresponding slots into input_ids. - - Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos - (-1 for new requests). - """ + GPU need to be copied into the corresponding slots into input_ids.""" if self.input_batch.prev_sampled_token_ids is None: # Normal scheduling case @@ -1684,50 +1570,47 @@ def _prepare_input_ids( # Async scheduling case, where some decode requests from the previous # iteration won't have entries in input_ids_cpu and need to be copied # on the GPU from prev_sampled_token_ids. - prev_positions = self.prev_positions.np[:num_reqs] - scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + assert prev_req_id_to_index is not None sample_flattened_indices: list[int] = [] spec_flattened_indices: list[int] = [] + prev_common_req_indices: list[int] = [] prev_draft_token_indices: list[int] = [] - prev_indices: list[int] = [] - common_indices_match = True + indices_match = True max_flattened_index = -1 total_num_spec_tokens = 0 + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens - for cur_index in range(num_reqs): - prev_index = prev_positions[cur_index] - if prev_index < 0: - continue - prev_indices.append(prev_index) - req_id = self.input_batch.req_ids[cur_index] - # We need to compute the flattened input_ids index of the - # last token in each common request. - draft_len = len(scheduled_spec_tokens.get(req_id, ())) - total_num_spec_tokens += draft_len - flattened_index = cu_num_tokens[cur_index].item() - 1 - # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] - # sample_flattened_indices = [0, 2, 5] - # spec_flattened_indices = [1, 3, 4, 6, 7] - sample_flattened_indices.append(flattened_index - draft_len) - spec_flattened_indices.extend( - range(flattened_index - draft_len + 1, flattened_index + 1) - ) - start = prev_index * self.num_spec_tokens - # prev_draft_token_indices is used to find which draft_tokens_id - # should be copied to input_ids - # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] - # flatten draft_tokens_id [1,2,3,4,5,6] - # draft_len of each request [1, 2, 1] - # then prev_draft_token_indices is [0, 2, 3, 4] - prev_draft_token_indices.extend(range(start, start + draft_len)) - common_indices_match &= prev_index == flattened_index - max_flattened_index = max(max_flattened_index, flattened_index) - + for req_id, cur_index in self.input_batch.req_id_to_index.items(): + if (prev_index := prev_req_id_to_index.get(req_id)) is not None: + prev_common_req_indices.append(prev_index) + # We need to compute the flattened input_ids index of the + # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) + indices_match &= prev_index == flattened_index + max_flattened_index = max(max_flattened_index, flattened_index) num_common_tokens = len(sample_flattened_indices) total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens if num_common_tokens < total_without_spec: # If not all requests are decodes from the last iteration, - # we need to copy the input_ids_cpu to the GPU first. + # We need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.enable_prompt_embeds: self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) @@ -1736,7 +1619,7 @@ def _prepare_input_ids( # No requests in common with the previous iteration # So input_ids.cpu will have all the input ids. return - if common_indices_match and max_flattened_index == (num_common_tokens - 1): + if indices_match and max_flattened_index == (num_common_tokens - 1): # Common-case optimization: the batch is unchanged # and no reordering happened. # The indices are both the same permutation of 0..N-1 so @@ -1753,7 +1636,7 @@ def _prepare_input_ids( sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_indices, dtype=torch.int64, pin_memory=self.pin_memory + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, @@ -1857,15 +1740,15 @@ def _prepare_inputs( req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens = self._get_cumsum_and_arange( - num_scheduled_tokens, self.query_pos.np - ) + # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. - positions_np = ( - self.input_batch.num_computed_tokens_cpu[req_indices] - + self.query_pos.np[: cu_num_tokens[-1]] + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, ) # Calculate M-RoPE positions. @@ -1943,6 +1826,9 @@ def _prepare_inputs( output_idx += num_sched + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1952,21 +1838,12 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - # Compute optimistic seq_lens (assumes all draft tokens from previous - # iteration accepted). Store in optimistic_seq_lens_cpu for use by - # _build_attention_metadata (max_seq_len) and discard_request_mask. - # seq_lens (GPU) will be computed later using the same optimistic values. - torch.add( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], - torch.from_numpy(num_scheduled_tokens), - out=self.optimistic_seq_lens_cpu[:num_reqs], + self.seq_lens.np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens ) - self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) - - # Build prev_positions mapping: current pos -> prev pos (-1 if new). - # Used for gathering from previous iteration's GPU tensors. - prev_req_id_to_index = self.input_batch.prev_req_id_to_index - self._compute_prev_positions(num_reqs) + # Fill unused with 0 for full cuda graph mode. + self.seq_lens.np[num_reqs:].fill(0) + self.seq_lens.copy_to_gpu() num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) @@ -1974,78 +1851,13 @@ def _prepare_inputs( # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning self.discard_request_mask.np[:num_reqs] = ( - self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np + self.seq_lens.np[:num_reqs] < num_tokens_np ) self.discard_request_mask.copy_to_gpu(num_reqs) - # Sync num_accepted_tokens from CPU (set by - # _update_states_after_model_execute for hybrid models). - if self.num_accepted_tokens_event is not None: - self.num_accepted_tokens_event.synchronize() - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() - else: - self.num_accepted_tokens.np.fill(1) - self.num_accepted_tokens.gpu.fill_(1) - - # Update num_computed_tokens on GPU. In async spec decode, - # CPU values are optimistic (all drafts accepted). The kernel - # corrects on GPU using the previous step's - # valid_sampled_token_count_gpu. Otherwise, just copy from CPU. - if ( - self.use_async_spec_decode - and self.valid_sampled_token_count_gpu is not None - and prev_req_id_to_index - ): - self.prev_positions.copy_to_gpu(num_reqs) - self.prev_num_draft_tokens.copy_to_gpu() - cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to( - device=self.device, non_blocking=True - ) - update_num_computed_tokens_for_batch_change( - self.num_computed_tokens, - self.num_accepted_tokens.gpu[:num_reqs], - self.prev_positions.gpu[:num_reqs], - self.valid_sampled_token_count_gpu, - self.prev_num_draft_tokens.gpu, - cpu_values, - ) - else: - self.num_computed_tokens[:num_reqs].copy_( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], - non_blocking=True, - ) - - self.req_indices.np[:total_num_scheduled_tokens] = req_indices - self.req_indices.copy_to_gpu(total_num_scheduled_tokens) - req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens] - - self.query_pos.copy_to_gpu(total_num_scheduled_tokens) - self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens - self.num_scheduled_tokens.copy_to_gpu(num_reqs) - num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs] - self.positions[:total_num_scheduled_tokens] = ( - self.num_computed_tokens[req_indices_gpu].to(torch.int64) - + self.query_pos.gpu[:total_num_scheduled_tokens] - ) - self.seq_lens[:num_reqs] = ( - self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu - ) - self.seq_lens[num_reqs:].fill_(0) - - self.input_batch.block_table.compute_slot_mapping( - num_reqs, - self.query_start_loc.gpu[: num_reqs + 1], - self.positions[:total_num_scheduled_tokens], - ) - # Copy the tensors to the GPU. self._prepare_input_ids( scheduler_output, - num_reqs, total_num_scheduled_tokens, cu_num_tokens, ) @@ -2062,14 +1874,9 @@ def _prepare_inputs( self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True, ) - if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0): - drift = self.num_computed_tokens[req_indices_gpu].to( - torch.int64 - ) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to( - device=self.device, dtype=torch.int64, non_blocking=True - ) - target = self.mrope_positions if self.uses_mrope else self.xdrope_positions - target.gpu[:, :total_num_scheduled_tokens] += drift + else: + # Common case (1D positions) + self.positions.copy_to_gpu(total_num_scheduled_tokens) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -2094,13 +1901,12 @@ def _prepare_inputs( draft_token_ids, ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] - draft_len = len(draft_token_ids) - num_draft_tokens[req_idx] = draft_len + num_draft_tokens[req_idx] = len(draft_token_ids) if ( self.input_batch.num_computed_tokens_cpu[req_idx] >= self.input_batch.num_prompt_tokens[req_idx] ): - num_decode_draft_tokens[req_idx] = draft_len + num_decode_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens ) @@ -2162,7 +1968,16 @@ def _build_attention_metadata( # window size when capturing to make sure the correct kernel is selected. max_seq_len = self.max_model_len else: - max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item() + max_seq_len = self.seq_lens.np[:num_reqs].max().item() + + if use_spec_decode: + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() kv_cache_groups = self.kv_cache_config.kv_cache_groups @@ -2192,30 +2007,14 @@ def _get_block_table(kv_cache_gid: int): attn_gid = self.routed_experts_attn_gid slot_mapping_attn = slot_mappings[attn_gid] self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() - num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ - :num_reqs_padded - ] - num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[ - :num_reqs_padded - ] - seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] - - # is_prefilling: True if request is still in prefill phase. - # Used by mamba backends to distinguish actual decodes from - # short extends. - is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu - - if self.use_async_spec_decode: - # GPU tensors are authoritative in async mode. - seq_lens_cpu = None - num_computed_tokens_cpu = None - cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], - seq_lens=self.seq_lens[:num_reqs_padded], - _seq_lens_cpu=seq_lens_cpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, + seq_lens=self.seq_lens.gpu[:num_reqs_padded], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs_padded + ], num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, max_query_len=max_query_len, @@ -2223,12 +2022,11 @@ def _get_block_table(kv_cache_gid: int): block_table_tensor=block_table_gid_0, slot_mapping=slot_mapping_gid_0, causal=True, - is_prefilling=is_prefilling, ) if self.dcp_world_size > 1: self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( - self.optimistic_seq_lens_cpu[:num_reqs], + self.seq_lens.cpu[:num_reqs], self.dcp_world_size, self.dcp_rank, self.parallel_config.cp_kv_cache_interleave_size, @@ -2632,34 +2430,33 @@ def _calc_spec_decode_metadata( # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 - # Step 1. - # cu_num_sampled_tokens: [4, 5, 8, 9, 11] - # _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - cu_num_sampled_tokens = self._get_cumsum_and_arange( - num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32 + # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( + num_sampled_tokens, cumsum_dtype=np.int32 ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]] + logits_indices += arange # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 # Compute the draft logits indices. # cu_num_draft_tokens: [3, 3, 5, 5, 6] - # _arange_scratch[:6]: [0, 1, 2, 0, 1, 0] - cu_num_draft_tokens = self._get_cumsum_and_arange( - num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32 + # arange: [0, 1, 2, 0, 1, 0] + cu_num_draft_tokens, arange = self._get_cumsum_and_arange( + num_draft_tokens, cumsum_dtype=np.int32 ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens ) # [0, 1, 2, 5, 6, 9] - target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]] + target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( @@ -2852,23 +2649,17 @@ def _execute_mm_encoder( ): batch_outputs: MultiModalEmbeddings - # EVS and dynamic res video related change. + # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data. This solves the issue with scheduler # putting too many video samples into a single batch. Scheduler # uses pruned vision tokens count to compare it versus compute # budget which is incorrect (Either input media size or non-pruned # output vision tokens count should be considered) - # dynamic res video for nemotron temporarily uses this hack via - # requires_sequential_video_encoding - # because it doesn't yet support video batching. # TODO(ywang96): Fix memory profiling to take EVS into account and # remove this hack. if ( - ( - self.is_multimodal_pruning_enabled - or self.requires_sequential_video_encoding - ) + self.is_multimodal_pruning_enabled and modality == "video" and num_items > 1 ): @@ -2905,19 +2696,7 @@ def _execute_mm_encoder( with self.timed_encoder_operation( should_time, mm_lora_refs, current_item_idx, num_items ): - cudagraph_output = None - if ( - self.encoder_cudagraph_manager is not None - and self.encoder_cudagraph_manager.supports_modality(modality) - ): - cudagraph_output = self.encoder_cudagraph_manager.execute( - mm_kwargs_batch, - ) - - if cudagraph_output is not None: - batch_outputs = cudagraph_output - else: - batch_outputs = model.embed_multimodal(**mm_kwargs_batch) + batch_outputs = model.embed_multimodal(**mm_kwargs_batch) sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items) encoder_outputs.extend(batch_outputs) @@ -3072,7 +2851,15 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: if not is_pooling_model(model): return [] - return list(model.pooler.get_supported_tasks()) + supported_tasks = list(model.pooler.get_supported_tasks()) + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once("Score API is only enabled for num_labels == 1.") + + return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() @@ -3161,14 +2948,11 @@ def _pool( ) hidden_states = hidden_states[:num_scheduled_tokens] - seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np, - seq_lens_cpu, - device=hidden_states.device, - query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1], + num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device ) model = cast(VllmModelForPooling, self.model) @@ -3320,9 +3104,7 @@ def _preprocess( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions[:num_input_tokens] - if num_input_tokens > num_scheduled_tokens: - self.positions[num_scheduled_tokens:num_input_tokens].zero_() + positions = self.positions.gpu[:num_input_tokens] if is_first_rank: intermediate_tensors = None @@ -3836,10 +3618,10 @@ def execute_model( scheduled_spec_decode_tokens=spec_decode_tokens_copy, ) - if has_kv_transfer_group(): - kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - get_kv_transfer_group().handle_preemptions(kv_connector_metadata) + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions( + scheduler_output.preempted_req_ids + ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ( @@ -3847,7 +3629,7 @@ def execute_model( self.synchronize_input_prep(), ): # Update persistent batch states. - deferred_state_corrections_fn = self._update_states(scheduler_output) + self._update_states(scheduler_output) if has_ec_transfer() and not get_ec_transfer().is_consumer: with self.maybe_get_ec_connector_output( @@ -3960,12 +3742,6 @@ def execute_model( pad_attn = cudagraph_mode == CUDAGraphMode.FULL if self.cache_config.mamba_cache_mode == "align": - # preprocess_mamba reads req_state.num_computed_tokens (CPU) - # to decide copy operations, so we must apply deferred - # corrections before it runs. - if deferred_state_corrections_fn: - deferred_state_corrections_fn() - deferred_state_corrections_fn = None mamba_utils.preprocess_mamba( scheduler_output, self.kv_cache_config, @@ -3977,14 +3753,6 @@ def execute_model( self.model.get_mamba_state_copy_func(), self._get_mamba_copy_bufs(), ) - # preprocess_mamba resets num_accepted_tokens_cpu to 1 - # for requests whose state was copied to a new block. - # Re-sync to GPU so the mamba kernel reads from the - # correct initial state slot (init_token_idx = 0). - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.copy_to_gpu(num_reqs) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -4145,12 +3913,6 @@ def execute_model( slot_mappings, ) self.kv_connector_output = kv_connector_output - - # Now the batch has been launched we can wait for corrections from the - # previous model forward without breaking async scheduling. - if deferred_state_corrections_fn: - deferred_state_corrections_fn() - return None @managed_inference_mode() @@ -4215,7 +3977,6 @@ def sample_tokens( self._draft_token_ids = None self._draft_token_req_ids = None - self.valid_sampled_token_count_gpu = None self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -4260,7 +4021,7 @@ def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, + spec_decode_common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, @@ -4498,9 +4259,6 @@ def _copy_valid_sampled_token_count( counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) self.valid_sampled_token_count_event.record() - if self.use_async_spec_decode: - # Stash for GPU-side correction in _prepare_inputs. - self.valid_sampled_token_count_gpu = valid_sampled_tokens_count self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) def _get_valid_sampled_token_count(self) -> list[int]: @@ -4630,7 +4388,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, + common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, @@ -4669,7 +4427,7 @@ def propose_draft_token_ids( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - self.optimistic_seq_lens_cpu, + common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, @@ -4883,9 +4641,6 @@ def load_model(self, load_dummy_weights: bool = False) -> None: and mm_config is not None and mm_config.is_multimodal_pruning_enabled() ) - self.requires_sequential_video_encoding = hasattr( - self.get_model(), "requires_sequential_video_encoding" - ) # Temporary hack for dynamic res video w/o support for bs>1 yet if ( is_mixture_of_experts(self.model) @@ -5420,28 +5175,17 @@ def _dummy_run( # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches - seq_lens = torch.tensor( # type: ignore[assignment] - [1] * num_decode_tokens + [num_prefill_tokens + 1], - dtype=torch.int, - ) + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] else: seq_lens = max_query_len # type: ignore[assignment] - self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens - self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) - self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True) + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() - cum_num_tokens = self._get_cumsum_and_arange( - num_scheduled_tokens, self.query_pos.np - ) + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - # Sync block table CPU->GPU so cleared rows from - # remove_request() are visible to the attention metadata - # builder. Without this, stale block IDs from finished - # requests can corrupt Mamba state. - self.input_batch.block_table.commit_block_table(num_reqs_padded) - pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL attn_metadata, _ = self._build_attention_metadata( num_tokens=num_tokens_unpadded, @@ -5484,7 +5228,7 @@ def _dummy_run( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions[:num_tokens_padded] + positions = self.positions.gpu[:num_tokens_padded] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -5899,10 +5643,7 @@ def _cleanup_profiling_kv_cache(self) -> None: for layer in self.compilation_config.static_forward_context.values(): if hasattr(layer, "kv_cache"): - kv_cache = layer.kv_cache - layer.kv_cache = ( - torch.tensor([]) if isinstance(kv_cache, torch.Tensor) else [] - ) + layer.kv_cache = [] gc.collect() torch.accelerator.empty_cache() @@ -5952,7 +5693,7 @@ def profile_cudagraph_memory(self) -> int: mem_samples: list[int] = [] for i, desc in enumerate(profile_descs): - mem_before = torch.cuda.mem_get_info()[0] + mem_before = current_platform.torch_device_fn.mem_get_info()[0] self._warmup_and_capture( desc, cudagraph_runtime_mode=mode, @@ -5966,7 +5707,7 @@ def profile_cudagraph_memory(self) -> int: ), ) torch.accelerator.synchronize() - free_after = torch.cuda.mem_get_info()[0] + free_after = current_platform.torch_device_fn.mem_get_info()[0] mem_samples.append(mem_before - free_after) first_capture = mem_samples[0] @@ -6019,31 +5760,6 @@ def capture_model(self) -> int: ) return 0 - # Initialize encoder CUDA graph manager if enabled. - # Use get_model() to unwrap GraphWrapper/UBatchWrapper, - # because @runtime_checkable Protocol isinstance() checks do not - # work through __getattr__ forwarding. - if ( - self.compilation_config.cudagraph_mm_encoder - and self.supports_mm_inputs - and self.encoder_cudagraph_manager is None - ): - from vllm.model_executor.models.interfaces import ( - SupportsEncoderCudaGraph, - supports_encoder_cudagraph, - ) - from vllm.v1.worker.encoder_cudagraph import EncoderCudaGraphManager - - raw_model = self.get_model() - if supports_encoder_cudagraph(raw_model): - self.encoder_cudagraph_manager = EncoderCudaGraphManager( - vllm_config=self.vllm_config, - device=self.device, - dtype=self.dtype, - model=cast(SupportsEncoderCudaGraph, raw_model), - ) - logger.info("Initialized EncoderCudaGraphManager for vision encoder") - compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() @@ -6055,7 +5771,7 @@ def capture_model(self) -> int: with self._freeze_gc(), graph_capture(device=self.device): torch.accelerator.synchronize() torch.accelerator.empty_cache() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] + start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] for ( runtime_mode, @@ -6067,12 +5783,8 @@ def capture_model(self) -> int: ) torch.accelerator.synchronize() - # Capture encoder CUDA graphs if enabled - if self.encoder_cudagraph_manager is not None: - self.encoder_cudagraph_manager.capture() - torch.accelerator.synchronize() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] + end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0] # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. From 3cfb83544f666f1fc7a716ad7b1da2100f35f680 Mon Sep 17 00:00:00 2001 From: cyberpioneer Date: Fri, 3 Apr 2026 09:01:50 +0000 Subject: [PATCH 4/7] fix --- .github/ISSUE_TEMPLATE/blank.md | 1 - benchmarks/flagos_eval/README.md | 2 +- examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json | 1 - tests/unit_tests/ops/test_layernorm.py | 1 + tests/unit_tests/worker/test_model_runner.py | 2 +- vllm_fl/dispatch/backends/vendor/ascend/ascend.py | 2 +- .../dispatch/backends/vendor/ascend/impl/normalization.py | 2 +- vllm_fl/dispatch/backends/vendor/ascend/register_ops.py | 2 +- vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py | 2 +- vllm_fl/dispatch/builtin_ops.py | 3 --- vllm_fl/distributed/communicator.py | 6 +----- vllm_fl/distributed/device_communicators/flagcx.py | 2 +- vllm_fl/version.py | 1 - 13 files changed, 9 insertions(+), 18 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/blank.md b/.github/ISSUE_TEMPLATE/blank.md index 9399f4fd..5d50f39e 100644 --- a/.github/ISSUE_TEMPLATE/blank.md +++ b/.github/ISSUE_TEMPLATE/blank.md @@ -5,4 +5,3 @@ title: '' labels: '' assignees: '' --- - diff --git a/benchmarks/flagos_eval/README.md b/benchmarks/flagos_eval/README.md index b089268e..cfcfed83 100644 --- a/benchmarks/flagos_eval/README.md +++ b/benchmarks/flagos_eval/README.md @@ -6,7 +6,7 @@ Evaluation toolkit for large language models with LM Eval and vLLM Benchmark. ### 1. Dependencies -* **[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)** +* **[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)** ```bash git clone https://github.com/EleutherAI/lm-evaluation-harness.git cd lm-evaluation-harness diff --git a/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json b/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json index dc478aaa..94408e27 100644 --- a/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json +++ b/examples/fuse_moe_tune/E=512,N=128,device_name=cuda.json @@ -144,4 +144,3 @@ "num_stages": 3 } } - diff --git a/tests/unit_tests/ops/test_layernorm.py b/tests/unit_tests/ops/test_layernorm.py index 3dcbe1ec..04acfe5a 100644 --- a/tests/unit_tests/ops/test_layernorm.py +++ b/tests/unit_tests/ops/test_layernorm.py @@ -15,6 +15,7 @@ class TestRMSNormFL: def __init__(self): from vllm.config import VllmConfig, set_current_vllm_config + set_current_vllm_config(VllmConfig()) @pytest.fixture diff --git a/tests/unit_tests/worker/test_model_runner.py b/tests/unit_tests/worker/test_model_runner.py index 04a15981..d1e8f998 100644 --- a/tests/unit_tests/worker/test_model_runner.py +++ b/tests/unit_tests/worker/test_model_runner.py @@ -60,7 +60,7 @@ def test_fields_match_expected_contract(self): "aux_hidden_states", "ec_connector_output", "cudagraph_stats", - 'slot_mappings', + "slot_mappings", ) assert ExecuteModelState._fields == expected_fields, ( "ExecuteModelState fields changed - this may break execute_model consumers" diff --git a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py index 9d0db4b1..eb787948 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/ascend.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/ascend.py @@ -145,4 +145,4 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) -> if use_sparse: raise NotImplementedError("MLA with sparse attention is not implemented for Ascend yet.") return "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendMLABackend" - return "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendAttentionBackend" \ No newline at end of file + return "vllm_fl.dispatch.backends.vendor.ascend.impl.attention.AscendAttentionBackend" diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py index f5072a64..8bcc2672 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/normalization.py @@ -38,4 +38,4 @@ def rms_norm_ascend( return x, residual x, _ = torch_npu.npu_rms_norm(x, weight, epsilon) - return x \ No newline at end of file + return x diff --git a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py index 88491153..fc140d08 100644 --- a/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py +++ b/vllm_fl/dispatch/backends/vendor/ascend/register_ops.py @@ -76,4 +76,4 @@ def register_builtins(registry: OpRegistry) -> None: ), ] - registry.register_many(impls) \ No newline at end of file + registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py index e54f7987..d4896cd4 100644 --- a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py +++ b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py @@ -58,7 +58,7 @@ def is_available(self) -> bool: IluvatarBackend._available = True else: IluvatarBackend._available = False - + else: IluvatarBackend._available = False except Exception: diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py index 2d0883d9..c8a8c990 100644 --- a/vllm_fl/dispatch/builtin_ops.py +++ b/vllm_fl/dispatch/builtin_ops.py @@ -74,9 +74,6 @@ def _register_vendor_backends(registry: OpRegistry) -> None: logger.debug(f"Vendor backends directory not found: {_VENDOR_BACKENDS_DIR}") return - for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR): - vendor_path = os.path.join(_VENDOR_BACKENDS_DIR, vendor_name) - available_vendor_dirs = { vendor_name for vendor_name in os.listdir(_VENDOR_BACKENDS_DIR) diff --git a/vllm_fl/distributed/communicator.py b/vllm_fl/distributed/communicator.py index f037dbaa..ae4b243c 100644 --- a/vllm_fl/distributed/communicator.py +++ b/vllm_fl/distributed/communicator.py @@ -142,7 +142,7 @@ def destroy(self): if self.all2all_manager is not None: self.all2all_manager.destroy() self.all2all_manager = None - + def all_gatherv(self, input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, @@ -207,7 +207,3 @@ def combine(self, hidden_states = self.all2all_manager.combine(hidden_states, is_sequence_parallel) return hidden_states - - - - diff --git a/vllm_fl/distributed/device_communicators/flagcx.py b/vllm_fl/distributed/device_communicators/flagcx.py index d0db73d7..f620166d 100644 --- a/vllm_fl/distributed/device_communicators/flagcx.py +++ b/vllm_fl/distributed/device_communicators/flagcx.py @@ -123,7 +123,7 @@ def __init__( self.all_reduce(data) stream.synchronize() del data - + def all_reduce(self, in_tensor: torch.Tensor, out_tensor: torch.Tensor = None, diff --git a/vllm_fl/version.py b/vllm_fl/version.py index 1ef06b0e..6c7597a4 100644 --- a/vllm_fl/version.py +++ b/vllm_fl/version.py @@ -89,4 +89,3 @@ def _load_scm() -> tuple[str | None, str | None]: "id": git_version, "date": _scm_date or _git_commit_date_from_repo() or "Unknown", } - From 8268235a11d86e19c53a49a42781db709bee9820 Mon Sep 17 00:00:00 2001 From: keennddyl Date: Fri, 3 Apr 2026 18:43:22 +0800 Subject: [PATCH 5/7] add musa support (#97) ### PR Category Vendor ### PR Type New Features ### Description This pull request adds support for the MUSA hardware backend throughout the codebase, enabling vLLM-FL to run on MUSA devices with appropriate configuration, device handling, and operator dispatch. The main changes include platform detection, device context management, configuration updates, and backend selection logic for MUSA. Platform and Device Support: * Added detection and handling for the "musa" platform in platform utilities and device capability queries, including `is_cuda_alike`, `is_cuda`, and a new `is_musa` method in `platform.py` [[1]](diffhunk://#diff-e62f96d38d994f2068a59e290d710e4900afc1b54bd4f58334de77c01c233c57R64-R77) [[2]](diffhunk://#diff-e62f96d38d994f2068a59e290d710e4900afc1b54bd4f58334de77c01c233c57L107-R115) [[3]](diffhunk://#diff-e62f96d38d994f2068a59e290d710e4900afc1b54bd4f58334de77c01c233c57R153-R155) [[4]](diffhunk://#diff-e62f96d38d994f2068a59e290d710e4900afc1b54bd4f58334de77c01c233c57L337-R350) [[5]](diffhunk://#diff-55ef01748202be695636b2c430841f377ace4d833eddea714e620f6cac70fd9eL49-R49). * Updated device context management in `flagcx.py` to use `torch.musa.device` when running on MUSA hardware. Operator Dispatch and Configuration: * Introduced a new `musa.yaml` dispatch configuration file specifying backend preferences, operator backend order, and blacklists for MUSA hardware. Graph and Execution Support: * Added support for `torch.musa.MUSAGraph` in the graph compilation logic to enable graph execution on MUSA devices. These changes collectively ensure that vLLM-FL can detect, configure, and efficiently utilize MUSA hardware in a manner similar to CUDA and other supported platforms. ### Related Issues ### Changes - ### Testing - ### Checklist - [ ] I have run the existing tests and they pass - [ ] I have added tests for my changes (if applicable) - [ ] I have updated the documentation (if applicable) --------- Co-authored-by: jiamingwang-mt Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: keennddyl Co-authored-by: hozier Co-authored-by: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> --- requirements.txt | 11 ++ requirements/musa.txt | 0 vllm_fl/compilation/graph.py | 2 + .../dispatch/backends/vendor/musa/__init__.py | 9 ++ .../backends/vendor/musa/impl/__init__.py | 5 + .../backends/vendor/musa/impl/activation.py | 26 ++++ .../vendor/musa/impl/normalization.py | 44 ++++++ .../backends/vendor/musa/impl/rotary.py | 98 ++++++++++++ vllm_fl/dispatch/backends/vendor/musa/musa.py | 139 ++++++++++++++++++ .../backends/vendor/musa/register_ops.py | 78 ++++++++++ vllm_fl/dispatch/config/musa.yaml | 61 ++++++++ vllm_fl/dispatch/config/utils.py | 2 +- .../device_communicators/flagcx.py | 11 +- vllm_fl/platform.py | 17 ++- vllm_fl/utils.py | 6 +- 15 files changed, 501 insertions(+), 8 deletions(-) create mode 100644 requirements.txt create mode 100644 requirements/musa.txt create mode 100644 vllm_fl/dispatch/backends/vendor/musa/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/impl/__init__.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/impl/activation.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/impl/normalization.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/impl/rotary.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/musa.py create mode 100644 vllm_fl/dispatch/backends/vendor/musa/register_ops.py create mode 100644 vllm_fl/dispatch/config/musa.yaml diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..f90450c7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +decorator +pyyaml +scipy +modelscope>=1.18.1 +setuptools <= 79.0.1 +setuptools-scm +scikit-build-core==0.11 +pybind11 +ninja +cmake +fastsafetensors diff --git a/requirements/musa.txt b/requirements/musa.txt new file mode 100644 index 00000000..e69de29b diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index 74e515a5..111aa9d6 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -46,6 +46,8 @@ class Graph: graph = torch.cuda.CUDAGraph elif current_platform.device_type == "npu": graph = torch.npu.NPUGraph + elif current_platform.device_type == "musa": + graph = torch.musa.MUSAGraph else: raise NotImplementedError("not support graph") diff --git a/vllm_fl/dispatch/backends/vendor/musa/__init__.py b/vllm_fl/dispatch/backends/vendor/musa/__init__.py new file mode 100644 index 00000000..a143a1fb --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +MUSA backend for vllm-plugin-FL dispatch. +""" + +from .musa import MusaBackend + +__all__ = ["MusaBackend"] diff --git a/vllm_fl/dispatch/backends/vendor/musa/impl/__init__.py b/vllm_fl/dispatch/backends/vendor/musa/impl/__init__.py new file mode 100644 index 00000000..23c6f3e3 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/impl/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +MUSA operator implementations. +""" diff --git a/vllm_fl/dispatch/backends/vendor/musa/impl/activation.py b/vllm_fl/dispatch/backends/vendor/musa/impl/activation.py new file mode 100644 index 00000000..2f9ddadf --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/impl/activation.py @@ -0,0 +1,26 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Reference activation operator implementations using PyTorch. +""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def silu_and_mul_musa(obj, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication using PyTorch. + + Args: + obj: The calling obj (for interface consistency) + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + d = x.shape[-1] // 2 + x1, x2 = x[..., :d], x[..., d:] + return F.silu(x1) * x2 diff --git a/vllm_fl/dispatch/backends/vendor/musa/impl/normalization.py b/vllm_fl/dispatch/backends/vendor/musa/impl/normalization.py new file mode 100644 index 00000000..87b69ca7 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/impl/normalization.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Reference normalization operator implementations using PyTorch. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + + +def rms_norm_musa( + obj, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization using PyTorch. + + Args: + obj: The calling obj (e.g., RMSNorm layer) + x: Input tensor + residual: Optional residual tensor + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + # Get weight and epsilon from obj + weight = obj.weight + epsilon = obj.variance_epsilon + + if residual is not None: + x = x + residual + residual = x + + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + epsilon) + output = weight * x + + if residual is not None: + return output, residual + return output \ No newline at end of file diff --git a/vllm_fl/dispatch/backends/vendor/musa/impl/rotary.py b/vllm_fl/dispatch/backends/vendor/musa/impl/rotary.py new file mode 100644 index 00000000..b70322fe --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/impl/rotary.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +MUSA rotary embedding operator implementations. + +NOTE: This is a template/stub implementation using PyTorch reference code. +Replace with actual Musa-optimized implementations when available. +""" + +from __future__ import annotations + +import torch + + +def rotary_embedding_musa( + obj, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding using Musa. + + This is a placeholder implementation using PyTorch reference code. + TODO: Replace with actual Musa GPU optimized implementation. + + Args: + obj: The calling obj (for interface consistency) + query: Query tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + key: Key tensor [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + cos: Cosine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 + sin: Sine cache [max_seq_len, rotary_dim] where rotary_dim = head_dim or head_dim // 2 + position_ids: Position indices [batch, seq_len] or [seq_len] + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place (ignored in reference impl) + + Returns: + Tuple of (embedded_query, embedded_key) + """ + # Get cos/sin for the positions + # position_ids can be [batch, seq_len] or [seq_len] + if position_ids.dim() == 1: + # [seq_len] -> [seq_len, rotary_dim] + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + else: + # [batch, seq_len] -> [batch, seq_len, rotary_dim] + cos_selected = cos[position_ids] + sin_selected = sin[position_ids] + + # Expand dimensions to match query/key shape + # query/key: [batch, num_heads, seq_len, head_dim] or [seq_len, num_heads, head_dim] + if query.dim() == 4: + # [batch, num_heads, seq_len, head_dim] + # cos_selected: [batch, seq_len, rotary_dim] -> [batch, 1, seq_len, rotary_dim] + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + elif query.dim() == 3: + # [seq_len, num_heads, head_dim] + # cos_selected: [seq_len, rotary_dim] -> [seq_len, 1, rotary_dim] + cos_selected = cos_selected.unsqueeze(1) + sin_selected = sin_selected.unsqueeze(1) + + # Check if we need to repeat cos/sin to match head_dim + rotary_dim = cos_selected.shape[-1] + head_dim = query.shape[-1] + + if rotary_dim != head_dim: + # cos/sin only covers half of head_dim, need to repeat + # This handles the case where rotary is only applied to part of the dimensions + cos_selected = torch.cat([cos_selected, cos_selected], dim=-1) + sin_selected = torch.cat([sin_selected, sin_selected], dim=-1) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + if rotary_interleaved: + # Interleaved rotary + def rotate_interleaved(x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + q_embed = (query * cos_selected) + (rotate_interleaved(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_interleaved(key) * sin_selected) + else: + # Standard rotary (neox style) + q_embed = (query * cos_selected) + (rotate_half(query) * sin_selected) + k_embed = (key * cos_selected) + (rotate_half(key) * sin_selected) + + return q_embed, k_embed \ No newline at end of file diff --git a/vllm_fl/dispatch/backends/vendor/musa/musa.py b/vllm_fl/dispatch/backends/vendor/musa/musa.py new file mode 100644 index 00000000..5a8ed08f --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/musa.py @@ -0,0 +1,139 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +MUSA backend implementation. + +This backend provides operator implementations for Musa GPUs. +Musa uses a CUDA-compatible architecture. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from vllm_fl.dispatch.backends.base import Backend + + +class MusaBackend(Backend): + """ + Musa backend for operator implementations. + + This backend uses Musa libraries to provide high-performance + operator implementations for Musa GPUs. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "musa" + + @property + def vendor(self) -> Optional[str]: + return "musa" + + def is_available(self) -> bool: + """ + Check if Musa hardware and libraries are available. + + This method uses the platform's vendor information to determine + if the device is an Musa GPU. + """ + return torch.musa.is_available() + + # ==================== Operator Implementations ==================== + + def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: + """ + SiLU activation followed by element-wise multiplication. + + Args: + obj: The calling obj (for interface consistency) + x: Input tensor of shape [..., 2*d] + + Returns: + Output tensor of shape [..., d] + """ + from .impl.activation import silu_and_mul_musa + + return silu_and_mul_musa(obj, x) + + def rms_norm( + self, + obj, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + RMS normalization. + + Args: + obj: The calling obj (e.g., RMSNorm layer) + x: Input tensor + residual: Optional residual tensor + + Returns: + Normalized tensor, or tuple of (normalized, residual) if residual is provided + """ + from .impl.normalization import rms_norm_musa + + return rms_norm_musa(obj, x, residual) + + def rotary_embedding( + self, + obj, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor, + rotary_interleaved: bool = False, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embedding. + + Args: + obj: The calling obj (for interface consistency) + query: Query tensor + key: Key tensor + cos: Cosine cache + sin: Sine cache + position_ids: Position indices + rotary_interleaved: Whether to use interleaved rotary + inplace: Whether to modify tensors in-place + + Returns: + Tuple of (embedded_query, embedded_key) + """ + from .impl.rotary import rotary_embedding_musa + + return rotary_embedding_musa( + obj, + query, + key, + cos, + sin, + position_ids, + rotary_interleaved=rotary_interleaved, + inplace=inplace, + ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for Musa. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + from vllm.attention.backends.registry import AttentionBackendEnum + + if use_mla: + return AttentionBackendEnum.TRITON_MLA.get_path() + + return AttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_fl/dispatch/backends/vendor/musa/register_ops.py b/vllm_fl/dispatch/backends/vendor/musa/register_ops.py new file mode 100644 index 00000000..1698aa05 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/musa/register_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +MUSA backend operator registrations. + +This module registers all VENDOR (MUSA) implementations. +""" + +from __future__ import annotations + +import functools + +from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all MUSA (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + from .musa import MusaBackend + + backend = MusaBackend() + is_avail = backend.is_available + + impls = [ + # Activation + OpImpl( + op_name="silu_and_mul", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu_and_mul, is_avail), + vendor="musa", + priority=BackendPriority.VENDOR, + ), + # Normalization + OpImpl( + op_name="rms_norm", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rms_norm, is_avail), + vendor="musa", + priority=BackendPriority.VENDOR, + ), + # Rotary Embedding + OpImpl( + op_name="rotary_embedding", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rotary_embedding, is_avail), + vendor="musa", + priority=BackendPriority.VENDOR, + ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="vendor.musa", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor="musa", + priority=BackendPriority.VENDOR, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/config/musa.yaml b/vllm_fl/dispatch/config/musa.yaml new file mode 100644 index 00000000..cc493e0c --- /dev/null +++ b/vllm_fl/dispatch/config/musa.yaml @@ -0,0 +1,61 @@ +# vLLM-FL Dispatch Configuration for Musa GPU +# Auto-loaded when running on Musa hardware + +# Preferred default backend type: flagos, vendor, reference +prefer: flagos + +# Strict Mode: +# true = Raise an error immediately on failure; do not attempt other backends. +# false = Attempt the next available backend in sequence upon failure (Default). +strict: false + +# Vendor Whitelist (Optional, allows all if not set) +# allow_vendors: +# - musa + +# Vendor Blacklist (Optional) +# deny_vendors: +# - ascend +# - cuda +# - metax + +# Per-operator backend execution order (Optional) +# Only the backends listed here will be attempted, in the order specified. +# +# Supported tokens: +# - flagos : Default FlagGems implementation (Triton) +# - reference : PyTorch reference implementation +# - vendor : Any available vendor backend (auto-detected) +# - vendor:musa : Musa-specific vendor backend +op_backends: + # attention_backends: prioritize vendor, fallback to FlagGems (Triton), then reference, others: prioritize FlagGems (Triton), fallback to vendor, then reference + attention_backend: + - vendor:musa + - flagos + - reference + + rms_norm: + - flagos + - vendor:musa + - reference + + silu_and_mul: + - flagos + - vendor:musa + - reference + + rotary_embedding: + - flagos + - vendor:musa + - reference + +# FlagOS operator blacklist +# Musa is CUDA-compatible, so most FlagGems ops should work +# Blacklist only ops that are known to have issues +flagos_blacklist: + - scaled_dot_product_attention + +# OOT (Out-of-Tree) operator blacklist +# These operators will NOT be registered as OOT replacements. +# oot_blacklist: +# - fused_moe diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py index 3877923c..087239a1 100644 --- a/vllm_fl/dispatch/config/utils.py +++ b/vllm_fl/dispatch/config/utils.py @@ -46,7 +46,7 @@ def get_platform_name() -> str: Detect the current hardware platform. Returns: - Platform name string: 'ascend', 'iluvatar', 'cuda', or 'unknown' + Platform name string: 'ascend', 'iluvatar', 'musa', 'cuda', or 'unknown' """ try: from vllm.platforms import current_platform diff --git a/vllm_fl/distributed/device_communicators/flagcx.py b/vllm_fl/distributed/device_communicators/flagcx.py index f620166d..922624c3 100644 --- a/vllm_fl/distributed/device_communicators/flagcx.py +++ b/vllm_fl/distributed/device_communicators/flagcx.py @@ -111,9 +111,14 @@ def __init__( assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device - # `torch.cuda.device` is a context manager that changes the - # current cuda device to the specified one - with torch.cuda.device(device): + # `torch.cuda.device` / `torch.musa.device` are context managers that + # change the current device to the specified one + if self.device.type == "musa": + device_ctx = torch.musa.device(self.device) + else: + device_ctx = torch.cuda.device(self.device) + + with device_ctx: self.comm = self.flagcx.flagcxCommInitRank( self.world_size, ctypes.byref(self.unique_id), self.rank) diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index c53e8332..203e7a96 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -67,12 +67,20 @@ def is_cuda_alike(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" if self.vendor_name == "iluvatar": return False + if self.vendor_name == "musa": + return True return self.device_type == "cuda" def is_cuda(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" + if self.vendor_name == "musa": + return True return self.device_type == "cuda" and self.vendor_name == "nvidia" + def is_musa(self) -> bool: + if hasattr(torch, 'musa') and torch.musa.is_available(): + return True + return False @property def supported_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.float16, torch.float32] @@ -110,7 +118,7 @@ def get_device_name(cls, device_id: int = 0) -> str: ### TODO(lms): change pin_memory depend device @classmethod def is_pin_memory_available(cls): - if cls.device_type in ["cuda", "xpu", "npu"]: + if cls.device_type in ["cuda", "xpu", "npu", "musa"]: return True return False @@ -151,6 +159,9 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if cls.device_type == "npu": cache_config.block_size = 128 logger.info("Setting kv cache block size to 128 for Ascend NPU.") + elif cls.device_type == "musa": + cache_config.block_size = 64 + logger.info("Setting kv cache block size to 64 for MUSA.") else: cache_config.block_size = 16 @@ -365,7 +376,9 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: # TODO(yxa): For NPU/Ascend devices, return None (no capability version like CUDA) if cls.device_type == "npu": return None - # For CUDA devices + if cls.device_type == "musa": + major, minor = torch.musa.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py index afbc553c..73a68b93 100644 --- a/vllm_fl/utils.py +++ b/vllm_fl/utils.py @@ -24,7 +24,7 @@ # Source: runtime platform detection (current_platform.device_name). # # Values are normalized to lowercase and matched against available backend -# subdirectories (for example, cuda/ascend/metax/iluvatar). +# subdirectories (for example, cuda/ascend/metax/iluvatar/mthreads). VENDOR_DEVICE_MAP: dict[str, dict[str, str]] = { # Registered backend: vendor/cuda "nvidia": {"device_type": "cuda", "device_name": "nvidia"}, @@ -34,6 +34,8 @@ "iluvatar": {"device_type": "cuda", "device_name": "cuda"}, # Registered backend: vendor/metax "metax": {"device_type": "cuda", "device_name": "metax"}, + # Registered backend: vendor/musa + "mthreads": {"device_type": "musa", "device_name": "musa"}, } @@ -202,7 +204,7 @@ def get_op_config() -> Optional[dict[str, str]]: class DeviceInfo: def __init__(self): self.device = DeviceDetector() - self.supported_device = ["nvidia", "ascend", "metax"] + self.supported_device = ["nvidia", "ascend", "metax", "mthreads"] backend.set_torch_backend_device_fn(self.device.vendor_name) @property From 136a3dc97cb7eaec6056fd7e5d182f83d2b2b2ca Mon Sep 17 00:00:00 2001 From: xiangbin <74005582+li199959@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:50:46 +0800 Subject: [PATCH 6/7] Feat/bge m3 test (#111) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Others Test Case Add e2e precision test for BGE-M3 embedding model covering three pooling modes: - **Dense**: cosine similarity between query and passage embeddings - **Lexical (sparse BM25)**: weighted token overlap score via `/tokenize` + `/pooling` (task=token_classify) - **ColBERT (multi-vector)**: MaxSim score via `/pooling` (task=token_embed) Also removes the outdated vLLM 0.13 backport note from README and updates the BAAI/bge-m3 entry to point to the implementation. - `tests/e2e_tests/serving/test_bge_m3.py`: New e2e test (179 lines) — dense, lexical, ColBERT precision validation - `README.md`: Remove vLLM 0.13 backport section; update BAAI/bge-m3 row to link to implementation - `pytest tests/e2e_tests/serving/test_bge_m3.py -v` - Requires server running: `vllm serve BAAI/bge-m3 --hf-overrides '{"architectures":["BgeM3EmbeddingModel"]}'` - [x] I have run the existing tests and they pass - [x] I have added tests for my changes - [] I have updated the documentation --------- Co-authored-by: ceci3 --- .github/scripts/cuda/setup.sh | 0 README.md | 1 + docker/build.sh | 0 tests/e2e_tests/serving/test_bge_m3.py | 222 +++++++++++++++++++++++++ vllm_fl/__init__.py | 9 + vllm_fl/models/__init__.py | 3 + vllm_fl/models/bge_m3.py | 208 +++++++++++++++++++++++ 7 files changed, 443 insertions(+) mode change 100755 => 100644 .github/scripts/cuda/setup.sh mode change 100755 => 100644 docker/build.sh create mode 100644 tests/e2e_tests/serving/test_bge_m3.py create mode 100644 vllm_fl/models/bge_m3.py diff --git a/.github/scripts/cuda/setup.sh b/.github/scripts/cuda/setup.sh old mode 100755 new mode 100644 diff --git a/README.md b/README.md index e4e31094..0e67e552 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n | MiniCPM-o 4.5 | Supported | [example](./examples/minicpm/) | | GLM-5 | Supported | [example](./examples/glm_5_offline_inference.py) | | Qwen3.5-35B-A3B | Supported | [example](./examples/glm_5_offline_inference.py) | +| BAAI/bge-m3 | Supported | [implementation](./vllm_fl/models/bge_m3.py) | ### Supported Chips diff --git a/docker/build.sh b/docker/build.sh old mode 100755 new mode 100644 diff --git a/tests/e2e_tests/serving/test_bge_m3.py b/tests/e2e_tests/serving/test_bge_m3.py new file mode 100644 index 00000000..005b2d91 --- /dev/null +++ b/tests/e2e_tests/serving/test_bge_m3.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +""" +BGE-M3 embedding precision test. + +Server must already be running at localhost:8000, e.g.: + + vllm serve BAAI/bge-m3 --hf-overrides '{"architectures":["BgeM3EmbeddingModel"]}' +""" + +import sys + +import numpy as np +import requests + +BASE_URL = "http://localhost:8000" +MODEL_NAME = "BAAI/bge-m3" + +SENTENCES_1 = ["What is BGE M3?", "Defination of BM25"] +SENTENCES_2 = [ + "BGE M3 is an embedding model supporting dense retrieval, " + "lexical matching and multi-vector interaction.", + "BM25 is a bag-of-words retrieval function that ranks a set " + "of documents based on the query terms appearing in each document", +] + +SIMILARITY_REFERENCE = [[0.6265, 0.3477], [0.3499, 0.678]] +LEXICAL_SCORE_REFERENCE = [0.181622, 0.0] +COLBERT_SCORE_REFERENCE = [0.7797, 0.4620] + +all_passed = True + + +def post(path: str, payload: dict) -> dict: + r = requests.post(f"{BASE_URL}{path}", json=payload, timeout=30) + r.raise_for_status() + return r.json() + + +def cosine_sim(a, b): + na = np.array(a) / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-10) + nb = np.array(b) / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-10) + return (na @ nb.T).tolist() + + +def test_dense(): + global all_passed + print("\n" + "=" * 60) + print("1. Dense Embedding (cosine similarity)") + print("=" * 60) + + emb1 = post( + "/v1/embeddings", + {"model": MODEL_NAME, "input": SENTENCES_1}, + )["data"] + emb2 = post( + "/v1/embeddings", + {"model": MODEL_NAME, "input": SENTENCES_2}, + )["data"] + + sim = cosine_sim( + [e["embedding"] for e in emb1], + [e["embedding"] for e in emb2], + ) + + ref = SIMILARITY_REFERENCE + diffs = [[abs(sim[i][j] - ref[i][j]) for j in range(2)] for i in range(2)] + max_diff = max(max(row) for row in diffs) + + print(" similarity matrix:") + print(f" vLLM: {sim[0][0]:.4f}, {sim[0][1]:.4f}") + print(f" {sim[1][0]:.4f}, {sim[1][1]:.4f}") + print(f" reference: {ref[0][0]:.4f}, {ref[0][1]:.4f}") + print(f" {ref[1][0]:.4f}, {ref[1][1]:.4f}") + print(f" max diff: {max_diff:.6f} (tolerance: 0.01)") + + passed = max_diff < 0.01 + print(f" {'✓ PASS' if passed else '✗ FAIL'}") + if not passed: + all_passed = False + return passed + + +def test_lexical(): + global all_passed + print("\n" + "=" * 60) + print("2. Lexical Sparse (BM25-style score)") + print("=" * 60) + + tokens1 = [ + post("/tokenize", {"model": MODEL_NAME, "prompt": s})["tokens"] + for s in SENTENCES_1 + ] + tokens2 = [ + post("/tokenize", {"model": MODEL_NAME, "prompt": s})["tokens"] + for s in SENTENCES_2 + ] + + sparse1 = post( + "/pooling", + {"model": MODEL_NAME, "input": SENTENCES_1, "task": "token_classify"}, + )["data"] + sparse2 = post( + "/pooling", + {"model": MODEL_NAME, "input": SENTENCES_2, "task": "token_classify"}, + )["data"] + + def merge(tokens, vals_per_token): + # vals_per_token: list of [val] (from /pooling data field) + if tokens and tokens[0] == 0: + tokens, vals_per_token = tokens[1:], vals_per_token[1:] + d = {} + for t, v in zip(tokens, vals_per_token): + val = float(v[0]) + if t not in d or val > d[t]: + d[t] = val + return d + + def lexical(a, b): + return sum(w * b[t] for t, w in a.items() if t in b) + + lw1 = [merge(t, s["data"]) for t, s in zip(tokens1, sparse1)] + lw2 = [merge(t, s["data"]) for t, s in zip(tokens2, sparse2)] + + score_1_0_x_2_0 = lexical(lw1[0], lw2[0]) + score_1_0_x_1_1 = lexical(lw1[0], lw1[1]) + + diff1 = abs(score_1_0_x_2_0 - LEXICAL_SCORE_REFERENCE[0]) + diff2 = abs(score_1_0_x_1_1 - LEXICAL_SCORE_REFERENCE[1]) + + print( + f" sent1[0] vs sent2[0]: vLLM={score_1_0_x_2_0:.6f} " + f"ref={LEXICAL_SCORE_REFERENCE[0]:.6f} diff={diff1:.6f}" + ) + print( + f" sent1[0] vs sent1[1]: vLLM={score_1_0_x_1_1:.6f} " + f"ref={LEXICAL_SCORE_REFERENCE[1]:.6f} diff={diff2:.6f}" + ) + + passed = diff1 < 0.05 * LEXICAL_SCORE_REFERENCE[0] and diff2 < 0.05 + print(f" {'✓ PASS' if passed else '✗ FAIL'}") + if not passed: + all_passed = False + return passed + + +def test_colbert(): + global all_passed + print("\n" + "=" * 60) + print("3. Multi-Vector ColBERT (MaxSim score)") + print("=" * 60) + + emb1 = post( + "/pooling", + {"model": MODEL_NAME, "input": SENTENCES_1, "task": "token_embed"}, + )["data"] + emb2 = post( + "/pooling", + {"model": MODEL_NAME, "input": SENTENCES_2, "task": "token_embed"}, + )["data"] + + def colbert(q_data, p_data): + # token_embed: data is [[f1,f2,...,f1024], ...] — list of token vectors + # Already flat vectors, just convert to numpy + q = np.array(q_data) + p = np.array(p_data) + q_norm = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-10) + p_norm = p / (np.linalg.norm(p, axis=1, keepdims=True) + 1e-10) + scores = q_norm @ p_norm.T # (n_query_tokens, n_passage_tokens) + return float(np.mean(np.max(scores, axis=1))) + + score_1_0_x_2_0 = colbert(emb1[0]["data"], emb2[0]["data"]) + score_1_0_x_2_1 = colbert(emb1[0]["data"], emb2[1]["data"]) + + diff1 = abs(score_1_0_x_2_0 - COLBERT_SCORE_REFERENCE[0]) + diff2 = abs(score_1_0_x_2_1 - COLBERT_SCORE_REFERENCE[1]) + + print( + f" sent1[0] vs sent2[0]: vLLM={score_1_0_x_2_0:.6f} " + f"ref={COLBERT_SCORE_REFERENCE[0]:.6f} diff={diff1:.6f}" + ) + print( + f" sent1[0] vs sent2[1]: vLLM={score_1_0_x_2_1:.6f} " + f"ref={COLBERT_SCORE_REFERENCE[1]:.6f} diff={diff2:.6f}" + ) + + passed = ( + diff1 < 0.01 * COLBERT_SCORE_REFERENCE[0] + and diff2 < 0.01 * COLBERT_SCORE_REFERENCE[1] + ) + print(f" {'✓ PASS' if passed else '✗ FAIL'}") + if not passed: + all_passed = False + return passed + + +if __name__ == "__main__": + print("BGE-M3 Embedding Precision Test") + print(f"Server: {BASE_URL}") + + try: + r = requests.get(f"{BASE_URL}/health", timeout=5) + print(f"Server status: {r.status_code} OK\n") + except Exception: + print("✗ Server is not running at localhost:8000") + print( + " Please start with: vllm serve BAAI/bge-m3 " + "--hf-overrides " + '\'{"architectures":["BgeM3EmbeddingModel"]}\'' + ) + sys.exit(1) + + test_dense() + test_lexical() + test_colbert() + + print("\n" + "=" * 60) + if all_passed: + print("✓ ALL TESTS PASSED") + sys.exit(0) + else: + print("✗ SOME TESTS FAILED") + sys.exit(1) diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 117a6502..f53973ce 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -60,3 +60,12 @@ def register_model(): #glm5_model() except Exception as e: logger.error(f"Register GlmMoeDsa model error: {str(e)}") + + # Register BGE-M3 pooling backport for vLLM 0.13.x + try: + ModelRegistry.register_model( + "BgeM3EmbeddingModel", + "vllm_fl.models.bge_m3:BgeM3EmbeddingModel", + ) + except Exception as e: + logger.error(f"Register BgeM3EmbeddingModel error: {str(e)}") diff --git a/vllm_fl/models/__init__.py b/vllm_fl/models/__init__.py index e69de29b..58b12e17 100644 --- a/vllm_fl/models/__init__.py +++ b/vllm_fl/models/__init__.py @@ -0,0 +1,3 @@ +from .bge_m3 import BgeM3EmbeddingModel + +__all__ = ["BgeM3EmbeddingModel"] diff --git a/vllm_fl/models/bge_m3.py b/vllm_fl/models/bge_m3.py new file mode 100644 index 00000000..0b23c6dc --- /dev/null +++ b/vllm_fl/models/bge_m3.py @@ -0,0 +1,208 @@ +import itertools +from collections.abc import Iterable, Set + +import torch +from torch import nn + +from vllm.config import PoolerConfig, VllmConfig, get_current_vllm_config +from vllm.model_executor.layers.pooler import ( + AllPooler, + DispatchPooler, + Pooler, + PoolerNormalize, + PoolingParamsUpdate, + StepPooler, +) +from vllm.model_executor.model_loader import DefaultModelLoader +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.adapters import _load_st_projector +from vllm.model_executor.models.roberta import RobertaEmbeddingModel +from vllm.pooling_params import PoolingParams +from vllm.tasks import PoolingTask +from vllm.v1.outputs import PoolerOutput +from vllm.v1.pool.metadata import PoolingMetadata + + +def filter_secondary_weights( + all_weights: Iterable[tuple[str, torch.Tensor]], + secondary_weight_prefixes: list[str], +) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: + all_weights_1, all_weights_2 = itertools.tee(all_weights) + + def is_secondary(name: str) -> bool: + return any(name.startswith(prefix) for prefix in secondary_weight_prefixes) + + secondary = ( + (name, weight) for name, weight in all_weights_1 if is_secondary(name) + ) + primary = ( + (name, weight) for name, weight in all_weights_2 if not is_secondary(name) + ) + return secondary, primary + + +class TokenEmbeddingProjectionHead(nn.Module): + def __init__(self, projector: nn.Module | None) -> None: + super().__init__() + vllm_config = get_current_vllm_config() + assert vllm_config is not None + + self.projector = _load_st_projector(vllm_config.model_config) + self.token_projector = projector + self.activation = PoolerNormalize() + self.head_dtype = vllm_config.model_config.head_dtype + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed"} + + def forward( + self, + pooled_data: torch.Tensor | None, + pooling_param: PoolingParams, + ) -> PoolerOutput: + if pooled_data is None: + return None + + pooled_data = pooled_data.to(self.head_dtype) + + if self.projector is not None: + pooled_data = self.projector(pooled_data) + + if self.token_projector is not None: + pooled_data = self.token_projector(pooled_data) + + pooled_data = pooled_data[..., : pooling_param.dimensions] + + if pooling_param.normalize: + pooled_data = self.activation(pooled_data) + + return pooled_data + + +class SpecialTokenFilterPooler(Pooler): + def __init__(self, pooler: Pooler, token_ids_to_skip: list[int | None]) -> None: + super().__init__() + self.pooler = pooler + self.token_ids_to_skip = tuple( + token_id for token_id in token_ids_to_skip if token_id is not None + ) + + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooler.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def _filter_one( + self, + data: torch.Tensor | None, + token_ids: torch.Tensor, + ) -> torch.Tensor | None: + if data is None: + return None + + keep_mask = torch.ones_like(token_ids, dtype=torch.bool) + for token_id in self.token_ids_to_skip: + keep_mask &= token_ids != token_id + return data[keep_mask] + + def forward( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + outputs = self.pooler(hidden_states, pooling_metadata) + prompt_token_ids = pooling_metadata.get_prompt_token_ids() + return [ + self._filter_one(output, token_ids) + for output, token_ids in zip(outputs, prompt_token_ids) + ] + + +class BgeM3EmbeddingModel(RobertaEmbeddingModel): + """Backport the vLLM 0.15 BGE-M3 embedding adapter to vLLM 0.13.x.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + + model_config = vllm_config.model_config + self.head_dtype = model_config.head_dtype + self.bos_token_id = model_config.hf_config.bos_token_id + self.eos_token_id = model_config.hf_config.eos_token_id + + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."] + self.secondary_weight_files = [ + weight_prefix + "pt" for weight_prefix in self.secondary_weight_prefixes + ] + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix=weight_prefix, + allow_patterns_overrides=[filename], + ) + for filename, weight_prefix in zip( + self.secondary_weight_files, + self.secondary_weight_prefixes, + ) + ] + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype) + self.colbert_linear = nn.Linear( + self.hidden_size, + self.hidden_size, + dtype=self.head_dtype, + ) + + token_embed_pooler = Pooler.for_token_embed(pooler_config) + if isinstance(token_embed_pooler, AllPooler): + token_embed_pooler = AllPooler( + head=TokenEmbeddingProjectionHead(self.colbert_linear) + ) + elif isinstance(token_embed_pooler, StepPooler): + token_embed_pooler = StepPooler( + head=TokenEmbeddingProjectionHead(self.colbert_linear) + ) + else: + raise TypeError( + f"Unsupported token_embed pooler: {type(token_embed_pooler)}" + ) + + token_classify_pooler = Pooler.for_token_classify( + pooler_config, + classifier=self.sparse_linear, + act_fn=torch.relu, + ) + + return DispatchPooler( + { + "embed": Pooler.for_embed(pooler_config), + "token_embed": SpecialTokenFilterPooler( + token_embed_pooler, + [self.bos_token_id], + ), + "token_classify": SpecialTokenFilterPooler( + token_classify_pooler, + [self.bos_token_id, self.eos_token_id], + ), + } + ) + + def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]): + secondary_weights, primary_weights = filter_secondary_weights( + all_weights, + self.secondary_weight_prefixes, + ) + + super().load_weights(primary_weights) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in secondary_weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) From 838341f74ff18e5e7a96690b86b8904155f679e0 Mon Sep 17 00:00:00 2001 From: XMing Date: Mon, 6 Apr 2026 16:16:17 +0800 Subject: [PATCH 7/7] Bump vLLM version to v0.18.0 in ci image (#117) ### PR Category CICD ### PR Type CI images and CI workflow configs ### Description Catch up the versions fo vLLM ### Related Issues No ### Changes - dockerfiles - ci workflow config ### Testing The workflow should be broken till core changes made in vllm-fl ### Checklist --- .github/configs/ascend.yml | 2 +- .github/configs/cuda.yml | 2 +- .github/scripts/ascend/setup.sh | 2 + .github/scripts/cuda/setup.sh | 2 + .github/workflows/ci.yml | 4 + docker/ascend/Dockerfile | 19 ++-- docker/ascend/Dockerfile.v0.1.0 | 79 +++++++++++++++ docker/build.sh | 4 +- docker/cuda/Dockerfile | 15 ++- pyproject.toml | 2 +- tests/models/qwen3/next_tp8.yaml | 6 +- tests/run.py | 21 +++- tests/utils/cleanup.py | 161 +++++++++++++++++++++++++++++++ 13 files changed, 289 insertions(+), 30 deletions(-) create mode 100644 docker/ascend/Dockerfile.v0.1.0 diff --git a/.github/configs/ascend.yml b/.github/configs/ascend.yml index fd1e9fee..5a3cc7f4 100644 --- a/.github/configs/ascend.yml +++ b/.github/configs/ascend.yml @@ -4,7 +4,7 @@ platform: ascend # Docker image for this hardware -ci_image: harbor.baai.ac.cn/flagscale/vllm-plugin-fl:v0.1.0-ascend-ci +ci_image: harbor.baai.ac.cn/flagscale/vllm-plugin-fl:v0.2.0-ascend-ci # Runner labels for this hardware runner_labels: diff --git a/.github/configs/cuda.yml b/.github/configs/cuda.yml index e4650af5..78d8e3b3 100644 --- a/.github/configs/cuda.yml +++ b/.github/configs/cuda.yml @@ -4,7 +4,7 @@ platform: cuda # Docker image for this hardware -ci_image: harbor.baai.ac.cn/flagscale/vllm-plugin-fl:v0.1.0-cuda-ci +ci_image: harbor.baai.ac.cn/flagscale/vllm-plugin-fl:v0.2.0-cuda-ci # Runner labels for this hardware runner_labels: diff --git a/.github/scripts/ascend/setup.sh b/.github/scripts/ascend/setup.sh index dc1119c4..dc9b6ffa 100644 --- a/.github/scripts/ascend/setup.sh +++ b/.github/scripts/ascend/setup.sh @@ -3,5 +3,7 @@ # Setup script for Ascend NPU CI environment. set -euo pipefail +git config --global --add safe.directory "$(pwd)" + pip install --upgrade pip "setuptools>=77.0.3" pip install --no-build-isolation -e ".[test]" diff --git a/.github/scripts/cuda/setup.sh b/.github/scripts/cuda/setup.sh index f22faa3d..86a4f345 100644 --- a/.github/scripts/cuda/setup.sh +++ b/.github/scripts/cuda/setup.sh @@ -3,5 +3,7 @@ # Setup script for CUDA CI environment. set -euo pipefail +git config --global --add safe.directory "$(pwd)" + uv pip install --system --upgrade pip uv pip install --system --no-build-isolation -e ".[test]" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 153ac43f..2d996e79 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,8 @@ on: paths-ignore: - "**.md" - "docs/**" + - "examples/**" + - "docker/**" - "LICENSE" - ".github/ISSUE_TEMPLATE/**" - ".github/PULL_REQUEST_TEMPLATE.md" @@ -17,6 +19,8 @@ on: paths-ignore: - "**.md" - "docs/**" + - "examples/**" + - "docker/**" - "LICENSE" - ".github/ISSUE_TEMPLATE/**" - ".github/PULL_REQUEST_TEMPLATE.md" diff --git a/docker/ascend/Dockerfile b/docker/ascend/Dockerfile index 9d10abcf..17ec33f4 100644 --- a/docker/ascend/Dockerfile +++ b/docker/ascend/Dockerfile @@ -1,20 +1,12 @@ -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 # ---------- base stage ---------- FROM quay.io/ascend/vllm-ascend:v${VLLM_VERSION}rc1-a3 AS base RUN pip install --upgrade pip setuptools -# CANN Toolkit environment variables (mirrors set_env.sh baked in at build time) -ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest -ENV LD_LIBRARY_PATH="${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64:${LD_LIBRARY_PATH}" \ - PYTHONPATH="${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}" \ - PATH="${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${ASCEND_TOOLKIT_HOME}/tools/ccec_compiler/bin:${PATH}" - -# Set ATB environment variables -ENV ATB_HOME_PATH=/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_1 -ENV LD_LIBRARY_PATH="${ATB_HOME_PATH}/lib:${ATB_HOME_PATH}/examples:${ATB_HOME_PATH}/tests/atbopstest:${LD_LIBRARY_PATH}" \ - PATH="${ATB_HOME_PATH}/bin:${PATH}" +# Add BiShengIR compiler to PATH +ENV PATH="${ASCEND_TOOLKIT_HOME}/tools/bishengir/bin:${PATH}" # ---------- dev stage ---------- FROM base AS dev @@ -49,8 +41,9 @@ RUN pip install \ cmake # Install FlagGems (NPU backend) +ARG FLAGGEMS_VERSION=v5.0.0 RUN pip install -U scikit-build-core==0.11 pybind11 \ - && git clone https://github.com/flagos-ai/FlagGems /workspace/FlagGems \ + && git clone --branch ${FLAGGEMS_VERSION} --depth 1 https://github.com/flagos-ai/FlagGems /workspace/FlagGems \ && pip install --no-build-isolation \ --config-settings=cmake.define.FLAGGEMS_BACKEND=NPU \ /workspace/FlagGems @@ -71,7 +64,7 @@ FROM base AS release ARG INDEX_URL ARG EXTRA_INDEX_URL -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 # Install vLLM # Todo diff --git a/docker/ascend/Dockerfile.v0.1.0 b/docker/ascend/Dockerfile.v0.1.0 new file mode 100644 index 00000000..9d10abcf --- /dev/null +++ b/docker/ascend/Dockerfile.v0.1.0 @@ -0,0 +1,79 @@ +ARG VLLM_VERSION=0.13.0 + +# ---------- base stage ---------- +FROM quay.io/ascend/vllm-ascend:v${VLLM_VERSION}rc1-a3 AS base + +RUN pip install --upgrade pip setuptools + +# CANN Toolkit environment variables (mirrors set_env.sh baked in at build time) +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH="${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64:${LD_LIBRARY_PATH}" \ + PYTHONPATH="${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}" \ + PATH="${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${ASCEND_TOOLKIT_HOME}/tools/ccec_compiler/bin:${PATH}" + +# Set ATB environment variables +ENV ATB_HOME_PATH=/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_1 +ENV LD_LIBRARY_PATH="${ATB_HOME_PATH}/lib:${ATB_HOME_PATH}/examples:${ATB_HOME_PATH}/tests/atbopstest:${LD_LIBRARY_PATH}" \ + PATH="${ATB_HOME_PATH}/bin:${PATH}" + +# ---------- dev stage ---------- +FROM base AS dev + +# Install dev tools +RUN pip install \ + pytest \ + pytest-cov \ + pytest-json-report \ + ruff \ + pre-commit \ + ninja \ + cmake + +# ---------- ci stage ---------- +FROM base AS ci + +# Install dev/test tools +RUN pip install --upgrade pip +RUN pip install \ + pytest \ + pytest-cov \ + pytest-timeout \ + pytest-json-report \ + numpy \ + requests \ + decorator \ + "modelscope>=1.18.1" \ + ruff \ + pre-commit \ + ninja \ + cmake + +# Install FlagGems (NPU backend) +RUN pip install -U scikit-build-core==0.11 pybind11 \ + && git clone https://github.com/flagos-ai/FlagGems /workspace/FlagGems \ + && pip install --no-build-isolation \ + --config-settings=cmake.define.FLAGGEMS_BACKEND=NPU \ + /workspace/FlagGems + +# Install FlagTree +RUN pip install flagtree==0.4.0+ascend3.2 \ + --index-url=https://resource.flagos.net/repository/flagos-pypi-hosted/simple \ + --trusted-host=resource.flagos.net + +# Set environment variables for vLLM and Triton +ENV VLLM_PLUGINS=fl +ENV TRITON_ALL_BLOCKS_PARALLEL=1 + +WORKDIR /workspace + +# ---------- release stage ---------- +FROM base AS release + +ARG INDEX_URL +ARG EXTRA_INDEX_URL +ARG VLLM_VERSION=0.13.0 + +# Install vLLM +# Todo + +WORKDIR /workspace diff --git a/docker/build.sh b/docker/build.sh index f79c05cd..ad92b88d 100644 --- a/docker/build.sh +++ b/docker/build.sh @@ -14,12 +14,12 @@ PYTHON_VERSION="${PYTHON_VERSION:-3.12}" UV_VERSION="${UV_VERSION:-0.7.12}" CUDA_VERSION="${CUDA_VERSION:-12.8.1}" UBUNTU_VERSION="${UBUNTU_VERSION:-22.04}" -VLLM_VERSION="${VLLM_VERSION:-0.13.0}" +VLLM_VERSION="${VLLM_VERSION:-0.18.0}" # ---- Build options ---- PLATFORM="${PLATFORM:-cuda}" TARGET="dev" -IMAGE_NAME="localhost:5000/vllm-plugin-fl" +IMAGE_NAME="harbor.baai.ac.cn/flagscale/vllm-plugin-fl" IMAGE_TAG="" INDEX_URL="${INDEX_URL:-}" EXTRA_INDEX_URL="${EXTRA_INDEX_URL:-}" diff --git a/docker/cuda/Dockerfile b/docker/cuda/Dockerfile index 8cea7c94..177d696d 100644 --- a/docker/cuda/Dockerfile +++ b/docker/cuda/Dockerfile @@ -6,7 +6,7 @@ FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS base ARG PYTHON_VERSION=3.12 ARG UV_VERSION=0.7.12 -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 ENV DEBIAN_FRONTEND=noninteractive @@ -47,7 +47,7 @@ FROM base AS dev ARG INDEX_URL ARG EXTRA_INDEX_URL -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 # Install vLLM RUN uv pip install --system \ @@ -70,7 +70,7 @@ FROM base AS ci ARG INDEX_URL ARG EXTRA_INDEX_URL -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 # Install vLLM RUN uv pip install --system \ @@ -92,18 +92,17 @@ RUN uv pip install --system \ pre-commit \ ninja \ cmake - +ARG FLAGGEMS_VERSION=v5.0.0 ARG FLAGCX_VERSION=v0.9.0 # Install FlagGems RUN uv pip install --system scikit-build-core==0.11 pybind11 \ - && git clone https://github.com/flagos-ai/FlagGems /workspace/FlagGems \ + && git clone --branch ${FLAGGEMS_VERSION} --depth 1 https://github.com/flagos-ai/FlagGems /workspace/FlagGems \ && uv pip install --system --no-build-isolation /workspace/FlagGems # Install FlagCX (NVIDIA) -RUN git clone https://github.com/flagos-ai/FlagCX.git /workspace/FlagCX \ +RUN git clone --branch ${FLAGCX_VERSION} --depth 1 https://github.com/flagos-ai/FlagCX.git /workspace/FlagCX \ && cd /workspace/FlagCX \ - && git checkout ${FLAGCX_VERSION} \ && git submodule update --init --recursive \ && make USE_NVIDIA=1 \ && cd plugin/torch \ @@ -125,7 +124,7 @@ FROM base AS release ARG INDEX_URL ARG EXTRA_INDEX_URL -ARG VLLM_VERSION=0.13.0 +ARG VLLM_VERSION=0.18.0 # Install vLLM RUN uv pip install --system \ diff --git a/pyproject.toml b/pyproject.toml index f9766d71..e09f006d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ test = [ "requests", "openai", "decorator", - "vllm[audio]==0.13.0", + "vllm[audio]==0.18.0", "modelscope>=1.18.1", ] diff --git a/tests/models/qwen3/next_tp8.yaml b/tests/models/qwen3/next_tp8.yaml index 0352ab86..6c11d026 100644 --- a/tests/models/qwen3/next_tp8.yaml +++ b/tests/models/qwen3/next_tp8.yaml @@ -3,10 +3,10 @@ llm: model: "/data/models/Qwen/Qwen3-Next-80B-A3B-Instruct" tensor_parallel_size: 8 - max_model_len: 16384 - max_num_batched_tokens: 16384 + max_model_len: 8192 + max_num_batched_tokens: 8192 max_num_seqs: 512 - gpu_memory_utilization: 0.7 + gpu_memory_utilization: 0.8 enforce_eager: true trust_remote_code: true diff --git a/tests/run.py b/tests/run.py index 1e4c2029..81287692 100644 --- a/tests/run.py +++ b/tests/run.py @@ -47,7 +47,8 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(_REPO_ROOT)) -from tests.utils.cleanup import device_cleanup +from tests.utils.cleanup import device_cleanup, wait_for_memory +from tests.utils.model_config import ModelConfig from tests.utils.platform_config import PlatformConfig from tests.utils.report import TestReport, TestResult @@ -351,6 +352,24 @@ def _run_single(self, tc: TestCase) -> TestResult: message="dry-run", ) + # Wait for sufficient device memory before e2e tests + if tc.task in ("inference", "serving") and tc.model and tc.case: + gpu_util = ModelConfig.load(tc.model, tc.case).engine.get( + "gpu_memory_utilization", 0.9 + ) + ok, info = wait_for_memory(self.config.platform, gpu_util) + if not ok: + print("[run] FAILED: timed out waiting for device memory") + return TestResult( + name=tc.name, + passed=False, + duration=0.0, + message=f"OOM: timed out waiting for device memory\n{info}", + task=tc.task, + model=tc.model, + case=tc.case, + ) + # Merge extra env vars (e.g. FL_TEST_MODEL/FL_TEST_CASE for inference) env = None if tc.extra_env: diff --git a/tests/utils/cleanup.py b/tests/utils/cleanup.py index 6bbe0b28..ae01a517 100644 --- a/tests/utils/cleanup.py +++ b/tests/utils/cleanup.py @@ -22,6 +22,7 @@ import signal import subprocess import time +from collections.abc import Callable def device_cleanup(platform: str, wait: float = 3.0) -> None: @@ -37,6 +38,10 @@ def device_cleanup(platform: str, wait: float = 3.0) -> None: """ _kill_stale_processes() + # Clear framework cache to reclaim memory held by PyTorch allocator + cache_fn = _PLATFORM_CACHE_CLEAR.get(platform, _cache_clear_noop) + cache_fn() + if wait > 0: time.sleep(wait) @@ -134,3 +139,159 @@ def _cleanup_noop() -> None: "cuda": _cleanup_cuda, "ascend": _cleanup_ascend, } + + +# --------------------------------------------------------------------------- +# Memory info (platform-specific) +# --------------------------------------------------------------------------- + + +def _mem_info_cuda() -> list[tuple[int, int]]: + """Return [(free_bytes, total_bytes), ...] for each CUDA device.""" + import torch + + result = [] + for i in range(torch.cuda.device_count()): + free, total = torch.cuda.mem_get_info(i) + result.append((free, total)) + return result + + +def _mem_info_ascend() -> list[tuple[int, int]]: + """Return [(free_bytes, total_bytes), ...] for each Ascend NPU.""" + import torch + + try: + import torch_npu # noqa: F401 + + result = [] + for i in range(torch.npu.device_count()): + free, total = torch.npu.mem_get_info(i) + result.append((free, total)) + return result + except (ImportError, AttributeError): + return [] + + +def _mem_info_noop() -> list[tuple[int, int]]: + return [] + + +_PLATFORM_MEMORY_INFO: dict[str, Callable[[], list[tuple[int, int]]]] = { + "cuda": _mem_info_cuda, + "ascend": _mem_info_ascend, +} + + +# --------------------------------------------------------------------------- +# Cache clear (platform-specific) +# --------------------------------------------------------------------------- + + +def _cache_clear_cuda() -> None: + """Clear PyTorch CUDA cache.""" + import torch + + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def _cache_clear_ascend() -> None: + """Clear PyTorch NPU cache.""" + try: + import torch + import torch_npu # noqa: F401 + + torch.npu.empty_cache() + except (ImportError, AttributeError): + pass + + +def _cache_clear_noop() -> None: + pass + + +_PLATFORM_CACHE_CLEAR: dict[str, Callable[[], None]] = { + "cuda": _cache_clear_cuda, + "ascend": _cache_clear_ascend, +} + + +# --------------------------------------------------------------------------- +# Public memory API +# --------------------------------------------------------------------------- + + +def get_device_memory(platform: str) -> list[tuple[float, float]]: + """Return [(free_mb, total_mb), ...] for each device on the platform.""" + mem_fn = _PLATFORM_MEMORY_INFO.get(platform, _mem_info_noop) + return [(free / (1024 * 1024), total / (1024 * 1024)) for free, total in mem_fn()] + + +def wait_for_memory( + platform: str, + gpu_memory_utilization: float = 0.9, + timeout: int = 1800, + interval: int = 30, +) -> tuple[bool, str]: + """Wait until devices have enough free memory for the given utilization. + + Args: + platform: Platform name (e.g. ``"cuda"``, ``"ascend"``). + gpu_memory_utilization: Fraction of total memory the model needs. + timeout: Maximum seconds to wait (default 30 min). + interval: Seconds between polls. + + Returns: + ``(True, info)`` if memory is available, ``(False, info)`` on timeout. + """ + mem_fn = _PLATFORM_MEMORY_INFO.get(platform, _mem_info_noop) + cache_fn = _PLATFORM_CACHE_CLEAR.get(platform, _cache_clear_noop) + + deadline = time.time() + timeout + attempt = 0 + + while True: + attempt += 1 + + # Kill stale vllm processes from previous e2e tests + _kill_stale_processes() + # Clear framework cache + cache_fn() + # Brief pause for resources to be released + time.sleep(1) + + mem_info = mem_fn() + if not mem_info: + return (True, "no devices detected, skipping memory check") + + # Check each device + all_ok = True + lines = [] + for i, (free, total) in enumerate(mem_info): + required = total * gpu_memory_utilization + free_mb = free / (1024 * 1024) + total_mb = total / (1024 * 1024) + required_mb = required / (1024 * 1024) + ok = free >= required + status = "OK" if ok else "WAIT" + lines.append( + f" Device {i}: {free_mb:.0f}/{total_mb:.0f} MiB free, " + f"need {required_mb:.0f} MiB ({gpu_memory_utilization:.0%}) [{status}]" + ) + if not ok: + all_ok = False + + info = "\n".join(lines) + print( + f"[memory] Attempt {attempt} (util={gpu_memory_utilization:.0%}):\n{info}" + ) + + if all_ok: + return (True, info) + + if time.time() >= deadline: + return (False, info) + + print(f"[memory] Waiting {interval}s for memory to free up...") + time.sleep(interval)