diff --git a/README.md b/README.md
index 13c1ac75..0e67e552 100644
--- a/README.md
+++ b/README.md
@@ -35,7 +35,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n
### Setup
-1. Install vllm from the official [v0.13.0](https://github.com/vllm-project/vllm/tree/v0.13.0) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL).
+1. Install vllm from the official [v0.18.1](https://github.com/vllm-project/vllm/tree/v0.18.1) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL).
2. Install vllm-plugin-FL
@@ -66,6 +66,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n
```sh
git clone https://github.com/flagos-ai/FlagGems
+ git checkout v5.0.0
cd FlagGems
pip install --no-build-isolation .
# or editble install
diff --git a/tests/unit_tests/ops/test_layernorm.py b/tests/unit_tests/ops/test_layernorm.py
index 4119a3d0..04acfe5a 100644
--- a/tests/unit_tests/ops/test_layernorm.py
+++ b/tests/unit_tests/ops/test_layernorm.py
@@ -13,6 +13,11 @@
class TestRMSNormFL:
"""Test RMSNormFL class behavior."""
+ def __init__(self):
+ from vllm.config import VllmConfig, set_current_vllm_config
+
+ set_current_vllm_config(VllmConfig())
+
@pytest.fixture
def mock_call_op(self):
with patch("vllm_fl.ops.layernorm.call_op") as mock:
diff --git a/tests/unit_tests/worker/test_model_runner.py b/tests/unit_tests/worker/test_model_runner.py
index 2e844470..d1e8f998 100644
--- a/tests/unit_tests/worker/test_model_runner.py
+++ b/tests/unit_tests/worker/test_model_runner.py
@@ -60,6 +60,7 @@ def test_fields_match_expected_contract(self):
"aux_hidden_states",
"ec_connector_output",
"cudagraph_stats",
+ "slot_mappings",
)
assert ExecuteModelState._fields == expected_fields, (
"ExecuteModelState fields changed - this may break execute_model consumers"
@@ -79,6 +80,7 @@ def test_immutability_prevents_accidental_mutation(self):
aux_hidden_states=None,
ec_connector_output=None,
cudagraph_stats=None,
+ slot_mappings=None,
)
with pytest.raises(AttributeError):
@@ -101,6 +103,7 @@ def test_unpacking_for_downstream_processing(self):
aux_hidden_states=None,
ec_connector_output=None,
cudagraph_stats=None,
+ slot_mappings=None,
)
# Simulate downstream unpacking
diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py
index cb612ed9..f53973ce 100644
--- a/vllm_fl/__init__.py
+++ b/vllm_fl/__init__.py
@@ -45,75 +45,19 @@ def register():
def register_model():
- """Register the FL model."""
- from vllm import ModelRegistry
- import vllm.model_executor.models.qwen3_next as qwen3_next_module
+ """Register FL-specific models not yet upstream."""
+ # Models now upstream in vLLM v0.18.1 (no longer need plugin registration):
+ # Qwen3NextForCausalLM, Qwen3_5MoeForConditionalGeneration,
+ # MiniCPMO, KimiK25ForConditionalGeneration, Qwen3_5MoeConfig
- # Register Qwen3.5 MoE config
- try:
- from vllm.transformers_utils.config import _CONFIG_REGISTRY
- from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig
- _CONFIG_REGISTRY["qwen3_5_moe"] = Qwen3_5MoeConfig
- except Exception as e:
- logger.error(f"Register Qwen3.5 MoE config error: {str(e)}")
-
- # Register Qwen3Next model
- try:
- from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401
-
- qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM
- logger.warning(
- "Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, "
- "original vLLM implementation is overridden"
- )
-
- ModelRegistry.register_model(
- "Qwen3NextForCausalLM",
- "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM"
- )
- except Exception as e:
- logger.error(f"Register Qwen3Next model error: {str(e)}")
-
- # Register Qwen3.5 MoE model
- try:
- ModelRegistry.register_model(
- "Qwen3_5MoeForConditionalGeneration",
- "vllm_fl.models.qwen3_5:Qwen3_5MoeForConditionalGeneration"
- )
- except Exception as e:
- logger.error(f"Register Qwen3.5 MoE model error: {str(e)}")
-
- # Register MiniCPMO model
- try:
- ModelRegistry.register_model(
- "MiniCPMO",
- "vllm_fl.models.minicpmo:MiniCPMO"
- )
- except Exception as e:
- logger.error(f"Register MiniCPMO model error: {str(e)}")
-
- # Register Kimi-K2.5 model
- try:
- ModelRegistry.register_model(
- "KimiK25ForConditionalGeneration",
- "vllm_fl.models.kimi_k25:KimiK25ForConditionalGeneration",
- )
- except Exception as e:
- logger.error(f"Register KimiK25 model error: {str(e)}")
-
- # Register GLM-5 (GlmMoeDsa) model
+ # Register GLM-5 (GlmMoeDsa) — config not yet upstream
try:
from vllm.transformers_utils.config import _CONFIG_REGISTRY
from vllm_fl.configs.glm_moe_dsa import GlmMoeDsaConfig
_CONFIG_REGISTRY["glm_moe_dsa"] = GlmMoeDsaConfig
- from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model
- glm5_model()
-
- ModelRegistry.register_model(
- "GlmMoeDsaForCausalLM",
- "vllm_fl.models.glm_moe_dsa:GlmMoeDsaForCausalLM"
- )
+ #from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model
+ #glm5_model()
except Exception as e:
logger.error(f"Register GlmMoeDsa model error: {str(e)}")
diff --git a/vllm_fl/attention/utils.py b/vllm_fl/attention/utils.py
index 642e7800..88982dfc 100644
--- a/vllm_fl/attention/utils.py
+++ b/vllm_fl/attention/utils.py
@@ -13,8 +13,8 @@ def patch_mm_encoder_attention():
FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a
fallback to flash_attn.
"""
- import vllm.attention.layers.mm_encoder_attention as mm_mod
- from vllm.attention.backends.registry import AttentionBackendEnum
+ import vllm.model_executor.layers.attention.mm_encoder_attention as mm_mod
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
def _patched_maybe_get_vit_flash_attn_backend(attn_backend):
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py
index 2ff4eb8e..111aa9d6 100644
--- a/vllm_fl/compilation/graph.py
+++ b/vllm_fl/compilation/graph.py
@@ -1,14 +1,15 @@
# Copyright (c) 2025 BAAI. All rights reserved.
-# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/compilation/cuda_graph.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/compilation/cuda_graph.py
# Below is the original copyright:
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
+import weakref
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
-from typing import Any, Optional
+from typing import Any, ClassVar
from unittest.mock import patch
import torch
@@ -18,13 +19,18 @@
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
-from vllm.forward_context import BatchDescriptor, get_forward_context
+from vllm.forward_context import (
+ BatchDescriptor,
+ get_forward_context,
+ is_forward_context_available,
+)
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
+# FL-specific: platform-agnostic weak_ref_tensors
def weak_ref_tensors(tensor: Any) -> Any:
if current_platform.device_type == "cuda":
from vllm.utils.torch_utils import weak_ref_tensors
@@ -34,6 +40,7 @@ def weak_ref_tensors(tensor: Any) -> Any:
return tensor
+# FL-specific: platform-agnostic graph class selection
class Graph:
if current_platform.device_type == "cuda":
graph = torch.cuda.CUDAGraph
@@ -44,15 +51,21 @@ class Graph:
else:
raise NotImplementedError("not support graph")
+
+# Re-export CUDAGraphStat for compatibility
+from vllm.compilation.cuda_graph import CUDAGraphStat # noqa: F401, E402
+
+
@dataclasses.dataclass
class GraphEntry:
batch_descriptor: BatchDescriptor
- graph: Optional[Graph] = None
- output: Optional[Any] = None
+ graph: Any | None = None
+ output: Any | None = None
# for graph debugging, track the input addresses
# during capture, and check if they are the same during replay
- input_addresses: Optional[list[int]] = None
+ input_addresses: list[int] | None = None
+
@dataclasses.dataclass
class GraphOptions:
@@ -62,11 +75,22 @@ class GraphOptions:
class GraphWrapper:
+ """FL-specific graph wrapper that supports multiple device types (CUDA, NPU).
+ Adapted from upstream CUDAGraphWrapper with platform-agnostic graph capture."""
+
+ _all_instances: ClassVar[weakref.WeakSet["GraphWrapper"]] = weakref.WeakSet()
+
+ @classmethod
+ def clear_all_graphs(cls) -> None:
+ """Clear captured graphs from all GraphWrapper instances."""
+ for instance in list(cls._all_instances):
+ instance.clear_graphs()
+
def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
- cudagraph_options: Optional[GraphOptions] = None):
+ cudagraph_options: GraphOptions | None = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
@@ -74,9 +98,10 @@ def __init__(self,
self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
+ self._runnable_str = str(runnable) if self.is_debugging_mode else None
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
- # need to initialize a CUDAGraphWrapper.
+ # need to initialize a GraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
@@ -85,25 +110,41 @@ def __init__(self,
if cudagraph_options is None:
cudagraph_options = GraphOptions()
- self.graph_options = cudagraph_options
+ self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_graph_entries: dict[BatchDescriptor, GraphEntry] = {}
- def __getattr__(self, key: str):
+ GraphWrapper._all_instances.add(self)
+
+ def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
- raise AttributeError(
- f"Attribute {key} not exists in the runnable of "
- f"cudagraph wrapper: {self.runnable}"
- )
+ if self.is_debugging_mode:
+ raise AttributeError(
+ f"Attribute {key} not exists in the runnable of "
+ f"cudagraph wrapper: {self._runnable_str}"
+ )
+ raise AttributeError
def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable
+ @property
+ def cudagraph_wrapper(self) -> "GraphWrapper":
+ return self
+
+ def clear_graphs(self) -> None:
+ self.concrete_graph_entries.clear()
+
def __call__(self, *args, **kwargs):
+ if not is_forward_context_available():
+ # No forward context means we are outside the normal
+ # inference path (e.g. a vision encoder forward pass).
+ return self.runnable(*args, **kwargs)
+
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
graph_runtime_mode = forward_context.cudagraph_runtime_mode
@@ -112,14 +153,9 @@ def __call__(self, *args, **kwargs):
graph_runtime_mode == CUDAGraphMode.NONE
or graph_runtime_mode != self.runtime_mode
):
- # CUDAGraphMode.NONE could mean the profile run, a warmup run, or
- # running without cudagraphs.
- # We do not trigger capture/replay if the runtime mode is not
- # matches. This enables properly dispatching to the correct
- # CUDAGraphWrapper when nesting multiple instances with different
- # runtime modes.
return self.runnable(*args, **kwargs)
+ assert batch_descriptor is not None
if batch_descriptor not in self.concrete_graph_entries:
# create a new entry for this batch descriptor
self.concrete_graph_entries[batch_descriptor] = GraphEntry(
@@ -129,11 +165,7 @@ def __call__(self, *args, **kwargs):
entry = self.concrete_graph_entries[batch_descriptor]
if entry.graph is None:
- if self.graph_options.debug_log_enable:
- # Since we capture cudagraph for many different shapes and
- # capturing is fast, we don't need to log it for every
- # shape. E.g. we only log it for the first subgraph in
- # piecewise mode.
+ if self.cudagraph_options.debug_log_enable:
logger.debug(
"Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name,
@@ -149,32 +181,40 @@ def __call__(self, *args, **kwargs):
graph = Graph.graph()
with ExitStack() as stack:
- if self.graph_options.gc_disable:
- # during every model forward for piecewise graph
- # mode, we will capture many pieces of graphs
- # (roughly one per layer). running gc again and again
- # across layers will make the graph capture very slow.
- # therefore, we only run gc for the first graph,
- # and disable gc for the rest of the graphs.
+ if self.cudagraph_options.gc_disable:
stack.enter_context(patch("gc.collect", lambda: None))
+ # FL-specific: patch our platform's empty_cache
stack.enter_context(
- patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None)
+ patch("vllm_fl.platform.PlatformFL.empty_cache",
+ lambda: None)
)
- set_graph_pool_id(self.graph_pool)
-
- # mind-exploding: carefully manage the reference and memory.
- with current_platform.torch_device_fn.graph(graph, pool=self.graph_pool):
+ if self.graph_pool is not None:
+ set_graph_pool_id(self.graph_pool)
+ else:
+ set_graph_pool_id(current_platform.graph_pool_handle())
+
+ # Sync offloader's copy stream before capture if available.
+ try:
+ from vllm.model_executor.offloader.base import get_offloader
+ get_offloader().sync_prev_onload()
+ except (ImportError, RuntimeError):
+ pass
+
+ # FL-specific: use platform-agnostic graph capture
+ with current_platform.torch_device_fn.graph(
+ graph, pool=self.graph_pool
+ ):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
- if self.graph_options.weak_ref_output:
- # by converting it to weak ref,
- # the original `output` will immediately be released
- # to save memory. It is only safe to do this for
- # the last graph in piecewise cuadgraph mode, because
- # the output of the last graph will not be used by
- # any other cuda graph.
- output = weak_ref_tensors(output)
+ # Join offloader's copy stream after forward if available
+ try:
+ from vllm.model_executor.offloader.base import get_offloader
+ get_offloader().join_after_forward()
+ except (ImportError, RuntimeError):
+ pass
+ if self.cudagraph_options.weak_ref_output:
+ output = weak_ref_tensors(output)
entry.output = weak_ref_tensors(output)
entry.graph = graph
@@ -197,6 +237,13 @@ def __call__(self, *args, **kwargs):
f"got {new_input_addresses}"
)
+ # Sync offloader before replay if available
+ try:
+ from vllm.model_executor.offloader.base import get_offloader
+ get_offloader().sync_prev_onload()
+ except (ImportError, RuntimeError):
+ pass
+
current_platform.torch_device_fn.synchronize()
entry.graph.replay()
return entry.output
diff --git a/vllm_fl/configs/qwen3_5_moe.py b/vllm_fl/configs/qwen3_5_moe.py
deleted file mode 100644
index 48d68865..00000000
--- a/vllm_fl/configs/qwen3_5_moe.py
+++ /dev/null
@@ -1,185 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The Qwen Team and The HuggingFace Inc. team.
-# All rights reserved.
-"""Qwen3.5-MoE model configuration for vLLM plugin."""
-
-from transformers.configuration_utils import PretrainedConfig
-
-
-def _layer_type_validation(layer_types, num_hidden_layers):
- if layer_types is not None and num_hidden_layers is not None:
- if len(layer_types) != num_hidden_layers:
- raise ValueError(
- f"Length of layer_types ({len(layer_types)}) must match "
- f"num_hidden_layers ({num_hidden_layers})"
- )
-
-
-class Qwen3_5MoeTextConfig(PretrainedConfig):
- model_type = "qwen3_5_moe_text"
- keys_to_ignore_at_inference = ["past_key_values"]
- base_config_key = "text_config"
-
- def __init__(
- self,
- vocab_size=248320,
- hidden_size=2048,
- num_hidden_layers=40,
- num_attention_heads=16,
- num_key_value_heads=2,
- hidden_act="silu",
- max_position_embeddings=32768,
- initializer_range=0.02,
- rms_norm_eps=1e-6,
- use_cache=True,
- tie_word_embeddings=False,
- rope_parameters=None,
- attention_bias=False,
- attention_dropout=0.0,
- head_dim=256,
- linear_conv_kernel_dim=4,
- linear_key_head_dim=128,
- linear_value_head_dim=128,
- linear_num_key_heads=16,
- linear_num_value_heads=32,
- moe_intermediate_size=512,
- shared_expert_intermediate_size=512,
- num_experts_per_tok=8,
- num_experts=256,
- norm_topk_prob=True,
- output_router_logits=False,
- router_aux_loss_coef=0.001,
- layer_types=None,
- full_attention_interval=4,
- attn_output_gate=True,
- pad_token_id=None,
- bos_token_id=None,
- eos_token_id=None,
- **kwargs,
- ):
- kwargs.pop("ignore_keys_at_rope_validation", None)
- self.vocab_size = vocab_size
- self.max_position_embeddings = max_position_embeddings
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.initializer_range = initializer_range
- self.rms_norm_eps = rms_norm_eps
- self.use_cache = use_cache
- self.attention_bias = attention_bias
- self.attention_dropout = attention_dropout
- self.head_dim = head_dim
- self.rope_parameters = rope_parameters
- self.attn_output_gate = attn_output_gate
-
- self.layer_types = layer_types
- if self.layer_types is None:
- self.layer_types = [
- "linear_attention"
- if bool((i + 1) % full_attention_interval)
- else "full_attention"
- for i in range(self.num_hidden_layers)
- ]
- _layer_type_validation(self.layer_types, self.num_hidden_layers)
-
- self.linear_conv_kernel_dim = linear_conv_kernel_dim
- self.linear_key_head_dim = linear_key_head_dim
- self.linear_value_head_dim = linear_value_head_dim
- self.linear_num_key_heads = linear_num_key_heads
- self.linear_num_value_heads = linear_num_value_heads
- self.moe_intermediate_size = moe_intermediate_size
- self.shared_expert_intermediate_size = shared_expert_intermediate_size
- self.num_experts_per_tok = num_experts_per_tok
- self.num_experts = num_experts
- self.norm_topk_prob = norm_topk_prob
- self.output_router_logits = output_router_logits
- self.router_aux_loss_coef = router_aux_loss_coef
- self.full_attention_interval = full_attention_interval
-
- # partial_rotary_factor is needed by rope
- kwargs.setdefault("partial_rotary_factor", 0.25)
-
- super().__init__(**kwargs)
- self.pad_token_id = pad_token_id
- self.bos_token_id = bos_token_id
- self.eos_token_id = eos_token_id
- self.tie_word_embeddings = tie_word_embeddings
-
-
-class Qwen3_5MoeVisionConfig(PretrainedConfig):
- model_type = "qwen3_5_moe"
- base_config_key = "vision_config"
-
- def __init__(
- self,
- depth=27,
- hidden_size=1152,
- hidden_act="gelu_pytorch_tanh",
- intermediate_size=4304,
- num_heads=16,
- in_channels=3,
- patch_size=16,
- spatial_merge_size=2,
- temporal_patch_size=2,
- out_hidden_size=3584,
- num_position_embeddings=2304,
- initializer_range=0.02,
- deepstack_visual_indexes=None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.depth = depth
- self.hidden_size = hidden_size
- self.hidden_act = hidden_act
- self.intermediate_size = intermediate_size
- self.num_heads = num_heads
- self.in_channels = in_channels
- self.patch_size = patch_size
- self.spatial_merge_size = spatial_merge_size
- self.temporal_patch_size = temporal_patch_size
- self.out_hidden_size = out_hidden_size
- self.num_position_embeddings = num_position_embeddings
- self.initializer_range = initializer_range
- self.deepstack_visual_indexes = deepstack_visual_indexes or []
-
-
-class Qwen3_5MoeConfig(PretrainedConfig):
- model_type = "qwen3_5_moe"
- sub_configs = {
- "vision_config": Qwen3_5MoeVisionConfig,
- "text_config": Qwen3_5MoeTextConfig,
- }
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- text_config=None,
- vision_config=None,
- image_token_id=248056,
- video_token_id=248057,
- vision_start_token_id=248053,
- vision_end_token_id=248054,
- tie_word_embeddings=False,
- **kwargs,
- ):
- if isinstance(vision_config, dict):
- self.vision_config = self.sub_configs["vision_config"](**vision_config)
- elif vision_config is None:
- self.vision_config = self.sub_configs["vision_config"]()
-
- if isinstance(text_config, dict):
- self.text_config = self.sub_configs["text_config"](**text_config)
- elif text_config is None:
- self.text_config = self.sub_configs["text_config"]()
-
- self.image_token_id = image_token_id
- self.video_token_id = video_token_id
- self.vision_start_token_id = vision_start_token_id
- self.vision_end_token_id = vision_end_token_id
- super().__init__(**kwargs)
- self.tie_word_embeddings = tie_word_embeddings
-
-
-__all__ = ["Qwen3_5MoeConfig", "Qwen3_5MoeTextConfig", "Qwen3_5MoeVisionConfig"]
diff --git a/vllm_fl/dispatch/backends/flaggems/flaggems.py b/vllm_fl/dispatch/backends/flaggems/flaggems.py
index e12afaf7..a7b166ea 100644
--- a/vllm_fl/dispatch/backends/flaggems/flaggems.py
+++ b/vllm_fl/dispatch/backends/flaggems/flaggems.py
@@ -144,7 +144,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) ->
Returns:
Fully qualified class path string
"""
- from vllm.attention.backends.registry import AttentionBackendEnum
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
# TritonAttentionBackend requires CUDA, check if available
if not torch.cuda.is_available():
diff --git a/vllm_fl/dispatch/backends/flaggems/impl/attention.py b/vllm_fl/dispatch/backends/flaggems/impl/attention.py
index 44c751c4..d7e14dea 100644
--- a/vllm_fl/dispatch/backends/flaggems/impl/attention.py
+++ b/vllm_fl/dispatch/backends/flaggems/impl/attention.py
@@ -11,27 +11,29 @@
import torch
from vllm import envs
-from vllm.attention.backends.abstract import (
+from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
-from vllm.attention.layer import Attention
-from vllm.attention.ops.common import cp_lse_ag_out_rs
-from vllm.attention.ops.merge_attn_states import merge_attn_states
+from vllm.model_executor.layers.attention.attention import Attention
+from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
+from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
-from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
+from vllm.model_executor.layers.batch_invariant import _batch_invariant_MODE as _bi_mode
from vllm.utils.math_utils import cdiv
-from vllm.v1.attention.backends.utils import (
+from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
+)
+from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
@@ -457,7 +459,7 @@ def __init__(
self.attn_type = attn_type
self.vllm_flash_attn_version = 3 # 2 #get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
- self.batch_invariant_enabled = vllm_is_batch_invariant()
+ self.batch_invariant_enabled = _bi_mode
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
diff --git a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py
index 5aa2860d..f6dc37e3 100644
--- a/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py
+++ b/vllm_fl/dispatch/backends/flaggems/impl/custom_attention.py
@@ -1,4 +1,4 @@
-from vllm.attention.backends.registry import (
+from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
register_backend,
)
diff --git a/vllm_fl/dispatch/backends/flaggems/impl/mla.py b/vllm_fl/dispatch/backends/flaggems/impl/mla.py
index 866dcd51..1789498d 100644
--- a/vllm_fl/dispatch/backends/flaggems/impl/mla.py
+++ b/vllm_fl/dispatch/backends/flaggems/impl/mla.py
@@ -8,7 +8,7 @@
import torch
-from vllm.attention.backends.abstract import (
+from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
@@ -17,7 +17,7 @@
# from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
# from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
-from vllm.v1.attention.backends.mla.common import (
+from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
diff --git a/vllm_fl/dispatch/backends/reference/reference.py b/vllm_fl/dispatch/backends/reference/reference.py
index 7f20276c..9c10a2e5 100644
--- a/vllm_fl/dispatch/backends/reference/reference.py
+++ b/vllm_fl/dispatch/backends/reference/reference.py
@@ -150,7 +150,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) ->
Fully qualified class path string (vLLM native backend)
"""
# Return vLLM's native flash attention backend as reference
- from vllm.attention.backends.registry import AttentionBackendEnum
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
if use_mla:
# vLLM native MLA backend
diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py
index 4dcdd7d7..494cf508 100644
--- a/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py
+++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/attention.py
@@ -28,7 +28,7 @@
import torch
import torch.nn as nn
-from vllm.attention.backends.abstract import (
+from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
@@ -36,7 +36,8 @@
)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.utils.math_utils import cdiv
-from vllm.v1.attention.backends.utils import AttentionCGSupport, CommonAttentionMetadata
+from vllm.v1.attention.backend import AttentionCGSupport
+from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm_fl.dispatch.backends.vendor.ascend.impl.attention_mask import (
AttentionMaskBuilder,
diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py
index 2aad980a..a6fa90f2 100644
--- a/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py
+++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/causal_conv1d.py
@@ -15,7 +15,7 @@
import torch.nn.functional as F
import triton
import triton.language as tl
-from vllm.attention.backends.utils import PAD_SLOT_ID
+from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_ref(
diff --git a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py
index 4e8c0758..78dfcd24 100644
--- a/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py
+++ b/vllm_fl/dispatch/backends/vendor/ascend/impl/mm_encoder_attention.py
@@ -21,7 +21,7 @@
import torch
import torch.nn.functional as F
import torch_npu
-from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
+from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
MIN_PAD_SIZE = 64 # min_size to pad weight
diff --git a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py
index 8260f067..e504559d 100644
--- a/vllm_fl/dispatch/backends/vendor/cuda/cuda.py
+++ b/vllm_fl/dispatch/backends/vendor/cuda/cuda.py
@@ -174,7 +174,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) ->
Returns:
Fully qualified class path string
"""
- from vllm.attention.backends.registry import AttentionBackendEnum
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
if use_mla:
if use_sparse:
diff --git a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py
index 20f67053..049deb3d 100644
--- a/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py
+++ b/vllm_fl/dispatch/backends/vendor/cuda/impl/activation.py
@@ -22,11 +22,9 @@ def silu_and_mul_cuda(obj, x: torch.Tensor) -> torch.Tensor:
Returns:
Output tensor of shape [..., d]
"""
- from vllm._custom_ops import silu_and_mul as vllm_silu_and_mul
-
d = x.shape[-1] // 2
out = torch.empty(*x.shape[:-1], d, dtype=x.dtype, device=x.device)
- vllm_silu_and_mul(out, x)
+ torch.ops._C.silu_and_mul(out, x)
return out
diff --git a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py
index 4d7d1449..d4896cd4 100644
--- a/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py
+++ b/vllm_fl/dispatch/backends/vendor/iluvatar/iluvatar.py
@@ -154,7 +154,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) ->
Returns:
Fully qualified class path string
"""
- from vllm.attention.backends.registry import AttentionBackendEnum
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
if use_mla:
if use_sparse:
diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py
index 35cdd6a9..7d10e52f 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/flash_attn.py
@@ -9,15 +9,15 @@
import numpy as np
import torch
-from vllm.attention.backends.abstract import (
+from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
-from vllm.attention.layer import Attention
-from vllm.attention.ops.common import cp_lse_ag_out_rs
+from vllm.model_executor.layers.attention.attention import Attention
+from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
# --------------------------------------------------------------
# Note: use Maca's merge_attn_states to get cuda kernel invoked
@@ -42,13 +42,15 @@
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
- vllm_is_batch_invariant,
+ _batch_invariant_MODE as _bi_mode,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
-from vllm.v1.attention.backends.utils import (
+from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
+)
+from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
@@ -56,7 +58,7 @@
reshape_attn_output_for_spec_decode, # used for prefill decode split with mtp
reshape_query_for_spec_decode, # used for prefill decode split with mtp
)
-from vllm.attention.backends.registry import AttentionBackendEnum, register_backend
+from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend
from vllm.v1.kv_cache_interface import AttentionSpec
# --------------------------------------------------------------
@@ -450,7 +452,7 @@ def build(
prefill_block_table_tensor = None
# \------------------------- Metax Modification -------------------------/
- if vllm_is_batch_invariant():
+ if _bi_mode:
max_num_splits = 1
def schedule(
@@ -645,7 +647,7 @@ def __init__(
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
- self.batch_invariant_enabled = vllm_is_batch_invariant()
+ self.batch_invariant_enabled = _bi_mode
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError(
diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py
index 3d06c6c3..94ec2e51 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/common.py
@@ -198,13 +198,13 @@
from vllm import _custom_ops as ops
from vllm import envs
-from vllm.attention.backends.abstract import (
+from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
MLAAttentionImpl,
)
-from vllm.attention.backends.utils import get_mla_dims
-from vllm.attention.ops.common import cp_lse_ag_out_rs
+from vllm.model_executor.layers.attention.mla_attention import get_mla_dims
+from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
# --------------------------------------------------------------
# Note: use Maca's merge_attn_states to get cuda kernel invoked
@@ -214,7 +214,7 @@
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group
from vllm.model_executor.layers.batch_invariant import (
- vllm_is_batch_invariant,
+ _batch_invariant_MODE as _bi_mode,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -272,7 +272,9 @@ class QueryLenSupport(Enum):
flashinfer_available = False
-from vllm.v1.attention.backends.mla.common import logger
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
CUDNN_WORKSPACE_SIZE = 12800
@@ -1276,7 +1278,7 @@ def _flash_attn_varlen_diff_headdims(
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
- if vllm_is_batch_invariant():
+ if _bi_mode:
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(
diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py
index 823ff4ba..6f10c44e 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/mla/flashmla.py
@@ -7,7 +7,7 @@
import torch
-from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
+from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from ..ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
@@ -17,7 +17,7 @@
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
- vllm_is_batch_invariant,
+ _batch_invariant_MODE as _bi_mode,
)
from vllm.platforms.interface import DeviceCapability
from .common import (
@@ -28,13 +28,13 @@
MLACommonMetadataBuilder,
QueryLenSupport,
)
+from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.attention.backends.utils import (
- AttentionCGSupport,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
)
from vllm.v1.kv_cache_interface import AttentionSpec
-from vllm.attention.backends.registry import AttentionBackendEnum, register_backend
+from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend
logger = init_logger(__name__)
@@ -271,7 +271,7 @@ def _forward_decode(
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
- if vllm_is_batch_invariant():
+ if _bi_mode:
device = q.device
dtype = torch.int32
diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py
index 446772c2..c2c51159 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/ops/merge_attn_states.py
@@ -43,7 +43,7 @@ def supported_headdim(o: torch.Tensor) -> bool:
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
)
else:
- from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
+ from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
diff --git a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py
index 56d2c787..b21344c5 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/impl/attention/utils/fa_utils.py
@@ -2,7 +2,9 @@
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from vllm.attention.utils.fa_utils import logger
+import logging
+
+logger = logging.getLogger(__name__)
from vllm.platforms import current_platform
diff --git a/vllm_fl/dispatch/backends/vendor/metax/metax.py b/vllm_fl/dispatch/backends/vendor/metax/metax.py
index 66de575f..d9dee3c9 100644
--- a/vllm_fl/dispatch/backends/vendor/metax/metax.py
+++ b/vllm_fl/dispatch/backends/vendor/metax/metax.py
@@ -14,7 +14,7 @@
from vllm_fl.dispatch.backends.base import Backend
-from vllm.attention.backends.registry import AttentionBackendEnum, register_backend
+from vllm.v1.attention.backends.registry import AttentionBackendEnum, register_backend
# Register attention backends for MACA
@@ -147,7 +147,7 @@ def attention_backend(self, use_mla: bool = False, use_sparse: bool = False) ->
Returns:
Fully qualified class path string
"""
- from vllm.attention.backends.registry import AttentionBackendEnum
+ from vllm.v1.attention.backends.registry import AttentionBackendEnum
# register before selection
register_attention_backends()
diff --git a/vllm_fl/dispatch/builtin_ops.py b/vllm_fl/dispatch/builtin_ops.py
index ec207a5b..c8a8c990 100644
--- a/vllm_fl/dispatch/builtin_ops.py
+++ b/vllm_fl/dispatch/builtin_ops.py
@@ -21,7 +21,6 @@
# Directory containing vendor backends
_VENDOR_BACKENDS_DIR = os.path.join(os.path.dirname(__file__), "backends", "vendor")
-
def _find_vendor_backend_dir(
vendor_name: str,
available_vendor_dirs: set[str],
@@ -61,7 +60,6 @@ def _get_current_vendor_backend_dirs(available_vendor_dirs: set[str]) -> set[str
"Failed to detect current vendor backend from current_platform."
) from exc
-
def _register_vendor_backends(registry: OpRegistry) -> None:
"""
Auto-discover and register all vendor backends.
diff --git a/vllm_fl/dispatch/config/__init__.py b/vllm_fl/dispatch/config/__init__.py
index bafd6d4e..cc4e8b49 100644
--- a/vllm_fl/dispatch/config/__init__.py
+++ b/vllm_fl/dispatch/config/__init__.py
@@ -14,8 +14,8 @@
get_oot_blacklist,
get_per_op_order,
get_platform_name,
- get_vendor_device_map,
load_platform_config,
+ get_vendor_device_map,
)
__all__ = [
diff --git a/vllm_fl/dispatch/config/nvidia.yaml b/vllm_fl/dispatch/config/nvidia.yaml
index 3c8676ff..0b06a8a3 100644
--- a/vllm_fl/dispatch/config/nvidia.yaml
+++ b/vllm_fl/dispatch/config/nvidia.yaml
@@ -10,8 +10,8 @@ prefer: flagos
strict: false
# Vendor Whitelist (Optional, allows all if not set)
-# allow_vendors:
-# - cuda
+allow_vendors:
+ - cuda
# Vendor Blacklist (Optional)
# deny_vendors:
diff --git a/vllm_fl/dispatch/config/utils.py b/vllm_fl/dispatch/config/utils.py
index 25f1fbfa..087239a1 100644
--- a/vllm_fl/dispatch/config/utils.py
+++ b/vllm_fl/dispatch/config/utils.py
@@ -208,7 +208,6 @@ def get_effective_config() -> dict[str, Any]:
# Return empty config
return {}
-
def get_vendor_device_map() -> dict[str, dict[str, str]]:
"""Load vendor mapping from Python config module.
diff --git a/vllm_fl/dispatch/logger_manager.py b/vllm_fl/dispatch/logger_manager.py
index 37867e8e..57e4a4f0 100644
--- a/vllm_fl/dispatch/logger_manager.py
+++ b/vllm_fl/dispatch/logger_manager.py
@@ -38,7 +38,7 @@ def get_logger(name: str = "vllm_fl.dispatch") -> logging.Logger:
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
- "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
+ "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
diff --git a/vllm_fl/models/glm_moe_dsa.py b/vllm_fl/models/glm_moe_dsa.py
deleted file mode 100644
index 4b88c4d0..00000000
--- a/vllm_fl/models/glm_moe_dsa.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Inference-only GLM-5 (GlmMoeDsa) model.
-
-GLM-5 uses a DeepSeek V2/V3-style architecture with MLA (Multi-head Latent
-Attention) and Mixture of Experts. The HF model type is ``glm_moe_dsa`` and
-the architecture class is ``GlmMoeDsaForCausalLM``.
-
-This thin wrapper inherits from vLLM's ``DeepseekV2ForCausalLM`` which already
-handles MLA, MoE, the DSA Indexer, and MTP speculative decoding layers.
-"""
-
-from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM
-
-
-class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM):
- """GLM-5 model for causal language modelling."""
- pass
diff --git a/vllm_fl/models/kimi_k25.py b/vllm_fl/models/kimi_k25.py
deleted file mode 100644
index 894839e5..00000000
--- a/vllm_fl/models/kimi_k25.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Minimal Kimi-K2.5 model support for text-only inference.
-
-This is a simplified implementation that wraps DeepseekV3 for text-only
-benchmarking. Vision components are not included.
-"""
-
-import copy
-from collections.abc import Iterable
-
-import torch
-from torch import nn
-
-from vllm.config import VllmConfig
-from vllm.distributed import get_pp_group
-from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe import SharedFusedMoE
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
-from vllm.model_executor.model_loader.weight_utils import (
- default_weight_loader,
- maybe_remap_kv_scale_name,
-)
-from vllm.model_executor.models.deepseek_v2 import (
- DeepseekV2Model,
- get_spec_layer_idx_from_weight_name,
-)
-from vllm.model_executor.models.interfaces import SupportsPP
-from vllm.model_executor.models.utils import (
- PPMissingLayer,
- is_pp_missing_parameter,
- maybe_prefix,
-)
-from vllm.sequence import IntermediateTensors
-
-logger = init_logger(__name__)
-
-
-class KimiK25ForConditionalGeneration(nn.Module, SupportsPP):
- """Kimi-K2.5 model for text-only conditional generation.
-
- This is a minimal implementation that uses DeepseekV2Model as the
- language backbone. Vision components are not included for simplicity.
- """
-
- def __init__(
- self,
- vllm_config: VllmConfig,
- prefix: str = "",
- ) -> None:
- super().__init__()
- model_config = vllm_config.model_config
- config = model_config.hf_config
- self.config = config
- quant_config = vllm_config.quant_config
-
- # Extract text config from KimiK25Config
- # The text_config is a DeepseekV3Config
- text_config = getattr(config, "text_config", config)
- self.hidden_size = text_config.hidden_size
-
- # Create a modified vllm_config with text_config as hf_config
- sub_vllm_config = copy.deepcopy(vllm_config)
- sub_vllm_config.model_config.hf_config = text_config
-
- # Build language model using DeepseekV2Model
- self.language_model = DeepseekV2Model(
- vllm_config=sub_vllm_config,
- prefix=maybe_prefix(prefix, "language_model"),
- )
-
- # Build lm_head
- if get_pp_group().is_last_rank:
- vocab_size = getattr(config, "vocab_size", text_config.vocab_size)
- self.lm_head = ParallelLMHead(
- vocab_size,
- text_config.hidden_size,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- else:
- self.lm_head = PPMissingLayer()
-
- self.make_empty_intermediate_tensors = (
- self.language_model.make_empty_intermediate_tensors
- )
- logit_scale = getattr(config, "logit_scale", 1.0)
- vocab_size = getattr(config, "vocab_size", text_config.vocab_size)
- self.logits_processor = LogitsProcessor(vocab_size, scale=logit_scale)
-
- def get_language_model(self) -> nn.Module:
- return self.language_model
-
- def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
- """Apply token embeddings to input_ids."""
- return self.language_model.embed_input_ids(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs,
- ) -> IntermediateTensors:
- if intermediate_tensors is not None:
- inputs_embeds = None
- hidden_states = self.language_model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- )
- return hidden_states
-
- def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
- logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
- return logits
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- """Get expert parameter mapping for MoE layers."""
- text_config = getattr(self.config, "text_config", self.config)
- if not getattr(text_config, "n_routed_experts", None):
- return []
- return SharedFusedMoE.make_expert_params_mapping(
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=text_config.n_routed_experts,
- num_redundant_experts=0,
- )
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
- """Load weights with proper name remapping for Kimi-K2.5."""
- text_config = getattr(self.config, "text_config", self.config)
-
- # Weight name remapping for Kimi-K2.5 -> DeepseekV2
- _KEYS_TO_MODIFY_MAPPING = {
- "language_model.lm_head": "lm_head",
- "language_model.model": "language_model",
- }
-
- stacked_params_mapping = [
- (".gate_up_proj", ".gate_proj", 0),
- (".gate_up_proj", ".up_proj", 1),
- ]
- if getattr(text_config, "kv_lora_rank", None) and getattr(
- text_config, "q_lora_rank", None
- ):
- stacked_params_mapping += [
- (".fused_qkv_a_proj", ".q_a_proj", 0),
- (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
- ]
- expert_params_mapping = self.get_expert_mapping()
-
- params_dict = dict(self.named_parameters())
-
- for args in weights:
- name, loaded_weight = args[:2]
- kwargs = args[2] if len(args) > 2 else {}
-
- # Skip rotary embedding cached values
- if "rotary_emb.inv_freq" in name:
- continue
- if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
- continue
-
- # Skip speculative decode layers
- spec_layer = get_spec_layer_idx_from_weight_name(text_config, name)
- if spec_layer is not None:
- continue
-
- # Skip vision tower weights (not needed for text-only inference)
- if "vision_tower" in name or "mm_projector" in name:
- continue
-
- # Apply key remapping
- for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
- if key_to_modify in name:
- name = name.replace(key_to_modify, new_key)
-
- use_default_weight_loading = False
-
- # Handle stacked parameters
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- if ("mlp.experts." in name) and name not in params_dict:
- continue
- name = name.replace(weight_name, param_name)
- if name.endswith(".bias") and name not in params_dict:
- continue
-
- if is_pp_missing_parameter(name, self):
- continue
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id, **kwargs)
- break
- else:
- # Handle expert parameters
- for _, (
- param_name,
- weight_name,
- expert_id,
- shard_id,
- ) in enumerate(expert_params_mapping):
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
-
- if is_pp_missing_parameter(name, self):
- continue
-
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name,
- expert_id=expert_id,
- shard_id=shard_id,
- **kwargs,
- )
- break
- else:
- use_default_weight_loading = True
-
- if use_default_weight_loading:
- if name.endswith(".bias") and name not in params_dict:
- continue
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
-
- if is_pp_missing_parameter(name, self):
- continue
-
- param = params_dict.get(name)
- if param is None:
- continue
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight, **kwargs)
diff --git a/vllm_fl/models/minicpmo.py b/vllm_fl/models/minicpmo.py
deleted file mode 100644
index 5affd08d..00000000
--- a/vllm_fl/models/minicpmo.py
+++ /dev/null
@@ -1,854 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-# Adapted from
-# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
-
-from collections.abc import Callable, Iterable, Mapping, Sequence
-from typing import Annotated, Any, Literal, TypeAlias
-
-import torch
-from torch import nn
-from transformers import BatchFeature
-from transformers.modeling_outputs import BaseModelOutputWithPast
-from transformers.models.whisper.modeling_whisper import (
- ACT2FN,
- WhisperAttention,
- WhisperConfig,
- WhisperEncoder,
-)
-
-from vllm.config import VllmConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
-from vllm.multimodal.inputs import (
- MultiModalDataDict,
- MultiModalFieldConfig,
- NestedTensors,
-)
-from vllm.multimodal.parse import (
- AudioItem,
- AudioProcessorItems,
- DictEmbeddingItems,
- ModalityData,
- ModalityDataItems,
- MultiModalDataItems,
- MultiModalDataParser,
-)
-from vllm.multimodal.processing import (
- PromptReplacement,
- PromptUpdate,
- PromptUpdateDetails,
-)
-from vllm.utils.tensor_schema import TensorSchema, TensorShape
-
-from vllm.model_executor.models.minicpmv import (
- _MAX_FRAMES_PER_VIDEO,
- MiniCPMV2_6,
- MiniCPMV4_5,
- MiniCPMVDummyInputsBuilder,
- MiniCPMVMultiModalDataParser,
- MiniCPMVMultiModalProcessor,
- MiniCPMVProcessingInfo,
- _minicpmv_field_config,
-)
-from vllm.model_executor.models.utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
-
-CPU_DEVICE = torch.device("cpu")
-
-
-class MiniCPMOAudioFeatureInputs(TensorSchema):
- """
- Dimensions:
- - bns: Batch size * number of audios * number of slices
- - bn: Batch size * number of audios
- - c: Number of channels
- - l: Length
- - s: Number of slices
- """
-
- type: Literal["audio_features"] = "audio_features"
-
- audio_features: Annotated[
- torch.Tensor | list[torch.Tensor],
- TensorShape("bns", "c", "l", dynamic_dims={"l"}),
- ]
- """
- Slice here means chunk. Audio that is too long will be split into slices,
- which is the same as image. Padding is used therefore `audio_features` is
- `torch.Tensor`.
- """
-
- audio_feature_lens: Annotated[
- torch.Tensor | list[torch.Tensor],
- TensorShape("bn", "s"),
- ]
- """
- This should be feature length of each audio slice,
- which equals to `audio_features.shape[-1]`
- """
-
-
-class MiniCPMOAudioEmbeddingInputs(TensorSchema):
- """
- Dimensions:
- - bn: Batch size * number of audios
- - s: Number of slices
- - h: Hidden size (must match language model backbone)
-
- Length of each slice may vary, so pass it as a list.
- """
-
- type: Literal["audio_embeds"] = "audio_embeds"
-
- audio_embeds: Annotated[
- torch.Tensor | list[torch.Tensor],
- TensorShape("bn", "s", "h", dynamic_dims={"s"}),
- ]
-
-
-MiniCPMOAudioInputs: TypeAlias = (
- MiniCPMOAudioFeatureInputs | MiniCPMOAudioEmbeddingInputs
-)
-
-
-def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
- return dict(
- **_minicpmv_field_config(hf_inputs),
- audio_features=MultiModalFieldConfig.batched("audio"),
- audio_feature_lens=MultiModalFieldConfig.batched("audio"),
- audio_embeds=MultiModalFieldConfig.batched("audio"),
- )
-
-
-class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
- def __init__(
- self,
- data: Mapping[str, torch.Tensor],
- fields_factory: Callable[
- [Mapping[str, torch.Tensor]],
- Mapping[str, MultiModalFieldConfig],
- ],
- ) -> None:
- super().__init__(
- data,
- modality="image",
- required_fields={"audio_embeds"},
- fields_factory=fields_factory,
- )
-
-
-class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
- def _parse_audio_data(
- self,
- data: dict[str, torch.Tensor] | ModalityData[AudioItem],
- ) -> ModalityDataItems[Any, Any] | None:
- if isinstance(data, dict):
- return MiniCPMOAudioEmbeddingItems(
- data,
- fields_factory=_minicpmo_field_config,
- )
-
- return super()._parse_audio_data(data)
-
-
-class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
- audio_pattern = "()"
-
- def get_supported_mm_limits(self) -> Mapping[str, int | None]:
- return {**super().get_supported_mm_limits(), "audio": None}
-
- def get_audio_placeholder(
- self,
- audio_lens: int,
- chunk_input: bool = True,
- chunk_length: int = 1,
- ) -> str:
- hf_processor = self.get_hf_processor()
-
- return hf_processor.get_audio_placeholder(
- audio_lens,
- chunk_input=chunk_input,
- chunk_length=chunk_length,
- )
-
- def get_default_audio_pool_step(self) -> int:
- hf_config = self.get_hf_config()
- # MiniCPM-o 4.5 uses pool_step=5, older versions use 2
- return getattr(hf_config, "audio_pool_step", 2)
-
- def get_default_audio_sampling_rate(self) -> int:
- return 16000
-
- def get_chunk_length(self) -> int:
- return self.get_hf_config().audio_chunk_length
-
- def get_max_audio_tokens_per_chunk(self) -> int:
- pool_step = self.get_default_audio_pool_step()
- fbank_feat_in_chunk = 100
- cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
- return (cnn_feat_in_chunk - pool_step) // pool_step + 1
-
- def get_max_audio_chunks_with_most_features(self) -> int:
- return 30
-
- def get_max_audio_tokens(self) -> int:
- num_chunks = self.get_max_audio_chunks_with_most_features()
- return self.get_max_audio_tokens_per_chunk() * num_chunks
-
- def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
- sampling_rate = self.get_default_audio_sampling_rate()
- num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
- return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
-
- def get_num_frames_with_most_features(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- ) -> int:
- max_images = mm_counts.get("image", 0)
- max_videos = mm_counts.get("video", 0)
- max_audios = mm_counts.get("audio", 0)
-
- max_image_tokens = self.get_max_image_tokens() * max_images
- max_audio_tokens = self.get_max_audio_tokens() * max_audios
- max_total_frames = self.get_max_video_frames(
- seq_len - max_image_tokens - max_audio_tokens
- )
- max_frames_per_video = min(
- max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
- )
-
- return max(max_frames_per_video, 1)
-
-
-class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- num_audios = mm_counts.get("audio", 0)
-
- audio_prompt_texts = self.info.audio_pattern * num_audios
-
- return super().get_dummy_text(mm_counts) + audio_prompt_texts
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- mm_options: Mapping[str, BaseDummyOptions] | None = None,
- ) -> MultiModalDataDict:
- num_audios = mm_counts.get("audio", 0)
- audio_len = (
- self.info.get_max_audio_chunks_with_most_features()
- * self.info.get_default_audio_sampling_rate()
- )
-
- audio_overrides = mm_options.get("audio") if mm_options else None
-
- audio_mm_data = {
- "audio": self._get_dummy_audios(
- length=audio_len, num_audios=num_audios, overrides=audio_overrides
- )
- }
-
- return {
- **super().get_dummy_mm_data(seq_len, mm_counts, mm_options),
- **audio_mm_data,
- }
-
-
-class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
- def _get_data_parser(self) -> MultiModalDataParser:
- return MiniCPMOMultiModalDataParser(
- target_sr=self.info.get_default_audio_sampling_rate()
- )
-
- def get_audio_prompt_texts(
- self,
- audio_lens: int,
- chunk_input: bool = True,
- chunk_length: int = 1,
- ) -> str:
- return self.info.get_audio_placeholder(
- audio_lens,
- chunk_input=chunk_input,
- chunk_length=chunk_length,
- )
-
- def process_audios(
- self,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
- if (audios := mm_data.get("audios")) is None:
- return {}
-
- parsed_audios = (
- self._get_data_parser()
- .parse_mm_data({"audio": audios})
- .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
- )
-
- if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
- audio_inputs = {}
- else:
- audio_inputs = self._base_call_hf_processor(
- prompts=[self.info.audio_pattern] * len(parsed_audios),
- mm_data={"audios": [[audio] for audio in parsed_audios]},
- mm_kwargs={**mm_kwargs, "chunk_input": True},
- tok_kwargs=tok_kwargs,
- out_keys={"audio_features", "audio_feature_lens"},
- )
-
- # Avoid padding since we need the output for each audio to be
- # independent of other audios for the cache to work correctly
- unpadded_audio_features = [
- feat[:, :feature_len]
- for feat, feature_len in zip(
- audio_inputs["audio_features"],
- audio_inputs["audio_feature_lens"],
- )
- ]
- audio_inputs["audio_features"] = unpadded_audio_features
-
- return audio_inputs
-
- def process_mm_inputs(
- self,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> Mapping[str, NestedTensors]:
- return {
- **super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs),
- **self.process_audios(mm_data, mm_kwargs, tok_kwargs),
- }
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- out_mm_kwargs: MultiModalKwargsItems,
- ) -> Sequence[PromptUpdate]:
- base_updates = super()._get_prompt_updates(
- mm_items=mm_items,
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- out_mm_kwargs=out_mm_kwargs,
- )
-
- audio_placeholder = self.info.audio_pattern
-
- def get_audio_replacement(item_idx: int):
- audios = mm_items.get_items(
- "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
- )
-
- if isinstance(audios, MiniCPMOAudioEmbeddingItems):
- single_audio_embeds = audios.get(item_idx)["audio_embeds"]
- audio_len = self.info.get_audio_len_by_num_chunks(
- sum(map(len, single_audio_embeds))
- )
- else:
- audio_len = audios.get_audio_length(item_idx)
-
- return PromptUpdateDetails.select_text(
- self.get_audio_prompt_texts(audio_len),
- "",
- )
-
- return [
- *base_updates,
- PromptReplacement(
- modality="audio",
- target=audio_placeholder,
- replacement=get_audio_replacement,
- ),
- ]
-
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- return _minicpmo_field_config(hf_inputs)
-
-
-class MultiModalProjector(nn.Module):
- def __init__(self, in_dim: int, out_dim: int):
- super().__init__()
- self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
- self.relu = nn.ReLU()
- self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
-
- def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.relu(self.linear1(audio_features))
- hidden_states = self.linear2(hidden_states)
- return hidden_states
-
-
-class MiniCPMWhisperEncoderLayer(nn.Module):
- def __init__(self, config: WhisperConfig, layer_idx: int):
- super().__init__()
- self.embed_dim = config.d_model
- self.self_attn = WhisperAttention(
- embed_dim=self.embed_dim,
- num_heads=config.encoder_attention_heads,
- dropout=config.attention_dropout,
- config=config,
- layer_idx=layer_idx,
- )
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- )
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.dropout, training=self.training
- )
- hidden_states = residual + hidden_states
-
- residual = hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.activation_dropout, training=self.training
- )
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.dropout, training=self.training
- )
- hidden_states = residual + hidden_states
-
- if hidden_states.dtype == torch.float16:
- hidden_states = cast_overflow_tensors(hidden_states)
-
- outputs = (hidden_states,)
-
- return outputs
-
-
-class MiniCPMWhisperEncoder(WhisperEncoder):
- def __init__(self, config: WhisperConfig):
- super().__init__(config)
- self.layers = nn.ModuleList(
- [
- MiniCPMWhisperEncoderLayer(config, layer_idx=i)
- for i in range(config.encoder_layers)
- ]
- )
-
- def forward(
- self,
- input_features: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> BaseModelOutputWithPast:
- # Ignore copy
- input_features = input_features.to(
- dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
- )
-
- inputs_embeds = nn.functional.gelu(self.conv1(input_features))
- inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
-
- inputs_embeds = inputs_embeds.permute(0, 2, 1)
-
- embed_pos = self.embed_positions.weight
-
- embed_pos = embed_pos[: inputs_embeds.shape[1], :]
-
- hidden_states = inputs_embeds + embed_pos
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.dropout, training=self.training
- )
-
- encoder_states = ()
-
- for idx, encoder_layer in enumerate(self.layers):
- encoder_states = encoder_states + (hidden_states,)
- to_drop = False
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop: # skip the layer
- to_drop = True
-
- # Ignore copy
- if to_drop:
- layer_outputs = (None, None)
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask,
- )
-
- hidden_states = layer_outputs[0]
-
- hidden_states = self.layer_norm(hidden_states)
- encoder_states = encoder_states + (hidden_states,)
-
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- hidden_states=encoder_states,
- )
-
-
-class MiniCPMOBaseModel:
- """Base mixin class for MiniCPM-O models with audio support."""
-
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
-
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> str | None:
- if modality.startswith("image"):
- return "(./)"
- if modality.startswith("video"):
- return "()"
- if modality.startswith("audio"):
- return "()"
-
- raise ValueError("Only image, video or audio modality is supported")
-
- def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
- # Do not use parameters temporarily
- audio_config = self.config.audio_config
- model = MiniCPMWhisperEncoder(audio_config)
- audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
- self.audio_avg_pooler = nn.AvgPool1d(
- self.config.audio_pool_step, stride=self.config.audio_pool_step
- )
- self.audio_projection_layer = MultiModalProjector(
- in_dim=audio_output_dim, out_dim=self.embed_dim
- )
- self.audio_encoder_layer = -1
- return model
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
- return loader.load_weights(weights)
-
- def subsequent_chunk_mask(
- self,
- size: int,
- chunk_size: int,
- num_left_chunks: int = -1,
- device: torch.device = CPU_DEVICE,
- num_lookhead: int = 0,
- ) -> torch.Tensor:
- ret = torch.zeros(size, size, device=device, dtype=torch.bool)
- # Vectorized computation of row indices and chunk boundaries
- row_indices = torch.arange(size, device=device)
- chunk_indices = row_indices // chunk_size
- if num_left_chunks < 0:
- # If num_left_chunks < 0, start is always 0 for all rows
- start_indices = torch.zeros_like(row_indices)
- else:
- # Compute start indices vectorially
- start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0)
- start_indices = start_chunk_indices * chunk_size
- # Compute ending indices vectorially
- end_chunk_indices = chunk_indices + 1
- end_indices = torch.clamp(
- end_chunk_indices * chunk_size + num_lookhead, max=size
- )
- # Create column indices for broadcasting
- col_indices = torch.arange(size, device=device).unsqueeze(0)
- start_indices = start_indices.unsqueeze(1)
- end_indices = end_indices.unsqueeze(1)
- # Vectorized mask creation
- ret = (col_indices >= start_indices) & (col_indices < end_indices)
- return ret
-
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
- input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
- input_lengths_after_pooling = (
- input_lengths_after_cnn - self.config.audio_pool_step
- ) // self.config.audio_pool_step + 1
- input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
-
- return input_lengths_after_cnn, input_lengths_after_pooling
-
- def get_audio_hidden_states(
- self, data: MiniCPMOAudioFeatureInputs
- ) -> list[torch.Tensor]:
- chunk_length = self.config.audio_chunk_length
-
- # (bs, 80, frames) or [], multi audios need filled in advance
- wavforms_raw = data["audio_features"]
- if isinstance(wavforms_raw, list):
- B = len(wavforms_raw)
- C = wavforms_raw[0].shape[-2]
- L = max(item.shape[-1] for item in wavforms_raw)
- device = wavforms_raw[0].device
- dtype = wavforms_raw[0].dtype
-
- wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
- for i, wavforms_item in enumerate(wavforms_raw):
- L_item = wavforms_item.shape[-1]
- wavforms[i, ..., :L_item] = wavforms_item
- else:
- wavforms = wavforms_raw
-
- # list, [[x1, x2], [y1], [z1]]
- audio_feature_lens_raw = data["audio_feature_lens"]
- if isinstance(audio_feature_lens_raw, torch.Tensor):
- audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
-
- audio_feature_lens = torch.hstack(audio_feature_lens_raw)
- batch_size, _, max_mel_seq_len = wavforms.shape
- max_seq_len = (max_mel_seq_len - 1) // 2 + 1
-
- # Create a sequence tensor of shape (batch_size, max_seq_len)
- seq_range = (
- torch.arange(
- 0,
- max_seq_len,
- dtype=audio_feature_lens.dtype,
- device=audio_feature_lens.device,
- )
- .unsqueeze(0)
- .expand(batch_size, max_seq_len)
- )
- lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
- # Create mask
- padding_mask = seq_range >= lengths_expand # 1 for padded values
-
- audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
- batch_size, 1, max_seq_len, max_seq_len
- )
- audio_attention_mask = audio_attention_mask_.to(
- dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
- )
-
- if chunk_length > 0:
- chunk_num_frame = int(chunk_length * 50)
- chunk_mask = self.subsequent_chunk_mask(
- size=max_seq_len,
- chunk_size=chunk_num_frame,
- num_left_chunks=-1,
- device=audio_attention_mask_.device,
- )
- audio_attention_mask_ = torch.logical_or(
- audio_attention_mask_, torch.logical_not(chunk_mask)
- )
-
- audio_attention_mask[audio_attention_mask_] = float("-inf")
- audio_states = self.apm(
- wavforms, attention_mask=audio_attention_mask
- ).hidden_states[self.audio_encoder_layer]
- audio_embeds = self.audio_projection_layer(audio_states)
-
- audio_embeds = audio_embeds.transpose(1, 2)
- audio_embeds = self.audio_avg_pooler(audio_embeds)
- audio_embeds = audio_embeds.transpose(1, 2)
-
- _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
- audio_feature_lens
- )
-
- num_audio_tokens = feature_lens_after_pooling
-
- final_audio_embeds = list[torch.Tensor]()
- idx = 0
- for i in range(len(audio_feature_lens_raw)):
- target_audio_embeds_lst = list[torch.Tensor]()
- for _ in range(len(audio_feature_lens_raw[i])):
- target_audio_embeds_lst.append(
- audio_embeds[idx, : num_audio_tokens[idx], :]
- )
- idx += 1
-
- final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
-
- return final_audio_embeds
-
- def _parse_and_validate_audio_input(
- self, **kwargs: object
- ) -> MiniCPMOAudioInputs | None:
- audio_features = kwargs.pop("audio_features", None)
- audio_embeds = kwargs.pop("audio_embeds", None)
-
- if audio_features is None and audio_embeds is None:
- return None
-
- if audio_embeds is not None:
- return MiniCPMOAudioEmbeddingInputs(
- type="audio_embeds",
- audio_embeds=audio_embeds,
- )
-
- audio_feature_lens = kwargs.pop("audio_feature_lens")
-
- return MiniCPMOAudioFeatureInputs(
- type="audio_features",
- audio_features=audio_features,
- audio_feature_lens=audio_feature_lens,
- )
-
- def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
- modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
-
- # Preserve the order of modalities if there are multiple of them
- # from the order of kwargs.
- for input_key in kwargs:
- if (
- input_key in ("audio_features", "audio_embeds")
- and "audios" not in modalities
- ):
- modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
-
- return modalities
-
- def _process_audio_input(
- self,
- audio_input: MiniCPMOAudioInputs,
- ) -> torch.Tensor | list[torch.Tensor]:
- if audio_input["type"] == "audio_embeds":
- return audio_input["audio_embeds"]
-
- return self.get_audio_hidden_states(audio_input)
-
- def _process_multimodal_inputs(self, modalities: dict):
- multimodal_embeddings = super()._process_multimodal_inputs(modalities)
-
- for modality in modalities:
- if modality == "audios":
- audio_input = modalities["audios"]
- audio_embeddings = self._process_audio_input(audio_input)
- multimodal_embeddings += tuple(audio_embeddings)
-
- return multimodal_embeddings
-
-
-class MiniCPMO2_6(MiniCPMOBaseModel, MiniCPMV2_6):
- """MiniCPM-O model based on MiniCPMV 2.6 (Qwen2 backbone)."""
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- # Skip MiniCPMV2_6.__init__ version assertion, call MiniCPMVBaseModel directly
- from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel
- MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix)
- # Override version for MiniCPM-O 2.6
- self.version = (2, 6)
- self.apm = self.init_audio_module(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
- )
-
-
-class MiniCPMO4_5(MiniCPMOBaseModel, MiniCPMV4_5):
- """MiniCPM-O 4.5 model based on MiniCPMV 4.5 (Qwen3 backbone)."""
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- # Skip MiniCPMV4_5.__init__ version assertion, call MiniCPMVBaseModel directly
- from vllm.model_executor.models.minicpmv import MiniCPMVBaseModel
- MiniCPMVBaseModel.__init__(self, vllm_config=vllm_config, prefix=prefix)
- # Override version for MiniCPM-O 4.5
- self.version = (4, 5)
- self.apm = self.init_audio_module(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
- )
-
-
-_MINICPMO_SUPPORT_VERSION = {
- (2, 6): MiniCPMO2_6,
- (4, 5): MiniCPMO4_5,
-}
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- MiniCPMOMultiModalProcessor,
- info=MiniCPMOProcessingInfo,
- dummy_inputs=MiniCPMODummyInputsBuilder,
-)
-class MiniCPMO(MiniCPMOBaseModel, MiniCPMV2_6):
- """
- MiniCPM-O model with audio support.
- Different versions use different LLM backbones:
- - Version 2.6: Uses Qwen2
- - Version 4.5: Uses Qwen3
- """
-
- def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
- config = vllm_config.model_config.hf_config
-
- # Determine version from config
- if hasattr(config, "version"):
- version = str(config.version).split(".")
- version = tuple([int(x) for x in version])
- else:
- # Auto-detect version based on config features:
- # - MiniCPM-o 4.5 (Qwen3 backbone): has head_dim attribute,
- # hidden_size=4096, num_hidden_layers=36
- # - MiniCPM-o 2.6 (Qwen2 backbone): no head_dim, different arch
- has_head_dim = hasattr(config, "head_dim")
- is_qwen3_like = (
- has_head_dim
- and getattr(config, "hidden_size", 0) == 4096
- and getattr(config, "num_hidden_layers", 0) == 36
- )
- if is_qwen3_like:
- version = (4, 5)
- else:
- # Default to 2.6 for backward compatibility
- version = (2, 6)
-
- # Dispatch class based on version
- instance_cls = _MINICPMO_SUPPORT_VERSION.get(version)
- if instance_cls is None:
- supported_versions = ", ".join(
- [f"{v[0]}.{v[1]}" for v in sorted(_MINICPMO_SUPPORT_VERSION.keys())]
- )
- raise ValueError(
- f"Currently, MiniCPMO only supports versions "
- f"{supported_versions}. Got version: {version}"
- )
-
- return instance_cls(vllm_config=vllm_config, prefix=prefix)
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- # This __init__ won't be called due to __new__ returning a different class
- super().__init__(vllm_config=vllm_config, prefix=prefix)
- self.apm = self.init_audio_module(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
- )
diff --git a/vllm_fl/models/qwen3_5.py b/vllm_fl/models/qwen3_5.py
deleted file mode 100644
index ad3ec5f9..00000000
--- a/vllm_fl/models/qwen3_5.py
+++ /dev/null
@@ -1,951 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM team.
-# Copyright 2025 The Qwen Team.
-"""Inference-only Qwen3.5 MoE model compatible with HuggingFace weights."""
-
-import typing
-from collections.abc import Callable, Iterable
-
-import torch
-from einops import rearrange
-from torch import nn
-
-from vllm.compilation.decorators import support_torch_compile
-from vllm.config import VllmConfig
-from vllm.distributed import get_pp_group
-from vllm.logger import init_logger
-from vllm.model_executor.layers.layernorm import (
- GemmaRMSNorm as Qwen3_5RMSNorm,
-)
-from vllm.model_executor.layers.linear import MergedColumnParallelLinear
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.mamba.mamba_utils import (
- MambaStateDtypeCalculator,
- MambaStateShapeCalculator,
-)
-from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.vocab_parallel_embedding import (
- ParallelLMHead,
- VocabParallelEmbedding,
-)
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.sequence import IntermediateTensors
-
-from vllm.model_executor.models.interfaces import (
- HasInnerState,
- IsHybrid,
- MixtureOfExperts,
- MultiModalEmbeddings,
- SupportsLoRA,
- SupportsPP,
- _require_is_multimodal,
-)
-from vllm.model_executor.models.utils import (
- AutoWeightsLoader,
- PPMissingLayer,
- _merge_multimodal_embeddings,
- extract_layer_index,
- is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory,
- make_layers,
- maybe_prefix,
-)
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.model_executor.models.qwen3_vl import (
- Qwen3_VisionTransformer,
- Qwen3VLDummyInputsBuilder,
- Qwen3VLForConditionalGeneration,
- Qwen3VLMultiModalProcessor,
- Qwen3VLProcessingInfo,
-)
-
-# from vllm_fl.models.qwen3_next import (
-from vllm.model_executor.models.qwen3_next import (
- Qwen3NextAttention,
- Qwen3NextDecoderLayer,
- Qwen3NextGatedDeltaNet,
- Qwen3NextModel,
- Qwen3NextSparseMoeBlock,
- QwenNextMixtureOfExperts,
-)
-from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeTextConfig
-import vllm_fl.models.qwen3_next # for error ''_OpNamespace' 'vllm' object has no attribute 'gdn_attention_core''
-
-logger = init_logger(__name__)
-
-
-class Qwen3_5SparseMoeBlock(Qwen3NextSparseMoeBlock):
- """Override SparseMoeBlock to read config from hf_text_config."""
-
- def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
- # Temporarily patch hf_config to point to hf_text_config
- # so the parent class reads the right config
- original_hf_config = vllm_config.model_config.hf_config
- vllm_config.model_config.hf_config = (
- vllm_config.model_config.hf_text_config
- )
- try:
- super().__init__(vllm_config=vllm_config, prefix=prefix)
- finally:
- vllm_config.model_config.hf_config = original_hf_config
-
-
-class Qwen3_5GatedDeltaNet(Qwen3NextGatedDeltaNet):
- """Qwen3.5 uses MergedColumnParallelLinear for qkvz projection
- and has a different forward pass without fix_query_key_value_ordering."""
-
- def fix_query_key_value_ordering(
- self,
- mixed_qkvz: torch.Tensor,
- mixed_ba: torch.Tensor,
- ):
- raise NotImplementedError(
- "Qwen3.5 Series dont need to fix query key value ordering"
- )
-
- def create_qkvz_proj(
- self,
- hidden_size: int,
- key_dim: int,
- value_dim: int,
- quant_config: QuantizationConfig | None,
- prefix: str,
- ) -> MergedColumnParallelLinear:
- return MergedColumnParallelLinear(
- input_size=hidden_size,
- output_sizes=[key_dim, key_dim, value_dim, value_dim],
- bias=False,
- quant_config=quant_config,
- prefix=prefix,
- )
-
- def __init__(self, config, model_config=None, cache_config=None,
- quant_config=None, speculative_config=None, prefix=""):
- # Call grandparent init to skip Qwen3NextGatedDeltaNet.__init__
- # but set up the same attributes
- nn.Module.__init__(self)
- from vllm.distributed import (
- divide,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- )
- from vllm.config import get_current_vllm_config
- from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- RowParallelLinear,
- )
- from vllm.model_executor.layers.layernorm import RMSNormGated
- from vllm.model_executor.layers.mamba.mamba_mixer2 import (
- mamba_v2_sharded_weight_loader,
- )
- from vllm.model_executor.model_loader.weight_utils import (
- sharded_weight_loader,
- )
- from vllm.model_executor.utils import set_weight_attrs
- from vllm.platforms import current_platform
- from transformers.activations import ACT2FN
-
- self.tp_size = get_tensor_model_parallel_world_size()
- self.tp_rank = get_tensor_model_parallel_rank()
- self.hidden_size = config.hidden_size
- self.num_v_heads = config.linear_num_value_heads
- self.num_k_heads = config.linear_num_key_heads
- self.head_k_dim = config.linear_key_head_dim
- self.head_v_dim = config.linear_value_head_dim
- self.key_dim = self.head_k_dim * self.num_k_heads
- self.value_dim = self.head_v_dim * self.num_v_heads
-
- self.conv_kernel_size = config.linear_conv_kernel_dim
- self.layer_idx = extract_layer_index(prefix)
- self.activation = config.hidden_act
- self.act = ACT2FN[config.hidden_act]
- self.layer_norm_epsilon = config.rms_norm_eps
- self.prefix = prefix
-
- self.config = config
- self.model_config = model_config
- self.cache_config = cache_config
- self.quant_config = quant_config
- self.speculative_config = speculative_config
- self.num_spec = (
- self.speculative_config.num_speculative_tokens
- if self.speculative_config
- else 0
- )
-
- # Convolution
- self.conv_dim = self.key_dim * 2 + self.value_dim
- self.conv1d = ColumnParallelLinear(
- input_size=self.conv_kernel_size,
- output_size=self.conv_dim,
- bias=False,
- prefix=f"{prefix}.conv1d",
- )
- self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
-
- query_key_settings = (self.key_dim, 0, False)
- value_settings = (self.value_dim, 0, False)
-
- delattr(self.conv1d.weight, "weight_loader")
- set_weight_attrs(
- self.conv1d.weight,
- {
- "weight_loader": mamba_v2_sharded_weight_loader(
- [query_key_settings, query_key_settings, value_settings],
- self.tp_size,
- self.tp_rank,
- )
- },
- )
-
- # Use MergedColumnParallelLinear for qkvz projection
- self.in_proj_qkvz = self.create_qkvz_proj(
- hidden_size=self.hidden_size,
- key_dim=self.key_dim,
- value_dim=self.value_dim,
- quant_config=quant_config,
- prefix=f"{prefix}.in_proj_qkvz",
- )
-
- self.projection_size_ba = self.num_v_heads * 2
- self.in_proj_ba = MergedColumnParallelLinear(
- input_size=self.hidden_size,
- output_sizes=[self.num_v_heads, self.num_v_heads],
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.in_proj_ba",
- )
-
- self.dt_bias = nn.Parameter(
- torch.ones(self.num_v_heads // self.tp_size),
- )
- self.A_log = nn.Parameter(
- torch.empty(divide(self.num_v_heads, self.tp_size)),
- )
-
- set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
- set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
-
- self.norm = RMSNormGated(
- self.head_v_dim,
- eps=self.layer_norm_epsilon,
- group_size=None,
- norm_before_gate=True,
- device=current_platform.current_device(),
- dtype=getattr(config, "dtype", torch.bfloat16),
- )
-
- self.out_proj = RowParallelLinear(
- self.value_dim,
- self.hidden_size,
- bias=False,
- input_is_parallel=True,
- quant_config=quant_config,
- prefix=f"{prefix}.out_proj",
- )
-
- compilation_config = get_current_vllm_config().compilation_config
- if prefix in compilation_config.static_forward_context:
- raise ValueError(f"Duplicate layer name: {prefix}")
- compilation_config.static_forward_context[prefix] = self
-
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- output: torch.Tensor,
- ):
- num_tokens = hidden_states.size(0)
-
- # Input Projection - Qwen3.5 style
- mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
- qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
- z_size = self.value_dim // self.tp_size
- mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
- z = z.reshape(z.size(0), -1, self.head_v_dim)
- ba, _ = self.in_proj_ba(hidden_states)
- b, a = ba.chunk(2, dim=-1)
-
- b = b.contiguous()
- a = a.contiguous()
-
- # Core Attention
- core_attn_out = torch.zeros(
- (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
- dtype=hidden_states.dtype,
- device=hidden_states.device,
- )
-
- torch.ops.vllm.gdn_attention_core(
- mixed_qkv,
- b,
- a,
- core_attn_out,
- self.prefix,
- )
-
- # Output Projection
- z_shape_og = z.shape
- core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
- z = z.reshape(-1, z.shape[-1])
- core_attn_out = self.norm(core_attn_out, z)
- core_attn_out = core_attn_out.reshape(z_shape_og)
- core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
- output[:num_tokens], _ = self.out_proj(core_attn_out)
-
-
-class Qwen3_5DecoderLayer(Qwen3NextDecoderLayer):
- def __init__(
- self,
- vllm_config: VllmConfig,
- layer_type: str,
- prefix: str = "",
- ) -> None:
- # Call nn.Module.__init__ directly, skipping Qwen3NextDecoderLayer's
- super(Qwen3NextDecoderLayer, self).__init__()
-
- config = vllm_config.model_config.hf_text_config
- model_config = vllm_config.model_config
- cache_config = vllm_config.cache_config
- quant_config = vllm_config.quant_config
- speculative_config = vllm_config.speculative_config
-
- self.layer_type = layer_type
- self.layer_idx = extract_layer_index(prefix)
-
- if self.layer_type == "linear_attention":
- self.linear_attn = Qwen3_5GatedDeltaNet(
- config,
- model_config=model_config,
- cache_config=cache_config,
- quant_config=quant_config,
- speculative_config=speculative_config,
- prefix=f"{prefix}.linear_attn",
- )
- elif self.layer_type == "full_attention":
- self.self_attn = Qwen3NextAttention(
- config,
- model_config=model_config,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.self_attn",
- )
- else:
- raise ValueError(f"Invalid layer_type {self.layer_type}")
-
- # Qwen3.5 MoE: all layers use sparse MoE blocks
- self.mlp = Qwen3_5SparseMoeBlock(
- vllm_config=vllm_config,
- prefix=f"{prefix}.mlp",
- )
-
- self.input_layernorm = Qwen3_5RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
- self.post_attention_layernorm = Qwen3_5RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
-
- self.layer_scale = getattr(config, "layer_scale", False)
- if self.layer_scale:
- self.attn_layer_scale = torch.nn.Parameter(
- torch.zeros(
- 1, 1, config.hidden_size,
- dtype=getattr(config, "dtype", torch.bfloat16),
- ),
- )
- self.ffn_layer_scale = torch.nn.Parameter(
- torch.zeros(
- 1, 1, config.hidden_size,
- dtype=getattr(config, "dtype", torch.bfloat16),
- ),
- )
-
-
-@support_torch_compile(
- dynamic_arg_dims={
- "input_ids": 0,
- "positions": -1,
- "intermediate_tensors": 0,
- "inputs_embeds": 0,
- }
-)
-class Qwen3_5Model(Qwen3NextModel):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super(Qwen3NextModel, self).__init__()
-
- config = vllm_config.model_config.hf_text_config
- parallel_config = vllm_config.parallel_config
-
- eplb_config = parallel_config.eplb_config
- self.num_redundant_experts = eplb_config.num_redundant_experts
-
- self.config = config
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = VocabParallelEmbedding(
- self.vocab_size,
- config.hidden_size,
- )
-
- def get_layer(prefix: str):
- return Qwen3_5DecoderLayer(
- vllm_config,
- layer_type=config.layer_types[extract_layer_index(prefix)],
- prefix=prefix,
- )
-
- self.start_layer, self.end_layer, self.layers = make_layers(
- config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
- )
- self.make_empty_intermediate_tensors = (
- make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], config.hidden_size
- )
- )
-
- if get_pp_group().is_last_rank:
- self.norm = Qwen3_5RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
- else:
- self.norm = PPMissingLayer()
-
- def load_fused_expert_weights(
- self,
- name: str,
- params_dict: dict,
- loaded_weight: torch.Tensor,
- shard_id: str,
- num_experts: int,
- ) -> bool:
- param = params_dict[name]
- weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
- loaded_local_expert = False
- for expert_id in range(num_experts):
- curr_expert_weight = loaded_weight[expert_id]
- success = weight_loader(
- param,
- curr_expert_weight,
- name,
- shard_id,
- expert_id,
- return_success=True,
- )
- if success:
- loaded_local_expert = True
- return loaded_local_expert
-
- def load_weights(
- self, weights: Iterable[tuple[str, torch.Tensor]]
- ) -> set[str]:
- stacked_params_mapping = [
- # self attention
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- # mlp
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- # GDN - Qwen3.5 style
- ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
- ("in_proj_qkvz", "in_proj_z", 3),
- ("in_proj_ba", "in_proj_b", 0),
- ("in_proj_ba", "in_proj_a", 1),
- ]
-
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
- expert_params_mapping = self.get_expert_mapping()
- is_fused_expert = False
- fused_expert_params_mapping = [
- ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
- ("experts.w2_weight", "experts.down_proj", 0, "w2"),
- ]
- num_experts = (
- self.config.num_experts
- if hasattr(self.config, "num_experts")
- else 0
- )
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if name.startswith("mtp."):
- continue
-
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if (
- "experts.gate_up_proj" in name
- or "experts.down_proj" in name
- ):
- is_fused_expert = True
- expert_params_mapping = fused_expert_params_mapping
-
- if weight_name not in name:
- continue
- if "mlp.experts" in name:
- continue
-
- name = name.replace(weight_name, param_name)
- if name.endswith(".bias") and name not in params_dict:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- if name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- if isinstance(shard_id, tuple):
- # Multi-shard: split loaded_weight and load each
- layer = weight_loader.__self__
- split_sizes = [
- layer.output_sizes[s] for s in shard_id
- ]
- output_dim = getattr(param, "output_dim", 0)
- parts = loaded_weight.split(
- split_sizes, dim=output_dim
- )
- for s, part in zip(shard_id, parts):
- weight_loader(param, part, s)
- else:
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- is_expert_weight = False
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- is_expert_weight = True
- name_mapped = name.replace(weight_name, param_name)
- if is_pp_missing_parameter(name_mapped, self):
- continue
- if is_fused_expert:
- if "experts.gate_up_proj" in name:
- loaded_weight = loaded_weight.chunk(2, dim=-2)
- success_w1 = self.load_fused_expert_weights(
- name_mapped, params_dict,
- loaded_weight[0], "w1", num_experts,
- )
- success_w3 = self.load_fused_expert_weights(
- name_mapped, params_dict,
- loaded_weight[1], "w3", num_experts,
- )
- success = success_w1 and success_w3
- else:
- success = self.load_fused_expert_weights(
- name_mapped, params_dict,
- loaded_weight, shard_id, num_experts,
- )
- if success:
- name = name_mapped
- break
- else:
- if (
- name_mapped.endswith(".bias")
- or name_mapped.endswith("_bias")
- ) and name_mapped not in params_dict:
- continue
- param = params_dict[name_mapped]
- weight_loader = param.weight_loader
- success = weight_loader(
- param,
- loaded_weight,
- name_mapped,
- shard_id=shard_id,
- expert_id=expert_id,
- return_success=True,
- )
- if success:
- name = name_mapped
- break
- else:
- if is_expert_weight:
- continue
- if name.endswith(".bias") and name not in params_dict:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- if name not in params_dict:
- logger.warning_once(
- f"Parameter {name} not found in params_dict, "
- "skip loading"
- )
- continue
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-class Qwen3_5_MoeMixtureOfExperts(MixtureOfExperts):
- def _get_moe_model_layers(self):
- """Get the model layers, handling both CausalLM and VL wrapper."""
- if hasattr(self, "language_model"):
- return self.language_model.model.layers
- return self.model.layers
-
- def update_physical_experts_metadata(
- self,
- num_physical_experts: int,
- num_local_physical_experts: int,
- ) -> None:
- assert self.num_local_physical_experts == num_local_physical_experts
- self.num_physical_experts = num_physical_experts
- self.num_local_physical_experts = num_local_physical_experts
- self.num_redundant_experts = (
- num_physical_experts - self.num_logical_experts
- )
- for layer in self._get_moe_model_layers():
- if isinstance(layer.mlp, Qwen3_5SparseMoeBlock):
- moe = layer.mlp
- moe.n_local_physical_experts = num_local_physical_experts
- moe.n_physical_experts = num_physical_experts
- moe.n_redundant_experts = self.num_redundant_experts
- moe.experts.update_expert_map()
-
- def set_moe_parameters(self):
- self.expert_weights = []
- self.moe_layers = []
- example_moe = None
- for layer in self._get_moe_model_layers():
- if isinstance(layer, Qwen3_5DecoderLayer) and isinstance(
- layer.mlp, (Qwen3NextSparseMoeBlock, Qwen3_5SparseMoeBlock)
- ):
- example_moe = layer.mlp
- self.moe_layers.append(layer.mlp.experts)
-
- if example_moe is None:
- raise RuntimeError(
- "No Qwen3_5 MoE layer found in the model.layers."
- )
-
- self.num_moe_layers = len(self.moe_layers)
- self.num_expert_groups = 1
- self.num_shared_experts = 0
- self.num_logical_experts = example_moe.n_logical_experts
- self.num_physical_experts = example_moe.n_physical_experts
- self.num_local_physical_experts = example_moe.n_local_physical_experts
- self.num_routed_experts = example_moe.n_routed_experts
- self.num_redundant_experts = example_moe.n_redundant_experts
-
-
-class Qwen3_5MoeForCausalLM(
- nn.Module,
- HasInnerState,
- SupportsLoRA,
- SupportsPP,
- Qwen3_5_MoeMixtureOfExperts,
- IsHybrid,
-):
- packed_modules_mapping = {
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"],
- }
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- config = vllm_config.model_config.hf_text_config
- self.vllm_config = vllm_config
- self.model_config = vllm_config.model_config
- cache_config = vllm_config.cache_config
-
- scheduler_config = vllm_config.scheduler_config
- assert not cache_config.enable_prefix_caching, (
- "Qwen3.5 currently does not support prefix caching"
- )
- self.quant_config = vllm_config.quant_config
-
- super().__init__()
- self.config = config
- self.scheduler_config = scheduler_config
- self.model = Qwen3_5Model(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
- )
-
- if get_pp_group().is_last_rank:
- if config.tie_word_embeddings:
- self.lm_head = self.model.embed_tokens
- else:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- else:
- self.lm_head = PPMissingLayer()
-
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.make_empty_intermediate_tensors = (
- self.model.make_empty_intermediate_tensors
- )
-
- # Set MoE hyperparameters
- self.set_moe_parameters()
-
- def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.model.embed_input_ids(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: object,
- ):
- hidden_states = self.model(
- input_ids, positions, intermediate_tensors, inputs_embeds
- )
- return hidden_states
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- return self.logits_processor(self.lm_head, hidden_states)
-
- def load_weights(
- self, weights: Iterable[tuple[str, torch.Tensor]]
- ) -> set[str]:
-
- def _remap_weights(
- weights: Iterable[tuple[str, torch.Tensor]]
- ) -> Iterable[tuple[str, torch.Tensor]]:
- for name, tensor in weights:
- # The HF checkpoint for Qwen3_5MoeForConditionalGeneration
- # uses model.language_model.* for the text model, but our
- # model structure uses model.* directly (no language_model
- # wrapper level).
- name = name.replace("model.language_model.", "model.")
- yield name, tensor
-
- loader = AutoWeightsLoader(
- self,
- skip_prefixes=["mtp.", "model.visual."],
- )
- return loader.load_weights(_remap_weights(weights))
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- return self.model.get_expert_mapping()
-
- @classmethod
- def get_mamba_state_dtype_from_config(
- cls,
- vllm_config: "VllmConfig",
- ) -> tuple[torch.dtype, torch.dtype]:
- return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
- vllm_config.model_config.dtype,
- vllm_config.cache_config.mamba_cache_dtype,
- )
-
- @classmethod
- def get_mamba_state_shape_from_config(
- cls, vllm_config: "VllmConfig"
- ) -> tuple[tuple[int, int], tuple[int, int]]:
- parallel_config = vllm_config.parallel_config
- hf_config = vllm_config.model_config.hf_text_config
- tp_size = parallel_config.tensor_parallel_size
- num_spec = (
- vllm_config.speculative_config.num_speculative_tokens
- if vllm_config.speculative_config
- else 0
- )
- return MambaStateShapeCalculator.gated_delta_net_state_shape(
- tp_size,
- hf_config.linear_num_key_heads,
- hf_config.linear_num_value_heads,
- hf_config.linear_key_head_dim,
- hf_config.linear_value_head_dim,
- hf_config.linear_conv_kernel_dim,
- num_spec,
- )
-
-
-########################################################
-# Qwen3_5-MoE Multimodal (VL) Wrapper
-########################################################
-
-
-class Qwen3_5MoeProcessingInfo(Qwen3VLProcessingInfo):
- """Processing info that uses Qwen3_5MoeConfig instead of Qwen3VLConfig."""
-
- def get_hf_config(self):
- return self.ctx.get_hf_config(Qwen3_5MoeConfig)
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- Qwen3VLMultiModalProcessor,
- info=Qwen3_5MoeProcessingInfo,
- dummy_inputs=Qwen3VLDummyInputsBuilder,
-)
-class Qwen3_5MoeForConditionalGeneration(
- Qwen3VLForConditionalGeneration,
- HasInnerState,
- Qwen3_5_MoeMixtureOfExperts,
- IsHybrid,
-):
- """Multimodal Qwen3.5-MoE model wrapping VisionTransformer + CausalLM."""
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
- nn.Module.__init__(self)
- config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- multimodal_config = vllm_config.model_config.multimodal_config
-
- self.config = config
- self.multimodal_config = multimodal_config
- self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
- self.video_pruning_rate = multimodal_config.video_pruning_rate
- self.is_multimodal_pruning_enabled = (
- multimodal_config.is_multimodal_pruning_enabled()
- )
-
- # Vision encoder
- if not multimodal_config.get_limit_per_prompt(
- "image"
- ) and not multimodal_config.get_limit_per_prompt("video"):
- self.visual = None
- else:
- self.visual = Qwen3_VisionTransformer(
- config.vision_config,
- norm_eps=getattr(config, "rms_norm_eps", 1e-6),
- quant_config=quant_config,
- multimodal_config=multimodal_config,
- prefix=maybe_prefix(prefix, "visual"),
- )
-
- # Language model (MoE CausalLM)
- self.language_model = Qwen3_5MoeForCausalLM(
- vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "language_model"),
- )
-
- self.make_empty_intermediate_tensors = (
- self.language_model.make_empty_intermediate_tensors
- )
-
- # Deepstack support
- self.use_deepstack = hasattr(
- config.vision_config, "deepstack_visual_indexes"
- ) and bool(config.vision_config.deepstack_visual_indexes)
- self.deepstack_num_level = (
- len(config.vision_config.deepstack_visual_indexes)
- if self.use_deepstack
- else 0
- )
- if self.use_deepstack and self.visual is not None:
- self.deepstack_input_embeds = [
- torch.zeros(
- vllm_config.scheduler_config.max_num_batched_tokens,
- config.text_config.hidden_size,
- )
- for _ in range(self.deepstack_num_level)
- ]
- else:
- self.deepstack_input_embeds = None
- self.visual_dim = config.vision_config.out_hidden_size
- self.multiscale_dim = self.visual_dim * self.deepstack_num_level
-
- # Set MoE hyperparameters
- self.set_moe_parameters()
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings: MultiModalEmbeddings | None = None,
- *,
- is_multimodal: torch.Tensor | None = None,
- handle_oov_mm_token: bool = False,
- ) -> torch.Tensor:
- inputs_embeds = self._embed_text_input_ids(
- input_ids,
- self.language_model.embed_input_ids,
- is_multimodal=is_multimodal,
- handle_oov_mm_token=handle_oov_mm_token,
- )
-
- if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
- return inputs_embeds
-
- is_multimodal = _require_is_multimodal(is_multimodal)
-
- inputs_embeds = _merge_multimodal_embeddings(
- inputs_embeds=inputs_embeds,
- multimodal_embeddings=multimodal_embeddings,
- is_multimodal=is_multimodal,
- )
-
- return inputs_embeds
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: object,
- ) -> torch.Tensor | IntermediateTensors:
- if intermediate_tensors is not None:
- inputs_embeds = None
-
- hidden_states = self.language_model.model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- )
-
- return hidden_states
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- return self.language_model.compute_logits(hidden_states)
-
- def load_weights(
- self, weights: Iterable[tuple[str, torch.Tensor]]
- ) -> set[str]:
- skip_prefixes = ["mtp."]
- if self.visual is None:
- skip_prefixes.append("visual.")
- loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
- return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- return self.language_model.model.get_expert_mapping()
-
- @classmethod
- def get_mamba_state_dtype_from_config(
- cls,
- vllm_config: "VllmConfig",
- ) -> tuple[torch.dtype, torch.dtype]:
- return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
- vllm_config.model_config.dtype,
- vllm_config.cache_config.mamba_cache_dtype,
- )
-
- @classmethod
- def get_mamba_state_shape_from_config(
- cls, vllm_config: "VllmConfig"
- ) -> tuple[tuple[int, int], tuple[int, int]]:
- parallel_config = vllm_config.parallel_config
- hf_config = vllm_config.model_config.hf_text_config
- tp_size = parallel_config.tensor_parallel_size
- num_spec = (
- vllm_config.speculative_config.num_speculative_tokens
- if vllm_config.speculative_config
- else 0
- )
- return MambaStateShapeCalculator.gated_delta_net_state_shape(
- tp_size,
- hf_config.linear_num_key_heads,
- hf_config.linear_num_value_heads,
- hf_config.linear_key_head_dim,
- hf_config.linear_value_head_dim,
- hf_config.linear_conv_kernel_dim,
- num_spec,
- )
diff --git a/vllm_fl/models/qwen3_next.py b/vllm_fl/models/qwen3_next.py
deleted file mode 100644
index f3afa3fe..00000000
--- a/vllm_fl/models/qwen3_next.py
+++ /dev/null
@@ -1,1396 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Inference-only Qwen3Next model."""
-
-from collections.abc import Iterable
-from itertools import islice
-
-import torch
-from einops import rearrange
-from torch import nn
-from transformers.activations import ACT2FN
-from vllm.attention.backends.abstract import AttentionMetadata
-from vllm.attention.layer import Attention
-from vllm.compilation.decorators import support_torch_compile
-from vllm.config import (
- CacheConfig,
- ModelConfig,
- SpeculativeConfig,
- VllmConfig,
- get_current_vllm_config,
-)
-from vllm.distributed import (
- divide,
- get_ep_group,
- get_pp_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_gather,
-)
-from vllm.forward_context import ForwardContext, get_forward_context
-from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe import SharedFusedMoE
-from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
-from vllm.model_executor.layers.layernorm import (
- GemmaRMSNorm as Qwen3NextRMSNorm,
-)
-from vllm.model_executor.layers.layernorm import RMSNormGated
-from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- QKVParallelLinear,
- ReplicatedLinear,
- RowParallelLinear,
-)
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.mamba.abstract import MambaBase
-from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
-from vllm.model_executor.layers.mamba.mamba_utils import (
- MambaStateDtypeCalculator,
- MambaStateShapeCalculator,
-)
-from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
- causal_conv1d_fn,
- causal_conv1d_update,
-)
-from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.model_executor.layers.vocab_parallel_embedding import (
- ParallelLMHead,
- VocabParallelEmbedding,
-)
-from vllm.model_executor.model_loader.weight_utils import (
- default_weight_loader,
- sharded_weight_loader,
-)
-from vllm.model_executor.models.interfaces import (
- HasInnerState,
- IsHybrid,
- MixtureOfExperts,
- SupportsLoRA,
- SupportsPP,
-)
-from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
-from vllm.model_executor.models.utils import (
- AutoWeightsLoader,
- PPMissingLayer,
- extract_layer_index,
- is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory,
- make_layers,
- maybe_prefix,
- sequence_parallel_chunk,
-)
-from vllm.model_executor.utils import set_weight_attrs
-from vllm.platforms import current_platform
-from vllm.sequence import IntermediateTensors
-from vllm.transformers_utils.configs import Qwen3NextConfig
-from vllm.triton_utils import tl, triton
-from vllm.utils.torch_utils import direct_register_custom_op
-from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
-from vllm_fl.ops.fla import ChunkGatedDeltaRuleOp, FusedRecurrentGatedDeltaRuleOp
-
-logger = init_logger(__name__)
-
-KVCache = tuple[torch.Tensor, torch.Tensor]
-
-class Qwen3NextSparseMoeBlock(nn.Module):
- def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- config = vllm_config.model_config.hf_config
- parallel_config = vllm_config.parallel_config
- quant_config = vllm_config.quant_config
-
- self.tp_size = get_tensor_model_parallel_world_size()
-
- self.ep_group = get_ep_group().device_group
- self.ep_rank = get_ep_group().rank_in_group
- self.ep_size = self.ep_group.size()
- self.n_routed_experts = config.num_experts
-
- self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
-
- if self.tp_size > config.num_experts:
- raise ValueError(
- f"Tensor parallel size {self.tp_size} is greater than "
- f"the number of experts {config.num_experts}."
- )
-
- # Load balancing settings.
- vllm_config = get_current_vllm_config()
- eplb_config = vllm_config.parallel_config.eplb_config
- self.enable_eplb = parallel_config.enable_eplb
-
- self.n_logical_experts = self.n_routed_experts
- self.n_redundant_experts = eplb_config.num_redundant_experts
- self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
- self.n_local_physical_experts = self.n_physical_experts // self.ep_size
-
- self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
- self.physical_expert_end = (
- self.physical_expert_start + self.n_local_physical_experts
- )
-
- self.gate = ReplicatedLinear(
- config.hidden_size,
- config.num_experts,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate",
- )
-
- self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
-
- if config.shared_expert_intermediate_size > 0:
- self.shared_expert = Qwen3NextMLP(
- hidden_size=config.hidden_size,
- intermediate_size=config.shared_expert_intermediate_size,
- hidden_act=config.hidden_act,
- quant_config=quant_config,
- reduce_results=False,
- expert_gate=self.shared_expert_gate,
- prefix=f"{prefix}.shared_expert",
- )
- else:
- self.shared_expert = None
-
- self.experts = SharedFusedMoE(
- shared_experts=self.shared_expert,
- gate=self.gate,
- num_experts=self.n_routed_experts,
- top_k=config.num_experts_per_tok,
- hidden_size=config.hidden_size,
- intermediate_size=config.moe_intermediate_size,
- reduce_results=False,
- renormalize=config.norm_topk_prob,
- quant_config=quant_config,
- prefix=f"{prefix}.experts",
- enable_eplb=self.enable_eplb,
- num_redundant_experts=self.n_redundant_experts,
- is_sequence_parallel=self.is_sequence_parallel,
- routing_method_type=RoutingMethodType.Renormalize,
- )
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # NOTE: hidden_states can have either 1D or 2D shape.
- orig_shape = hidden_states.shape
- num_tokens, hidden_dim = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_dim)
-
- if self.is_sequence_parallel:
- hidden_states = sequence_parallel_chunk(hidden_states)
-
- if self.experts.is_internal_router:
- # In this case, the gate/router runs inside the FusedMoE class
- final_hidden_states = self.experts(
- hidden_states=hidden_states, router_logits=hidden_states
- )
- else:
- # router_logits: (num_tokens, n_experts)
- router_logits, _ = self.gate(hidden_states)
- final_hidden_states = self.experts(
- hidden_states=hidden_states, router_logits=router_logits
- )
-
- if self.shared_expert is not None:
- final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
-
- if self.is_sequence_parallel:
- final_hidden_states = tensor_model_parallel_all_gather(
- final_hidden_states, 0
- )
- final_hidden_states = final_hidden_states[:num_tokens]
- elif self.tp_size > 1:
- final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
- final_hidden_states
- )
-
- return final_hidden_states.view(orig_shape)
-
-
-class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
- @property
- def mamba_type(self) -> str:
- return "gdn_attention"
-
- def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
- return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
- self.model_config.dtype, self.cache_config.mamba_cache_dtype
- )
-
- def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
- return MambaStateShapeCalculator.gated_delta_net_state_shape(
- self.tp_size,
- self.num_k_heads,
- self.num_v_heads,
- self.head_k_dim,
- self.head_v_dim,
- self.conv_kernel_size,
- self.num_spec,
- )
-
- def __init__(
- self,
- config: Qwen3NextConfig,
- model_config: ModelConfig | None = None,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- speculative_config: SpeculativeConfig | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.tp_size = get_tensor_model_parallel_world_size()
- self.tp_rank = get_tensor_model_parallel_rank()
- self.hidden_size = config.hidden_size
- self.num_v_heads = config.linear_num_value_heads
- self.num_k_heads = config.linear_num_key_heads
- self.head_k_dim = config.linear_key_head_dim
- self.head_v_dim = config.linear_value_head_dim
- self.key_dim = self.head_k_dim * self.num_k_heads
- self.value_dim = self.head_v_dim * self.num_v_heads
-
- self.conv_kernel_size = config.linear_conv_kernel_dim
- self.layer_idx = extract_layer_index(prefix)
- self.activation = config.hidden_act
- self.act = ACT2FN[config.hidden_act]
- self.layer_norm_epsilon = config.rms_norm_eps
- self.prefix = prefix
-
- self.config = config
- self.model_config = model_config
- self.cache_config = cache_config
- self.quant_config = quant_config
- self.speculative_config = speculative_config
- self.num_spec = (
- self.speculative_config.num_speculative_tokens
- if self.speculative_config
- else 0
- )
-
- # QKV
- self.conv_dim = self.key_dim * 2 + self.value_dim
- self.conv1d = ColumnParallelLinear(
- input_size=self.conv_kernel_size,
- output_size=self.conv_dim,
- bias=False,
- prefix=f"{prefix}.conv1d",
- )
- self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
-
- # projection of the input hidden states
- self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
- self.projection_size_ba = self.num_v_heads * 2
- self.in_proj_qkvz = ColumnParallelLinear(
- input_size=self.hidden_size,
- output_size=self.projection_size_qkvz,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.in_proj_qkvz",
- )
- # ba_proj doesn't support blockwise fp8 quantization.
- self.in_proj_ba = ColumnParallelLinear(
- input_size=self.hidden_size,
- output_size=self.projection_size_ba,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.in_proj_ba",
- )
-
- query_key_settings = (self.key_dim, 0, False)
- value_settings = (self.value_dim, 0, False)
-
- delattr(self.conv1d.weight, "weight_loader")
- set_weight_attrs(
- self.conv1d.weight,
- {
- "weight_loader": mamba_v2_sharded_weight_loader(
- [
- query_key_settings,
- query_key_settings,
- value_settings,
- ],
- self.tp_size,
- self.tp_rank,
- )
- },
- )
-
- # selective projection used to make dt, B and C input dependant
-
- # time step projection (discretization)
- # instantiate once and copy inv_dt in init_weights of PretrainedModel
- self.dt_bias = nn.Parameter(
- torch.ones(self.num_v_heads // self.tp_size),
- )
- self.A_log = nn.Parameter(
- torch.empty(
- divide(self.num_v_heads, self.tp_size),
- )
- )
-
- set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)})
- set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
-
- self.norm = RMSNormGated(
- self.head_v_dim,
- eps=self.layer_norm_epsilon,
- group_size=None,
- norm_before_gate=True,
- device=current_platform.current_device(),
- dtype=config.dtype,
- )
-
- self.out_proj = RowParallelLinear(
- self.value_dim,
- self.hidden_size,
- bias=False,
- input_is_parallel=True,
- quant_config=quant_config,
- prefix=f"{prefix}.out_proj",
- )
-
- compilation_config = get_current_vllm_config().compilation_config
- if prefix in compilation_config.static_forward_context:
- raise ValueError(f"Duplicate layer name: {prefix}")
- compilation_config.static_forward_context[prefix] = self
- self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp(
- output_final_state = True,
- use_qk_l2norm_in_kernel=True,
- )
-
- self.fused_recurrent_gated_delta_rule_multi_query = FusedRecurrentGatedDeltaRuleOp(
- inplace_final_state=True,
- use_qk_l2norm_in_kernel=True,
- )
- self.fused_recurrent_gated_delta_rule_remain_query = FusedRecurrentGatedDeltaRuleOp(
- inplace_final_state=True,
- use_qk_l2norm_in_kernel=True,
- )
-
- def fix_query_key_value_ordering(
- self,
- mixed_qkvz,
- mixed_ba,
- ):
- """
- Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
- """
- new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
- self.num_k_heads // self.tp_size,
- (
- self.head_k_dim
- + self.head_k_dim
- + (self.head_v_dim + self.head_v_dim)
- * self.num_v_heads
- // self.num_k_heads
- ),
- )
- new_tensor_shape_ba = mixed_qkvz.size()[:-1] + (
- self.num_k_heads // self.tp_size,
- 2 * self.num_v_heads // self.num_k_heads,
- )
-
- mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
- mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
-
- split_arg_list_qkvz = [
- self.head_k_dim,
- self.head_k_dim,
- (self.num_v_heads // self.num_k_heads * self.head_v_dim),
- (self.num_v_heads // self.num_k_heads * self.head_v_dim),
- ]
- split_arg_list_ba = [
- self.num_v_heads // self.num_k_heads,
- self.num_v_heads // self.num_k_heads,
- ]
-
- # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
- # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn],
- # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
- (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2)
- (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)
-
- # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
- value = value.reshape(value.size(0), -1, self.head_v_dim)
- z = z.reshape(z.size(0), -1, self.head_v_dim)
- b = b.reshape(b.size(0), self.num_v_heads // self.tp_size)
- a = a.reshape(a.size(0), self.num_v_heads // self.tp_size)
-
- return query, key, value, z, b, a
-
- def rearrange_mixed_qkv(self, mixed_qkv):
- if mixed_qkv is None:
- return None, None, None
- query, key, value = torch.split(
- mixed_qkv,
- [
- self.key_dim // self.tp_size,
- self.key_dim // self.tp_size,
- self.value_dim // self.tp_size,
- ],
- dim=-1,
- )
- query, key = map(
- lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim),
- (query, key),
- )
- value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
- return query, key, value
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- output: torch.Tensor,
- ):
- """
- Forward pass with three parts:
- 1. Input projection
- 2. Core attention (custom op)
- 3. Output projection
- """
- num_tokens = hidden_states.size(0)
-
- # ============================================================
- # Part 1: Input Projection
- # ============================================================
- projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
- projected_states_ba, _ = self.in_proj_ba(hidden_states)
- query, key, value, z, b, a = self.fix_query_key_value_ordering(
- projected_states_qkvz, projected_states_ba
- )
- query, key, value = map(
- lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
- )
- mixed_qkv = torch.cat((query, key, value), dim=-1)
-
- # ============================================================
- # Part 2: Core Attention (Custom Op)
- # ============================================================
- # Note: we should not use torch.empty here like other attention backends,
- # see discussions in https://github.com/vllm-project/vllm/pull/28182
- core_attn_out = torch.zeros(
- (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
- dtype=hidden_states.dtype,
- device=hidden_states.device,
- )
-
- torch.ops.vllm.gdn_attention_core(
- mixed_qkv,
- b,
- a,
- core_attn_out,
- self.prefix,
- )
-
- # ============================================================
- # Part 3: Output Projection
- # ============================================================
- z_shape_og = z.shape
- # Reshape input data into 2D tensor
- core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
- z = z.reshape(-1, z.shape[-1])
- core_attn_out = self.norm(core_attn_out, z)
- core_attn_out = core_attn_out.reshape(z_shape_og)
- core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
- output[:num_tokens], _ = self.out_proj(core_attn_out)
-
- def _forward_core(
- self,
- mixed_qkv: torch.Tensor,
- b: torch.Tensor,
- a: torch.Tensor,
- core_attn_out: torch.Tensor,
- ):
- """
- Core attention computation (called by custom op).
- """
- forward_context = get_forward_context()
- attn_metadata: AttentionMetadata = forward_context.attn_metadata
-
- if attn_metadata is None:
- # V1 profile run
- return
-
- assert isinstance(attn_metadata, dict)
- attn_metadata = attn_metadata[self.prefix]
- assert isinstance(attn_metadata, GDNAttentionMetadata)
- has_initial_state = attn_metadata.has_initial_state
- spec_query_start_loc = attn_metadata.spec_query_start_loc
- non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
- spec_sequence_masks = attn_metadata.spec_sequence_masks
- spec_token_indx = attn_metadata.spec_token_indx
- non_spec_token_indx = attn_metadata.non_spec_token_indx
- spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
- non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
- self_kv_cache = self.kv_cache[forward_context.virtual_engine]
- conv_state = self_kv_cache[0].transpose(-1, -2)
- ssm_state = self_kv_cache[1]
- num_actual_tokens = attn_metadata.num_actual_tokens
- num_accepted_tokens = attn_metadata.num_accepted_tokens
-
- mixed_qkv = mixed_qkv[:num_actual_tokens]
- b = b[:num_actual_tokens]
- a = a[:num_actual_tokens]
-
- # 1. Convolution sequence transformation
- conv_weights = self.conv1d.weight.view(
- self.conv1d.weight.size(0), self.conv1d.weight.size(2)
- )
-
- if spec_sequence_masks is not None:
- if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
- mixed_qkv_spec = mixed_qkv
- mixed_qkv_non_spec = None
- else:
- mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
- mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
- else:
- mixed_qkv_spec = None
- mixed_qkv_non_spec = mixed_qkv
-
- # 1.1: Process the multi-query part
- if spec_sequence_masks is not None:
- mixed_qkv_spec = causal_conv1d_update(
- mixed_qkv_spec,
- conv_state,
- conv_weights,
- self.conv1d.bias,
- self.activation,
- conv_state_indices=spec_state_indices_tensor[:, 0][
- : attn_metadata.num_spec_decodes
- ],
- num_accepted_tokens=num_accepted_tokens,
- query_start_loc=spec_query_start_loc,
- max_query_len=spec_state_indices_tensor.size(-1),
- validate_data=False,
- )
-
- # 1.2: Process the remaining part
- if attn_metadata.num_prefills > 0:
- mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
- # - "cache_indices" updates the conv_state cache in positions
- # pointed to by "state_indices_tensor"
- mixed_qkv_non_spec = causal_conv1d_fn(
- mixed_qkv_non_spec_T,
- conv_weights,
- self.conv1d.bias,
- activation=self.activation,
- conv_states=conv_state,
- has_initial_state=has_initial_state,
- cache_indices=non_spec_state_indices_tensor,
- query_start_loc=non_spec_query_start_loc,
- metadata=attn_metadata,
- ).transpose(0, 1)
- elif attn_metadata.num_decodes > 0:
- mixed_qkv_non_spec = causal_conv1d_update(
- mixed_qkv_non_spec,
- conv_state,
- conv_weights,
- self.conv1d.bias,
- self.activation,
- conv_state_indices=non_spec_state_indices_tensor[
- : attn_metadata.num_actual_tokens
- ],
- validate_data=True,
- )
- else:
- mixed_qkv_non_spec = None
-
- query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
- query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
- mixed_qkv_non_spec
- )
-
- g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
-
- if spec_sequence_masks is not None:
- if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
- g_spec = g
- beta_spec = beta
- g_non_spec = None
- beta_non_spec = None
- else:
- g_spec = g.index_select(1, spec_token_indx)
- beta_spec = beta.index_select(1, spec_token_indx)
- g_non_spec = g.index_select(1, non_spec_token_indx)
- beta_non_spec = beta.index_select(1, non_spec_token_indx)
- else:
- g_spec = None
- beta_spec = None
- g_non_spec = g
- beta_non_spec = beta
-
- # 2. Recurrent attention
-
- # 2.1: Process the multi-query part
- if spec_sequence_masks is not None:
- core_attn_out_spec, last_recurrent_state = self.fused_recurrent_gated_delta_rule_multi_query(
- q=query_spec,
- k=key_spec,
- v=value_spec,
- g=g_spec,
- beta=beta_spec,
- initial_state=ssm_state,
- cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
- ssm_state_indices=spec_state_indices_tensor,
- num_accepted_tokens=num_accepted_tokens,
- )
- else:
- core_attn_out_spec, last_recurrent_state = None, None
-
- # 2.2: Process the remaining part
- if attn_metadata.num_prefills > 0:
- initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
- initial_state[~has_initial_state, ...] = 0
- (
- core_attn_out_non_spec,
- last_recurrent_state,
- ) = self.chunk_gated_delta_rule(
- q=query_non_spec.contiguous(),
- k=key_non_spec.contiguous(),
- v=value_non_spec.contiguous(),
- g=g_non_spec,
- beta=beta_non_spec,
- initial_state=initial_state,
- cu_seqlens=non_spec_query_start_loc,
- )
- # Init cache
- ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
- ssm_state.dtype
- )
- elif attn_metadata.num_decodes > 0:
- core_attn_out_non_spec, last_recurrent_state = (
- self.fused_recurrent_gated_delta_rule_remain_query(
- q=query_non_spec,
- k=key_non_spec,
- v=value_non_spec,
- g=g_non_spec,
- beta=beta_non_spec,
- initial_state=ssm_state,
- cu_seqlens=non_spec_query_start_loc[
- : attn_metadata.num_decodes + 1
- ],
- ssm_state_indices=non_spec_state_indices_tensor,
- )
- )
- else:
- core_attn_out_non_spec, last_recurrent_state = None, None
-
- # 3. Merge core attention output
- if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
- merged_out = torch.empty(
- (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
- dtype=core_attn_out_non_spec.dtype,
- device=core_attn_out_non_spec.device,
- )
- merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
- merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
- core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
- elif spec_sequence_masks is not None:
- core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
- else:
- core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
-
-
-class Qwen3NextAttention(nn.Module):
- def __init__(
- self,
- config: Qwen3NextConfig,
- model_config: ModelConfig | None = None,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = config.num_attention_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = config.num_key_value_heads
- if self.total_num_kv_heads >= tp_size:
- # Number of KV heads is greater than TP size, so we partition
- # the KV heads across multiple tensor parallel GPUs.
- assert self.total_num_kv_heads % tp_size == 0
- else:
- # Number of KV heads is less than TP size, so we replicate
- # the KV heads across multiple tensor parallel GPUs.
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
- self.dual_chunk_attention_config = getattr(
- config, "dual_chunk_attention_config", None
- )
- self.attn_output_gate = getattr(config, "attn_output_gate", True)
-
- self.qkv_proj = QKVParallelLinear(
- config.hidden_size,
- self.head_dim,
- self.total_num_heads * (1 + self.attn_output_gate),
- self.total_num_kv_heads,
- bias=getattr(config, "qkv_bias", False),
- quant_config=quant_config,
- prefix=f"{prefix}.qkv_proj",
- )
-
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- config.hidden_size,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.o_proj",
- )
-
- self.rotary_emb = get_rope(
- head_size=self.head_dim,
- max_position=config.max_position_embeddings,
- rope_parameters=config.rope_parameters,
- dual_chunk_attention_config=self.dual_chunk_attention_config,
- )
-
- self.attn = Attention(
- self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.attn",
- **{
- "layer_idx": extract_layer_index(prefix),
- "dual_chunk_attention_config": self.dual_chunk_attention_config,
- }
- if self.dual_chunk_attention_config
- else {},
- )
-
- self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
-
- def forward(
- self,
- positions: torch.Tensor,
- output: torch.Tensor,
- hidden_states: torch.Tensor,
- ):
- qkv, _ = self.qkv_proj(hidden_states)
-
- if self.attn_output_gate:
- q_gate, k, v = qkv.split(
- [self.q_size * 2, self.kv_size, self.kv_size], dim=-1
- )
- orig_shape = q_gate.shape[:-1]
- q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
- q, gate = torch.chunk(q_gate, 2, dim=-1)
- q = q.reshape(*orig_shape, -1)
- gate = gate.reshape(*orig_shape, -1)
- else:
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
- q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(
- -1, self.num_heads * self.head_dim
- )
- k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(
- -1, self.num_kv_heads * self.head_dim
- )
-
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v)
-
- if self.attn_output_gate:
- gate = torch.sigmoid(gate)
- attn_output = attn_output * gate
-
- output[:], _ = self.o_proj(attn_output)
-
-
-class Qwen3NextDecoderLayer(nn.Module):
- def __init__(
- self,
- vllm_config: VllmConfig,
- layer_type: str,
- prefix: str = "",
- ) -> None:
- super().__init__()
-
- config = vllm_config.model_config.hf_config
- model_config = vllm_config.model_config
- cache_config = vllm_config.cache_config
- quant_config = vllm_config.quant_config
- speculative_config = vllm_config.speculative_config
-
- self.layer_type = layer_type
- self.layer_idx = extract_layer_index(prefix)
-
- if self.layer_type == "linear_attention":
- self.linear_attn = Qwen3NextGatedDeltaNet(
- config,
- model_config=model_config,
- cache_config=cache_config,
- quant_config=quant_config,
- speculative_config=speculative_config,
- prefix=f"{prefix}.linear_attn",
- )
- elif self.layer_type == "full_attention":
- self.self_attn = Qwen3NextAttention(
- config,
- model_config=model_config,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.self_attn",
- )
- else:
- raise ValueError(f"Invalid layer_type {self.layer_type}")
-
- mlp_only_layers = (
- [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
- )
- if (self.layer_idx not in mlp_only_layers) and (
- config.num_experts > 0
- and (self.layer_idx + 1) % config.decoder_sparse_step == 0
- ):
- self.mlp = Qwen3NextSparseMoeBlock(
- vllm_config=vllm_config,
- prefix=f"{prefix}.mlp",
- )
- else:
- self.mlp = Qwen3NextMLP(
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- quant_config=quant_config,
- )
-
- self.input_layernorm = Qwen3NextRMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
- self.post_attention_layernorm = Qwen3NextRMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
-
- self.layer_scale = getattr(config, "layer_scale", False)
- if self.layer_scale:
- self.attn_layer_scale = torch.nn.Parameter(
- torch.zeros(
- 1,
- 1,
- config.hidden_size,
- dtype=config.dtype,
- ),
- )
- self.ffn_layer_scale = torch.nn.Parameter(
- torch.zeros(
- 1,
- 1,
- config.hidden_size,
- dtype=config.dtype,
- ),
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- positions: torch.Tensor = None,
- **kwargs: object,
- ):
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
-
- self_attention_output = torch.empty_like(hidden_states)
- if self.layer_type == "linear_attention":
- self.linear_attn(
- hidden_states=hidden_states,
- output=self_attention_output,
- )
- elif self.layer_type == "full_attention":
- self.self_attn(
- hidden_states=hidden_states,
- output=self_attention_output,
- positions=positions,
- )
- else:
- raise ValueError("Invalid layer_type")
- hidden_states = self_attention_output
-
- if self.layer_scale:
- if len(hidden_states.shape) == 2:
- hidden_states = hidden_states * (
- self.attn_layer_scale.to(hidden_states.dtype)[0] + 1
- )
- else:
- hidden_states = hidden_states * (
- self.attn_layer_scale.to(hidden_states.dtype) + 1
- )
-
- # Fully Connected
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
-
- hidden_states = self.mlp(hidden_states)
-
- if self.layer_scale:
- if len(hidden_states.shape) == 2:
- hidden_states = hidden_states * (
- self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1
- )
- else:
- assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
- f"shape must be the same {len(hidden_states.shape)}, "
- f"{len(self.ffn_layer_scale.shape)}"
- )
- hidden_states = hidden_states * (
- self.ffn_layer_scale.to(hidden_states.dtype) + 1
- )
-
- return hidden_states, residual
-
-
-@support_torch_compile
-class Qwen3NextModel(nn.Module):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- config: Qwen3NextConfig = vllm_config.model_config.hf_config
- parallel_config = vllm_config.parallel_config
-
- eplb_config = parallel_config.eplb_config
- self.num_redundant_experts = eplb_config.num_redundant_experts
-
- self.config = config
-
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = VocabParallelEmbedding(
- self.vocab_size,
- config.hidden_size,
- )
-
- def get_layer(prefix: str):
- return Qwen3NextDecoderLayer(
- vllm_config,
- layer_type=config.layer_types[extract_layer_index(prefix)],
- prefix=prefix,
- )
-
- self.start_layer, self.end_layer, self.layers = make_layers(
- config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
- )
- self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], config.hidden_size
- )
-
- if get_pp_group().is_last_rank:
- self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- else:
- self.norm = PPMissingLayer()
-
- def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.embed_tokens(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if get_pp_group().is_first_rank:
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- else:
- hidden_states = self.embed_input_ids(input_ids)
- residual = None
- else:
- assert intermediate_tensors is not None
- hidden_states = intermediate_tensors["hidden_states"]
- residual = intermediate_tensors["residual"]
-
- for layer in islice(self.layers, self.start_layer, self.end_layer):
- hidden_states, residual = layer(
- positions=positions,
- hidden_states=hidden_states,
- residual=residual,
- )
-
- if not get_pp_group().is_last_rank:
- return IntermediateTensors(
- {"hidden_states": hidden_states, "residual": residual}
- )
- hidden_states, _ = self.norm(hidden_states, residual)
- return hidden_states
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- # Params for weights, fp8 weight scales, fp8 activation scales
- # (param_name, weight_name, expert_id, shard_id)
- return SharedFusedMoE.make_expert_params_mapping(
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=self.config.num_experts,
- num_redundant_experts=self.num_redundant_experts,
- )
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
-
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
- expert_params_mapping = self.get_expert_mapping()
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
-
- if name.startswith("mtp."):
- continue
-
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
-
- if "mlp.experts" in name:
- continue
-
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- # Skip layers on other devices.
- if is_pp_missing_parameter(name, self):
- continue
- # name = apply_attn_prefix(name, params_dict)
- if name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip layers on other devices.
- if is_pp_missing_parameter(name, self):
- continue
- # Skip loading extra bias for GPTQ models.
- if (
- name.endswith(".bias") or name.endswith("_bias")
- ) and name not in params_dict:
- continue
- if name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(
- param,
- loaded_weight,
- name,
- shard_id=shard_id,
- expert_id=expert_id,
- )
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- if is_pp_missing_parameter(name, self):
- continue
- if name not in params_dict:
- logger.warning_once(
- f"Parameter {name} not found in params_dict, skip loading"
- )
- continue
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-class QwenNextMixtureOfExperts(MixtureOfExperts):
- def update_physical_experts_metadata(
- self,
- num_physical_experts: int,
- num_local_physical_experts: int,
- ) -> None:
- assert self.num_local_physical_experts == num_local_physical_experts
- self.num_physical_experts = num_physical_experts
- self.num_local_physical_experts = num_local_physical_experts
- self.num_redundant_experts = num_physical_experts - self.num_logical_experts
- for layer in self.model.layers:
- if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
- moe = layer.mlp
- moe.n_local_physical_experts = num_local_physical_experts
- moe.n_physical_experts = num_physical_experts
- moe.n_redundant_experts = self.num_redundant_experts
- moe.experts.update_expert_map()
-
- def set_moe_parameters(self):
- self.expert_weights = []
-
- self.moe_layers = []
- example_moe = None
- for layer in self.model.layers:
- if isinstance(layer, Qwen3NextDecoderLayer) and isinstance(
- layer.mlp, Qwen3NextSparseMoeBlock
- ):
- example_moe = layer.mlp
- self.moe_layers.append(layer.mlp.experts)
-
- if example_moe is None:
- raise RuntimeError("No Qwen3Next layer found in the model.layers.")
-
- # Set MoE hyperparameters
- self.num_moe_layers = len(self.moe_layers)
- self.num_expert_groups = 1
- self.num_shared_experts = 0
- self.num_logical_experts = example_moe.n_logical_experts
- self.num_physical_experts = example_moe.n_physical_experts
- self.num_local_physical_experts = example_moe.n_local_physical_experts
- self.num_routed_experts = example_moe.n_routed_experts
- self.num_redundant_experts = example_moe.n_redundant_experts
-
-
-class Qwen3NextForCausalLM(
- nn.Module,
- HasInnerState,
- SupportsLoRA,
- SupportsPP,
- QwenNextMixtureOfExperts,
- IsHybrid,
-):
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": ["gate_proj", "up_proj"],
- }
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- config = vllm_config.model_config.hf_config
- self.vllm_config = vllm_config
- self.model_config = vllm_config.model_config
- cache_config = vllm_config.cache_config
-
- scheduler_config = vllm_config.scheduler_config
- assert not cache_config.enable_prefix_caching, (
- "Qwen3Next currently does not support prefix caching"
- )
- self.quant_config = vllm_config.quant_config
-
- super().__init__()
- self.config = config
- self.scheduler_config = scheduler_config
- self.model = Qwen3NextModel(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
- )
-
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.make_empty_intermediate_tensors = (
- self.model.make_empty_intermediate_tensors
- )
-
- # Set MoE hyperparameters
- self.set_moe_parameters()
-
- def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.model.embed_input_ids(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: object,
- ):
- hidden_states = self.model(
- input_ids, positions, intermediate_tensors, inputs_embeds
- )
-
- return hidden_states
-
- @classmethod
- def get_mamba_state_dtype_from_config(
- cls,
- vllm_config: "VllmConfig",
- ) -> tuple[torch.dtype, torch.dtype]:
- return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
- vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
- )
-
- @classmethod
- def get_mamba_state_shape_from_config(
- cls, vllm_config: "VllmConfig"
- ) -> tuple[tuple[int, int], tuple[int, int]]:
- parallel_config = vllm_config.parallel_config
- hf_config = vllm_config.model_config.hf_config
- tp_size = parallel_config.tensor_parallel_size
- num_spec = (
- vllm_config.speculative_config.num_speculative_tokens
- if vllm_config.speculative_config
- else 0
- )
- return MambaStateShapeCalculator.gated_delta_net_state_shape(
- tp_size,
- hf_config.linear_num_key_heads,
- hf_config.linear_num_value_heads,
- hf_config.linear_key_head_dim,
- hf_config.linear_value_head_dim,
- hf_config.linear_conv_kernel_dim,
- num_spec,
- )
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- return self.logits_processor(self.lm_head, hidden_states)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(
- self,
- skip_prefixes=["mtp."],
- )
- return loader.load_weights(weights)
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- return self.model.get_expert_mapping()
-
-
-def gdn_attention_core(
- mixed_qkv: torch.Tensor,
- b: torch.Tensor,
- a: torch.Tensor,
- core_attn_out: torch.Tensor,
- layer_name: str,
-) -> None:
- """
- Custom op for the core attention computation.
- Only handles the convolution + recurrent attention part.
- Input/output projections are handled outside this op.
- """
- forward_context: ForwardContext = get_forward_context()
- self = forward_context.no_compile_layers[layer_name]
- self._forward_core(
- mixed_qkv=mixed_qkv,
- b=b,
- a=a,
- core_attn_out=core_attn_out,
- )
-
-
-def gdn_attention_core_fake(
- mixed_qkv: torch.Tensor,
- b: torch.Tensor,
- a: torch.Tensor,
- core_attn_out: torch.Tensor,
- layer_name: str,
-) -> None:
- """Fake implementation for torch.compile."""
- return
-
-
-if not hasattr(torch.ops.vllm, "gdn_attention_core"):
- direct_register_custom_op(
- op_name="gdn_attention_core",
- op_func=gdn_attention_core,
- mutates_args=["core_attn_out"],
- fake_impl=gdn_attention_core_fake,
- )
-
-
-@triton.jit
-def fused_gdn_gating_kernel(
- g,
- beta_output,
- A_log,
- a,
- b,
- dt_bias,
- seq_len,
- NUM_HEADS: tl.constexpr,
- beta: tl.constexpr,
- threshold: tl.constexpr,
- BLK_HEADS: tl.constexpr,
-):
- i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
- head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
- off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off
- mask = head_off < NUM_HEADS
- blk_A_log = tl.load(A_log + head_off, mask=mask)
- blk_a = tl.load(a + off, mask=mask)
- blk_b = tl.load(b + off, mask=mask)
- blk_bias = tl.load(dt_bias + head_off, mask=mask)
- # If the model is loaded in fp16, without the .float() here, A might be -inf
- x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
- softplus_x = tl.where(
- beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
- )
- blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
- tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
- # compute beta_output = sigmoid(b)
- blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
- tl.store(
- beta_output + off, blk_beta_output.to(beta_output.dtype.element_ty), mask=mask
- )
-
-
-def fused_gdn_gating(
- A_log: torch.Tensor,
- a: torch.Tensor,
- b: torch.Tensor,
- dt_bias: torch.Tensor,
- beta: float = 1.0,
- threshold: float = 20.0,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Fused computation of g and beta for Gated Delta Net.
- g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
- beta_output = b.sigmoid()
- TODO maybe use torch.compile to replace this triton kernel
- """
- batch, num_heads = a.shape
- seq_len = 1
- grid = (batch, seq_len, triton.cdiv(num_heads, 8))
- g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device)
- beta_output = torch.empty(1, batch, num_heads, dtype=b.dtype, device=b.device)
- fused_gdn_gating_kernel[grid](
- g,
- beta_output,
- A_log,
- a,
- b,
- dt_bias,
- seq_len,
- num_heads,
- beta,
- threshold,
- 8,
- num_warps=1,
- )
- return g, beta_output
diff --git a/vllm_fl/ops/fused_moe/fused_moe.py b/vllm_fl/ops/fused_moe/fused_moe.py
index b35dc3c5..72465aa9 100644
--- a/vllm_fl/ops/fused_moe/fused_moe.py
+++ b/vllm_fl/ops/fused_moe/fused_moe.py
@@ -8,7 +8,6 @@
import functools
import torch
import torch.nn.functional as F
-import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
@@ -17,7 +16,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
_get_config_quant_dtype,
try_get_optimal_moe_config,
- invoke_fused_moe_kernel,
+ dispatch_fused_moe_kernel,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.triton_utils import tl
@@ -109,7 +108,9 @@ def fused_experts_impl(
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
- CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
+ # Note: upstream removed VLLM_FUSED_MOE_CHUNK_SIZE in v0.18.1;
+ # keep the old default (64K) for the FL chunked implementation.
+ CHUNK_SIZE = 65536
M = min(num_tokens, CHUNK_SIZE)
config_dtype = _get_config_dtype_str(
@@ -209,7 +210,7 @@ def fused_experts_impl(
ignore_invalid_experts=True,
)
- invoke_fused_moe_kernel(
+ dispatch_fused_moe_kernel(
qcurr_hidden_states,
w1,
intermediate_cache1,
@@ -235,6 +236,9 @@ def fused_experts_impl(
# Activation function with multiplication
# todo: dispatch to flag_gems and other backends
+ # v0.18.1: activation may be a MoEActivation enum; normalize to str
+ if hasattr(activation, "value"):
+ activation = activation.value
if activation == "silu":
intermediate_cache2 = call_op(
"silu_and_mul", None, intermediate_cache1.view(-1, N)
@@ -262,7 +266,7 @@ def fused_experts_impl(
block_shape=block_shape,
)
- invoke_fused_moe_kernel(
+ dispatch_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
diff --git a/vllm_fl/ops/fused_moe/layer.py b/vllm_fl/ops/fused_moe/layer.py
index 3d3368b3..740dcd59 100644
--- a/vllm_fl/ops/fused_moe/layer.py
+++ b/vllm_fl/ops/fused_moe/layer.py
@@ -11,8 +11,8 @@
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
-from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
-from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
+from vllm.model_executor.layers.fused_moe.router.routing_simulator_router import RoutingSimulator
+from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import grouped_topk
from vllm.platforms import current_platform
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
@@ -21,7 +21,7 @@
if current_platform.is_cuda_alike():
- from vllm.model_executor.layers.fused_moe.fused_moe import (
+ from vllm.model_executor.layers.fused_moe.router.base_router import (
eplb_map_to_physical_and_record,
)
else:
@@ -76,7 +76,30 @@ def forward_oot(
else:
return result
- forward_native = forward_oot
+ def forward_native(
+ self,
+ layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
+ x: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ shared_experts_input: torch.Tensor | None,
+ ) -> torch.Tensor:
+ """v0.18.1 forward_native signature — called when custom ops are
+ disabled (e.g. under torch.compile / Inductor).
+ Note: shared experts are handled by the upstream runner
+ (_apply_quant_method), so we must NOT return a tuple here."""
+ return fused_experts(
+ hidden_states=x,
+ w1=layer.w13_weight,
+ w2=layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=layer.activation,
+ quant_config=self.moe_quant_config,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ )
class FusedMoEFL(FusedMoE):
@@ -127,7 +150,7 @@ def select_experts(
plain MoE implementations without redundant experts.
"""
from vllm_fl.ops.fused_moe.fused_moe import fused_topk
- from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk_bias
+ from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import fused_topk_bias
if self.enable_eplb:
if self.quant_method.supports_eplb:
diff --git a/vllm_fl/patches/glm_moe_dsa.py b/vllm_fl/patches/glm_moe_dsa.py
index 4d205537..f4a7dd67 100644
--- a/vllm_fl/patches/glm_moe_dsa.py
+++ b/vllm_fl/patches/glm_moe_dsa.py
@@ -52,52 +52,6 @@ def _patched_set(self, special_tokens=None):
pass
-def patch_is_deepseek_mla():
- """Patch ModelConfig.is_deepseek_mla to recognise glm_moe_dsa as MLA."""
- from vllm.config.model import ModelConfig
- _orig_is_mla = ModelConfig.is_deepseek_mla.fget
-
- @property
- def _patched_is_mla(self):
- if (
- hasattr(self.hf_text_config, "model_type")
- and self.hf_text_config.model_type == "glm_moe_dsa"
- and getattr(self.hf_text_config, "kv_lora_rank", None)
- is not None
- ):
- return True
- return _orig_is_mla(self)
-
- ModelConfig.is_deepseek_mla = _patched_is_mla
-
-
-def patch_fp8_mqa_logits_dim():
- """Fix k_scale dim mismatch for deep_gemm fp8_mqa_logits.
-
- vLLM 0.13.0 passes k_scale as [N, 1] but deep_gemm 2.3.0 expects [N].
- Upstream fix: https://github.com/vllm-project/vllm/pull/32652
- We wrap fp8_mqa_logits to flatten k_scale before calling the native impl.
- """
- import vllm.utils.deep_gemm as dg_mod
-
- dg_mod._lazy_init()
- _orig_impl = dg_mod._fp8_mqa_logits_impl
- if _orig_impl is None:
- return
-
- def _fixed_fp8_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke):
- k_fp8, k_scale = kv
- return _orig_impl(
- q, (k_fp8, k_scale.flatten()), weights,
- cu_seqlen_ks, cu_seqlen_ke,
- )
-
- dg_mod._fp8_mqa_logits_impl = _fixed_fp8_mqa_logits
- dg_mod._lazy_init = lambda: None
- logger.info("Patched fp8_mqa_logits: flatten k_scale [N,1] -> [N] "
- "for deep_gemm 2.3.0 compat")
-
-
def patch_indexer_schedule_metadata():
"""Fix schedule_metadata not computed when VLLM_USE_DEEP_GEMM=0.
@@ -144,8 +98,6 @@ def _patched_build(self, common_prefix_len, common_attn_metadata,
def apply_platform_patches():
"""All GLM-5 patches needed at platform registration time."""
patch_tokenizer_compat()
- patch_fp8_mqa_logits_dim()
-
def patch_indexer_rope_reshape():
"""Fix RoPE output shape in Indexer.forward for DSA models.
@@ -222,6 +174,5 @@ def _patched_forward(self, hidden_states, qr, positions, rotary_emb):
def apply_model_patches():
"""All GLM-5 patches needed at model registration time."""
- patch_is_deepseek_mla()
patch_indexer_schedule_metadata()
patch_indexer_rope_reshape()
diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py
index 6a76600e..203e7a96 100644
--- a/vllm_fl/platform.py
+++ b/vllm_fl/platform.py
@@ -1,11 +1,11 @@
# Copyright (c) 2025 BAAI. All rights reserved.
-# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/platforms/cuda.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/platforms/cuda.py
# Below is the original copyright:
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
-from typing import TYPE_CHECKING, Optional, TypeVar
+from typing import TYPE_CHECKING, TypeVar
from typing_extensions import ParamSpec
import torch
@@ -16,15 +16,15 @@
except (ImportError, OSError):
pass # NPU or other platforms may not have vllm._C
-from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.platforms import Platform, PlatformEnum
from vllm.platforms.interface import DeviceCapability
+from vllm.v1.attention.backends.registry import AttentionBackendEnum
if TYPE_CHECKING:
- from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
+ from vllm.v1.attention.selector import AttentionSelectorConfig
else:
VllmConfig = None
CacheDType = None
@@ -48,6 +48,12 @@ class PlatformFL(Platform):
vendor_name = device_info.vendor_name
device_type = get_device_type(vendor_name)
device_name = get_device_name(vendor_name)
+ # cuda_alike (nvidia/metax): device_name = vendor_name (not used in torch.device)
+ # non-cuda_alike (iluvatar/ascend): device_name = device_type (used in torch.device)
+ device_name = device_info.vendor_name if (
+ device_info.device_type == "cuda" and device_info.vendor_name != "iluvatar"
+ ) else device_info.device_type
+ device_type = device_info.device_type
dispatch_key = device_info.dispatch_key
torch_device_fn = device_info.torch_device_fn
ray_device_key: str = "GPU"
@@ -88,7 +94,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
@classmethod
def get_current_memory_usage(
- cls, device: Optional[torch.types.Device] = None
+ cls, device: torch.types.Device | None = None
) -> float:
cls.torch_device_fn.empty_cache()
cls.torch_device_fn.reset_peak_memory_stats(device)
@@ -120,6 +126,9 @@ def is_pin_memory_available(cls):
def import_kernels(cls) -> None:
"""Import device-specific kernels."""
logger.info(f"current vendor_name is: {cls.vendor_name}")
+ # Always load base vLLM C extensions
+ super().import_kernels()
+
if cls.vendor_name == "metax":
try:
import mcoplib._C # noqa: F401
@@ -209,9 +218,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
@classmethod
def get_attn_backend_cls(
cls,
- selected_backend: "AttentionBackendEnum",
+ selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
- ) -> list[str]:
+ num_heads: int | None = None,
+ ) -> str:
"""Get the attention backend class path using the dispatch mechanism."""
from vllm_fl.dispatch import call_op
@@ -245,8 +255,8 @@ def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
- backend: Optional["AttentionBackendEnum"] = None,
- ) -> list[str]:
+ backend: "AttentionBackendEnum | None" = None,
+ ) -> "AttentionBackendEnum":
from vllm_fl.attention.utils import patch_mm_encoder_attention
patch_mm_encoder_attention()
@@ -335,10 +345,31 @@ def use_custom_allreduce(cls) -> bool:
return True
@classmethod
- def pre_register_and_update(cls, parser = None) -> None:
+ def pre_register_and_update(cls, parser=None) -> None:
if cls.device_name == "npu":
import vllm_fl.dispatch.backends.vendor.ascend
+ def supports_fp8(cls) -> bool:
+ if cls.vendor_name == "nvidia":
+ return True
+ return False
+
+ @classmethod
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
+ return cls.torch_device_fn.get_device_properties(
+ device_id
+ ).total_memory
+
+ @classmethod
+ def use_custom_op_collectives(cls) -> bool:
+ if cls.vendor_name == "nvidia":
+ return True
+ return False
+
+ @classmethod
+ def num_compute_units(cls, device_id: int = 0) -> int:
+ return cls.torch_device_fn.get_device_properties(device_id).multi_processor_count
+
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
diff --git a/vllm_fl/utils.py b/vllm_fl/utils.py
index bb2b2d38..73a68b93 100644
--- a/vllm_fl/utils.py
+++ b/vllm_fl/utils.py
@@ -10,7 +10,6 @@
_OP_CONFIG: Optional[dict[str, str]] = None
-
# Mapping used by dispatch registration to resolve the current runtime platform
# into a backend directory under dispatch/backends/vendor.
#
diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py
index 155c0316..e5f552b7 100644
--- a/vllm_fl/worker/model_runner.py
+++ b/vllm_fl/worker/model_runner.py
@@ -1,19 +1,20 @@
-# Mainly adopted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py
+# Copyright (c) 2025 BAAI. All rights reserved.
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py
# Below is the original copyright:
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import os
import functools
import gc
import itertools
+import threading
import time
from collections import defaultdict
-from collections.abc import Iterator, Sequence
-from contextlib import contextmanager, nullcontext
+from collections.abc import Callable, Iterable, Iterator, Sequence
+from contextlib import contextmanager
from copy import copy, deepcopy
+from dataclasses import dataclass, replace
from functools import reduce
-from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np
@@ -23,22 +24,18 @@
from tqdm import tqdm
import vllm.envs as envs
-from vllm.attention.backends.abstract import(
- AttentionBackend,
- AttentionMetadata,
- AttentionType,
- MultipleOf)
-from vllm.attention.layer import Attention, MLAAttention
from vllm.compilation.counter import compilation_counter
-from vllm.compilation.cuda_graph import CUDAGraphStat #CUDAGraphWrapper
+from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationMode,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
+ set_current_vllm_config,
update_config,
)
+from vllm.config.cache import CacheConfig
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
@@ -47,22 +44,32 @@
get_dcp_group,
get_pp_group,
get_tp_group,
+ graph_capture,
is_global_first_rank,
prepare_communication_buffer_for_model,
- GraphCaptureContext
)
from vllm.forward_context import (
BatchDescriptor,
set_forward_context,
)
from vllm.logger import init_logger
+from vllm.lora.layers import LoRAMapping, LoRAMappingType
+from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
+ RoutedExpertsCapturer,
+)
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding,
XDRotaryEmbedding,
)
-from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
+from vllm.model_executor.model_loader import get_model_loader
+from vllm.model_executor.model_loader.reload import (
+ finalize_layerwise_reload,
+ initialize_layerwise_reload,
+)
from vllm.model_executor.models.interfaces import (
+ MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
SupportsXDRoPE,
@@ -70,6 +77,7 @@
supports_eagle3,
supports_mrope,
supports_multimodal_pruning,
+ supports_realtime,
supports_transcription,
supports_xdrope,
)
@@ -78,26 +86,30 @@
is_pooling_model,
is_text_generation_model,
)
+from vllm.model_executor.offloader import (
+ create_offloader,
+ get_offloader,
+ set_offloader,
+)
from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
PlaceholderRange,
)
-from vllm.multimodal.utils import group_mm_kwargs_by_modality
+from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
+from vllm.tracing import instrument
from vllm.utils import length_from_prompt_token_ids_or_embeds
-from vllm.utils.jsontree import json_map_leaves
from vllm.utils.math_utils import cdiv, round_up
-from vllm.utils.mem_constants import GiB_bytes
-from vllm.utils.mem_utils import DeviceMemoryProfiler
+from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.nvtx_pytorch_hooks import PytHooks
-from vllm.utils.platform_utils import is_pin_memory_available
-
+from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.platforms import current_platform
if current_platform.dist_backend == "flagcx":
@contextmanager
@@ -135,18 +147,23 @@ def graph_capture(device: torch.device):
from vllm.utils.torch_utils import (
get_dtype_size,
kv_cache_dtype_str_to_dtype,
- supports_dynamo,
)
-from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
-from vllm.v1.attention.backends.utils import (
+from vllm.v1.attention.backend import (
+ AttentionBackend,
AttentionCGSupport,
+ AttentionMetadata,
AttentionMetadataBuilder,
+ AttentionType,
CommonAttentionMetadata,
+)
+from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
+from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadataBuilder
+from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend,
get_dcp_local_seq_lens,
reorder_batch_to_split_decodes_and_prefills,
- split_attn_metadata,
)
+from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (
AttentionSpec,
@@ -180,16 +197,28 @@ def graph_capture(device: torch.device):
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
+from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
+from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
-from vllm.v1.spec_decode.ngram_proposer import NgramProposer
+from vllm.v1.spec_decode.ngram_proposer_gpu import (
+ NgramProposerGPU,
+ copy_num_valid_draft_tokens,
+ update_ngram_gpu_tensors_incremental,
+ update_scheduler_for_invalid_drafts,
+)
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
-from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
+from vllm.v1.worker import mamba_utils
+from vllm.v1.worker.cp_utils import (
+ check_attention_cp_compatibility,
+ get_total_cp_world_size,
+)
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
+from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -198,22 +227,27 @@ def graph_capture(device: torch.device):
UBatchSlices,
check_ubatch_thresholds,
maybe_create_ubatch_slices,
+ split_attn_metadata,
)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.workspace import lock_workspace
from vllm.v1.worker.utils import (
AttentionGroup,
- MultiModalBudget,
+ KVBlockZeroer,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
+ prepare_kernel_block_sizes,
sanity_check_mm_encoder_outputs,
)
if TYPE_CHECKING:
- from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
+
+logger = init_logger(__name__)
+# FL-specific imports
from vllm_fl.compilation.graph import GraphWrapper
from vllm_fl.dispatch.io_common import managed_inference_mode
from vllm_fl.dispatch.io_dumper import (
@@ -221,9 +255,7 @@ def graph_capture(device: torch.device):
init_io_dump_from_env,
register_io_module_hooks,
)
-
-logger = init_logger(__name__)
-
+GraphWrapper = GraphWrapper
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
@@ -238,7 +270,7 @@ def __init__(
sampled_token_ids: torch.Tensor,
logprobs_tensors: LogprobsTensors | None,
invalid_req_indices: list[int],
- async_output_copy_stream: current_platform.torch_device_fn.Stream | None,
+ async_output_copy_stream: current_platform.torch_device_fn.Stream,
vocab_size: int,
):
self._model_runner_output = model_runner_output
@@ -258,7 +290,7 @@ def __init__(
with current_platform.torch_device_fn.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
- 'cpu', non_blocking=True
+ "cpu", non_blocking=True
)
self._logprobs_tensors_cpu = (
self._logprobs_tensors.to_cpu_nonblocking()
@@ -282,22 +314,106 @@ def get_output(self) -> ModelRunnerOutput:
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
- cu_num_tokens = None
+ logprobs_lists = None
+ if self._logprobs_tensors_cpu is not None:
+ logprobs_lists = self._logprobs_tensors_cpu.tolists()
else:
- valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
+ valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
- return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
+ logprobs_tensors=self._logprobs_tensors_cpu,
)
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
- if self._logprobs_tensors_cpu:
- output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
+ output.logprobs = logprobs_lists
return output
+def _copy_pooler_output_to_cpu(
+ raw_pooler_output: PoolerOutput, finished_mask: list[bool]
+) -> list[torch.Tensor | None]:
+ num_reqs = len(finished_mask)
+
+ if isinstance(raw_pooler_output, torch.Tensor):
+ if raw_pooler_output.shape[0] != num_reqs:
+ raise ValueError(
+ "Pooler output batch size does not match finished mask size: "
+ f"{raw_pooler_output.shape[0]} != {num_reqs}."
+ )
+
+ num_finished = sum(finished_mask)
+ if num_finished == 0:
+ return [None] * num_reqs
+ if num_finished == num_reqs:
+ return list(raw_pooler_output.to("cpu", non_blocking=True))
+
+ # partial finished
+ finished_indices = [i for i, include in enumerate(finished_mask) if include]
+ index_tensor = torch.tensor(
+ finished_indices, device=raw_pooler_output.device, dtype=torch.long
+ )
+ finished_outputs = raw_pooler_output.index_select(0, index_tensor).to(
+ "cpu", non_blocking=True
+ )
+ partial_pooler_output: list[torch.Tensor | None] = [None] * num_reqs
+ for i, out in zip(finished_indices, finished_outputs):
+ partial_pooler_output[i] = out
+ return partial_pooler_output
+
+ assert isinstance(raw_pooler_output, list)
+ if len(raw_pooler_output) != num_reqs:
+ raise ValueError(
+ "Pooler output batch size does not match finished mask size: "
+ f"{len(raw_pooler_output)} != {num_reqs}."
+ )
+
+ pooler_output: list[torch.Tensor | None] = [None] * num_reqs
+ for i, (out, include) in enumerate(zip(raw_pooler_output, finished_mask)):
+ if include and out is not None:
+ pooler_output[i] = out.to("cpu", non_blocking=True)
+ return pooler_output
+
+
+class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
+ def __init__(
+ self,
+ model_runner_output: ModelRunnerOutput,
+ raw_pooler_output: PoolerOutput,
+ finished_mask: list[bool],
+ async_output_copy_stream: current_platform.torch_device_fn.Stream,
+ ):
+ self._model_runner_output = model_runner_output
+
+ # Event on the copy stream so we can synchronize the non-blocking copy.
+ self.async_copy_ready_event = torch.Event()
+
+ # Keep a reference to the device tensors to avoid them being
+ # deallocated until we finish copying it to the host.
+ self._raw_pooler_output = raw_pooler_output
+
+ # Initiate the copy on a separate stream, but do not synchronize it.
+ default_stream = current_platform.torch_device_fn.current_stream()
+ with current_platform.torch_device_fn.stream(async_output_copy_stream):
+ async_output_copy_stream.wait_stream(default_stream)
+ self._model_runner_output.pooler_output = _copy_pooler_output_to_cpu(
+ raw_pooler_output=self._raw_pooler_output,
+ finished_mask=finished_mask,
+ )
+ self.async_copy_ready_event.record()
+
+ def get_output(self) -> ModelRunnerOutput:
+ """Copy the device tensors to the host and return a ModelRunnerOutput.
+ This function blocks until the copy is finished.
+ """
+ self.async_copy_ready_event.synchronize()
+
+ # Release the device tensors once the copy has completed.
+ del self._raw_pooler_output
+ return self._model_runner_output
+
+
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
@@ -311,6 +427,7 @@ class ExecuteModelState(NamedTuple):
aux_hidden_states: list[torch.Tensor] | None
ec_connector_output: ECConnectorOutput | None
cudagraph_stats: CUDAGraphStat | None
+ slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None
class ModelRunnerFL(
@@ -324,6 +441,7 @@ def __init__(
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
+ self.offload_config = vllm_config.offload_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
@@ -332,10 +450,6 @@ def __init__(
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
- from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
-
- set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
-
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
@@ -355,6 +469,9 @@ def __init__(
)
# This will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
+ # Set to True after init_routed_experts_capturer() completes.
+ # Prevents routed experts code from running during profiling/dummy run.
+ self.routed_experts_initialized = False
self.max_model_len = model_config.max_model_len
# Always set to false after the first forward pass
@@ -366,11 +483,11 @@ def __init__(
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
- # TODO: Support overlapping mirco-batches
+ # TODO: Support overlapping micro-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend == "external_launcher"
- and len(get_pp_group().ranks) > 0
+ and len(get_pp_group().ranks) > 1
)
# Model-related.
@@ -398,10 +515,15 @@ def __init__(
else:
self.max_encoder_len = 0
+ # Async scheduling
+ self.use_async_scheduling = self.scheduler_config.async_scheduling
+
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.eplb_state: EplbState | None = None
+ # NOTE(yongji): flag to temporarily disable EPLB during scaling up/down
+ self.eep_eplb_suppressed = False
"""
State of the expert parallelism load balancer.
@@ -421,6 +543,7 @@ def __init__(
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
+ self.late_interaction_runner = LateInteractionRunner()
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@@ -429,10 +552,41 @@ def __init__(
# layers in the draft model.
if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: (
- NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
+ NgramProposer # noqa: F823
+ | NgramProposerGPU
+ | SuffixDecodingProposer
+ | EagleProposer
+ | DraftModelProposer
+ | MedusaProposer
+ | ExtractHiddenStatesProposer
)
if self.speculative_config.method == "ngram":
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
+
self.drafter = NgramProposer(self.vllm_config)
+ elif self.speculative_config.uses_draft_model():
+ self.drafter = DraftModelProposer(
+ vllm_config=self.vllm_config,
+ device=self.device,
+ runner=self,
+ )
+ elif self.speculative_config.use_ngram_gpu():
+ self.drafter = NgramProposerGPU(self.vllm_config, self.device, self)
+ self.num_tokens_no_spec_gpu = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device=device
+ )
+ self.token_ids_gpu_tensor = torch.zeros(
+ self.max_num_reqs,
+ self.max_model_len,
+ dtype=torch.int32,
+ device=device,
+ )
+ self._ngram_pinned_idx_buf = torch.zeros(
+ self.max_num_reqs, dtype=torch.long, pin_memory=True
+ )
+ self._ngram_pinned_val_buf = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, pin_memory=True
+ )
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
@@ -445,15 +599,26 @@ def __init__(
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
)
+ elif self.speculative_config.method == "extract_hidden_states":
+ self.drafter = ExtractHiddenStatesProposer(
+ vllm_config=self.vllm_config, device=self.device
+ )
+ self.use_aux_hidden_state_outputs = True
else:
raise ValueError(
"Unknown speculative decoding method: "
f"{self.speculative_config.method}"
)
self.rejection_sampler = RejectionSampler(self.sampler)
+
self.num_spec_tokens = 0
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
+ draft_config = self.speculative_config.draft_model_config
+ if draft_config is not None and draft_config.max_model_len is not None:
+ self.effective_drafter_max_model_len = draft_config.max_model_len
+ else:
+ self.effective_drafter_max_model_len = self.max_model_len
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@@ -475,17 +640,22 @@ def __init__(
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
tuple(logits_processors) if logits_processors is not None else ()
)
+ placeholder_block_size = (
+ self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE
+ )
+ self._init_block_sizes = [placeholder_block_size]
+ self._init_kernel_block_sizes = [placeholder_block_size]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
- # We need to use the encoder length for encoder-decoer
+ # We need to use the encoder length for encoder-decoder
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
- block_sizes=[self.cache_config.block_size],
- kernel_block_sizes=[self.cache_config.block_size],
+ block_sizes=[placeholder_block_size],
+ kernel_block_sizes=[placeholder_block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config,
@@ -501,13 +671,14 @@ def __init__(
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
- self.use_async_scheduling = self.scheduler_config.async_scheduling
- self.async_output_copy_stream = current_platform.torch_device_fn.Stream() if \
- self.use_async_scheduling else None
+ # Separate cuda stream for overlapping transfer of sampled token ids from
+ # GPU to CPU when async scheduling is enabled.
+ self.async_output_copy_stream: current_platform.torch_device_fn.Stream | None = None
# cuda event to synchronize use of reused CPU tensors between steps
# when async scheduling is enabled.
self.prepare_inputs_event: torch.Event | None = None
if self.use_async_scheduling:
+ self.async_output_copy_stream = current_platform.torch_device_fn.Stream()
self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
@@ -518,6 +689,16 @@ def __init__(
self.cudagraph_batch_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
+ else:
+ self.cudagraph_batch_sizes = []
+
+ # Cache the device properties.
+ self._init_device_properties()
+
+ # Encoder timing registry for observability
+ self.encoder_timing_registry: dict[str, EncoderTimingStats] = {}
+ self._encoder_timing_lock = threading.Lock()
+
# Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
@@ -546,9 +727,16 @@ def __init__(
self.num_accepted_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int64
)
+
# Only relevant for multimodal models
if self.supports_mm_inputs:
- self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
+ # Double buffer to avoid race condition: previous iteration's async
+ # copy may still be reading from CPU while current iteration writes.
+ self.is_mm_embed_buffers = [
+ self._make_buffer(self.max_num_tokens, dtype=torch.bool),
+ self._make_buffer(self.max_num_tokens, dtype=torch.bool),
+ ]
+ self.is_mm_embed_idx = 0
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@@ -602,11 +790,7 @@ def __init__(
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (
- MultiModalBudget(
- self.model_config,
- self.scheduler_config,
- self.mm_registry,
- )
+ MultiModalBudget(self.vllm_config, self.mm_registry)
if self.supports_mm_inputs
else None
)
@@ -620,6 +804,22 @@ def __init__(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
+ # N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
+ self._num_valid_draft_tokens: torch.Tensor | None = None
+ self._num_valid_draft_tokens_cpu: torch.Tensor | None = None
+ self._num_valid_draft_tokens_event: current_platform.torch_device_fn.Event | None = None
+ self._num_valid_draft_tokens_copy_stream: current_platform.torch_device_fn.Stream | None = None
+ if (
+ self.speculative_config is not None
+ and self.speculative_config.use_ngram_gpu()
+ ):
+ self._num_valid_draft_tokens_cpu = torch.empty(
+ self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
+ )
+ self._num_valid_draft_tokens_event = current_platform.torch_device_fn.Event()
+ self._num_valid_draft_tokens_copy_stream = current_platform.torch_device_fn.Stream()
+
+ self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event()
# TODO(yxa): NPU uses int32, CUDA uses int64 for sampled token ids
sampled_ids_dtype = torch.int32 if current_platform.device_type == "npu" else torch.int64
@@ -629,28 +829,73 @@ def __init__(
device="cpu",
pin_memory=self.pin_memory,
)
+
# Pre-allocated tensor for copying valid sampled token counts to CPU,
# with dedicated stream for overlapping and event for coordination.
self.valid_sampled_token_count_event: torch.Event | None = None
- self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
- if self.use_async_scheduling and self.num_spec_tokens:
- self.valid_sampled_token_count_event = torch.Event()
- self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
- self.valid_sampled_token_count_cpu = torch.empty(
- self.max_num_reqs,
- dtype=torch.int64,
- device="cpu",
- pin_memory=self.pin_memory,
- )
+ self.valid_sampled_token_count_copy_stream: current_platform.torch_device_fn.Stream | None = None
+ # We also copy the drafted tokens to the CPU asynchronously,
+ # in case we need them for structured outputs.
+ self.draft_token_ids_event: torch.Event | None = None
+ self.draft_token_ids_copy_stream: current_platform.torch_device_fn.Stream | None = None
+ self.valid_sampled_token_count_cpu: torch.Tensor | None = None
+ self.draft_token_ids_cpu: torch.Tensor | None = None
+ self.num_accepted_tokens_event: torch.Event | None = None
+ if self.num_spec_tokens:
+ self.draft_token_ids_event = torch.Event()
+ self.num_accepted_tokens_event = torch.Event()
+ self.draft_token_ids_copy_stream = current_platform.torch_device_fn.Stream()
+ self.draft_token_ids_cpu = torch.empty(
+ (self.max_num_reqs, self.num_spec_tokens),
+ dtype=torch.int64,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+ if self.use_async_scheduling:
+ self.valid_sampled_token_count_event = torch.Event()
+ self.valid_sampled_token_count_copy_stream = current_platform.torch_device_fn.Stream()
+ self.valid_sampled_token_count_cpu = torch.empty(
+ self.max_num_reqs,
+ dtype=sampled_ids_dtype,
+ device="cpu",
+ pin_memory=self.pin_memory,
+ )
+
+ # Model weight offloader
+ # Make sure this is called before any get_offloader call
+ set_offloader(create_offloader(self.offload_config))
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None
+ self.mamba_state_idx: dict[str, int] = {}
+ self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
self.layerwise_nvtx_hooks_registered = False
+ def update_max_model_len(self, max_model_len: int) -> None:
+ self.max_model_len = max_model_len
+ if self.speculative_config:
+ draft_config = self.speculative_config.draft_model_config
+ if draft_config is None or draft_config.max_model_len is None:
+ self.effective_drafter_max_model_len = self.max_model_len
+
def reset_mm_cache(self) -> None:
+ """
+ Clear the multi-modal cache that was used during profiling,
+ but no longer needed during inference.
+ """
if self.mm_budget:
self.mm_budget.reset_cache()
+ self.late_interaction_runner.clear()
+
+ def reset_encoder_cache(self) -> None:
+ """Clear the GPU-side encoder cache storing vision embeddings.
+
+ This should be called when model weights are updated to ensure
+ stale embeddings computed with old weights are not reused.
+ """
+ self.encoder_cache.clear()
+ self.late_interaction_runner.clear()
@managed_inference_mode()
def init_fp8_kv_scales(self) -> None:
@@ -721,7 +966,17 @@ def _make_buffer(
with_numpy=numpy,
)
- def _init_model_kwargs(self, num_tokens: int):
+ def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers:
+ if self._mamba_copy_bufs is None:
+ self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create(
+ self.max_num_reqs,
+ self.kv_cache_config,
+ self.model.get_mamba_state_copy_func(),
+ self._make_buffer,
+ )
+ return self._mamba_copy_bufs
+
+ def _init_model_kwargs(self):
model_kwargs = dict[str, Any]()
if not self.is_pooling_model:
@@ -765,7 +1020,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
Args:
scheduler_output: The scheduler output.
"""
- # Attention free models have zero kv_cache_goups, however models
+ # Attention free models have zero kv_cache_groups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
@@ -780,6 +1035,32 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
decode_threshold=self.reorder_batch_threshold,
)
+ def _init_kv_zero_meta(self) -> None:
+ """One-time precomputation for _zero_block_ids.
+
+ Delegates to KVBlockZeroer.init_meta with the runner's state.
+ Called from gpu_worker.py outside the CuMem pool context.
+ """
+ self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
+ self._kv_block_zeroer.init_meta(
+ attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
+ kernel_block_sizes=self._kernel_block_sizes,
+ cache_dtype=self.cache_config.cache_dtype,
+ runner_only_attn_layers=self.runner_only_attn_layers,
+ static_forward_context=(self.compilation_config.static_forward_context),
+ )
+
+ def _zero_block_ids(self, block_ids: list[int]) -> None:
+ """Zero the KV cache memory for the given block IDs."""
+ if hasattr(self, "_kv_block_zeroer"):
+ self._kv_block_zeroer.zero_block_ids(block_ids)
+
+ # Note: used for model runner override.
+ def _init_device_properties(self) -> None:
+ """Initialize attributes from current_platform.torch_device_fn.get_device_properties"""
+
+ self.num_sms = num_compute_units(self.device.index)
+
# Note: used for model runner override.
def _sync_device(self) -> None:
current_platform.torch_device_fn.synchronize()
@@ -798,6 +1079,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
+ self.late_interaction_runner.on_requests_finished(
+ scheduler_output.finished_req_ids
+ )
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -807,6 +1091,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
+ # Zero GPU memory for freshly allocated cache blocks to prevent
+ # stale NaN/data from corrupting attention or SSM computation.
+ if scheduler_output.new_block_ids_to_zero:
+ self._zero_block_ids(scheduler_output.new_block_ids_to_zero)
+
# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
@@ -833,10 +1122,23 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
+ is_ngram_gpu = (
+ self.speculative_config is not None
+ and self.speculative_config.use_ngram_gpu()
+ )
+ if is_ngram_gpu:
+ ngram_gpu_new_reqs: list[CachedRequestState] = []
+
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
+ if req_id in self.requests:
+ # For streaming case only.
+ req_state = self._update_streaming_request(req_id, new_req_data)
+ reqs_to_add.append(req_state)
+ continue
+
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
@@ -872,6 +1174,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
+ self.late_interaction_runner.register_request(req_id, pooling_params)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
@@ -889,10 +1192,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state)
+ # Track new requests for ngram_gpu full tensor copy
+ if is_ngram_gpu:
+ ngram_gpu_new_reqs.append(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
+ scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
+
+ # Save scheduler-allocated spec lengths before trimming so
+ # prev_num_draft_len keeps the optimistic count for rejection correction.
+ original_num_spec_per_req: dict[str, int] = {}
+ if (
+ self.speculative_config is not None
+ and self.speculative_config.use_ngram_gpu()
+ ):
+ for req_id, toks in scheduled_spec_tokens.items():
+ original_num_spec_per_req[req_id] = len(toks)
+ update_scheduler_for_invalid_drafts(
+ self._num_valid_draft_tokens_event,
+ self._num_valid_draft_tokens_cpu,
+ scheduler_output,
+ self.input_batch.req_id_to_index,
+ )
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
@@ -906,20 +1229,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
- # prev_num_draft_len is used in async scheduling mode with
- # spec decode. it indicates if need to update num_computed_tokens
- # of the request. for example:
- # fist step: num_computed_tokens = 0, spec_tokens = [],
- # prev_num_draft_len = 0.
- # second step: num_computed_tokens = 100(prompt lenth),
- # spec_tokens = [a,b], prev_num_draft_len = 0.
- # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
- # prev_num_draft_len = 2.
- # num_computed_tokens in first step and second step does't contain
- # the spec tokens length, but in third step it contains the
- # spec tokens length. we only need to update num_computed_tokens
- # when prev_num_draft_len > 0.
- if req_state.prev_num_draft_len:
+ if req_state.prev_num_draft_len and self.use_async_scheduling:
+ # prev_num_draft_len is used in async scheduling mode with
+ # spec decode. it indicates if need to update num_computed_tokens
+ # of the request. for example:
+ # first step: num_computed_tokens = 0, spec_tokens = [],
+ # prev_num_draft_len = 0.
+ # second step: num_computed_tokens = 100(prompt length),
+ # spec_tokens = [a,b], prev_num_draft_len = 0.
+ # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
+ # prev_num_draft_len = 2.
+ # num_computed_tokens in first step and second step doesn't contain
+ # the spec tokens length, but in third step it contains the
+ # spec tokens length. we only need to update num_computed_tokens
+ # when prev_num_draft_len > 0.
if req_index is None:
req_state.prev_num_draft_len = 0
else:
@@ -930,24 +1253,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
+ if is_ngram_gpu and num_accepted > 0 and req_index is not None:
+ self.input_batch.num_tokens_no_spec[req_index] += num_accepted
+
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
- # When using PP, the scheduler sends the sampled tokens back,
- # because there's no direct communication between the first-
- # stage worker and the last-stage worker.
- new_token_ids = req_data.new_token_ids[i]
- # Add the sampled token(s) from the previous step (if any).
- # This doesn't include "unverified" tokens like spec tokens.
- num_new_tokens = (
- num_computed_tokens + len(new_token_ids) - req_state.num_tokens
- )
- if num_new_tokens == 1:
- # Avoid slicing list in most common case.
- req_state.output_token_ids.append(new_token_ids[-1])
- elif num_new_tokens > 0:
- req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
+ if not req_data.new_token_ids:
+ # Async scheduled PP: Sampled tokens propagated via GPU broadcast.
+ new_token_ids: list[int] = []
+ else:
+ # Non-async scheduling with PP: The scheduler sends
+ # sampled token ids back because there's no direct communication
+ # between the first-stage worker and the last-stage worker.
+ new_token_ids = req_data.new_token_ids[i]
+ # Add the sampled token(s) from the previous step (if any).
+ # This doesn't include "unverified" tokens like spec tokens.
+ num_new_tokens = (
+ num_computed_tokens + len(new_token_ids) - req_state.num_tokens
+ )
+ if num_new_tokens == 1:
+ # Avoid slicing list in most common case.
+ req_state.output_token_ids.append(new_token_ids[-1])
+ elif num_new_tokens > 0:
+ req_state.output_token_ids.extend(
+ new_token_ids[-num_new_tokens:]
+ )
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
@@ -957,7 +1289,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.num_prompt_tokens[req_index]
+ num_output_tokens
)
- self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
# Update the block IDs.
@@ -985,6 +1316,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
reqs_to_add.append(req_state)
+ # Track resumed requests for ngram_gpu full tensor copy
+ if is_ngram_gpu:
+ ngram_gpu_new_reqs.append(req_state)
continue
# Update the persistent batch.
@@ -1002,46 +1336,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_index, start_token_index:end_token_index
] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
- self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
- spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
- req_id, []
- )
- num_spec_tokens = len(spec_token_ids)
- # For async scheduling, token_ids_cpu assigned from
- # spec_token_ids are placeholders and will be overwritten in
- # _prepare_input_ids.
- if num_spec_tokens:
- start_index = self.input_batch.num_tokens_no_spec[req_index]
- end_token_index = start_index + num_spec_tokens
- self.input_batch.token_ids_cpu[
- req_index, start_index:end_token_index
- ] = spec_token_ids
- # NOTE(woosuk): `num_tokens` here may include spec tokens.
- self.input_batch.num_tokens[req_index] += num_spec_tokens
-
- # When speculative decoding is used with structured output,
- # the scheduler can drop draft tokens that do not
- # conform to the schema. This can result in
- # scheduler_output.scheduled_spec_decode_tokens being empty,
- # even when speculative decoding is enabled.
- self.input_batch.spec_token_ids[req_index].clear()
- self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
-
- # there are no draft tokens with async scheduling,
- # we clear the spec_decoding info in scheduler_output and
- # use normal sampling but rejection_sampling.
- if self.use_async_scheduling:
- req_state.prev_num_draft_len = num_spec_tokens
- if num_spec_tokens and self._draft_token_ids is None:
- scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
- scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
- scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
+ self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
+ # Restore scheduler-side draft count after ngram trimming.
+ if original_num_spec_per_req:
+ orig = original_num_spec_per_req.get(req_id, 0)
+ if orig != req_state.prev_num_draft_len:
+ req_state.prev_num_draft_len = orig
+
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
self.input_batch.add_request(request)
+ self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
@@ -1050,8 +1358,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
+ # Incrementally update ngram_gpu tensors after batch is stable
+ if is_ngram_gpu:
+ update_ngram_gpu_tensors_incremental(
+ self.input_batch,
+ self.token_ids_gpu_tensor,
+ self.num_tokens_no_spec_gpu,
+ ngram_gpu_new_reqs,
+ self.device,
+ _pinned_idx_buf=self._ngram_pinned_idx_buf,
+ _pinned_val_buf=self._ngram_pinned_val_buf,
+ )
+
def _update_states_after_model_execute(
- self, output_token_ids: torch.Tensor
+ self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None:
"""Update the cached states after model execution.
@@ -1061,17 +1381,18 @@ def _update_states_after_model_execute(
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
- if not self.model_config.is_hybrid or not self.speculative_config:
+ if not self.speculative_config or not self.model_config.is_hybrid:
return
# Find the number of accepted tokens for each sequence.
- num_accepted_tokens = (
+ num_reqs = output_token_ids.size(0)
+ self.num_accepted_tokens.gpu[:num_reqs] = (
(
torch.cat(
[
output_token_ids,
torch.full(
- (output_token_ids.size(0), 1),
+ (num_reqs, 1),
-1,
device=output_token_ids.device,
),
@@ -1082,11 +1403,64 @@ def _update_states_after_model_execute(
)
.int()
.argmax(-1)
- .cpu()
- .numpy()
)
- for i, num_tokens in enumerate(num_accepted_tokens):
- self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
+ if self.cache_config.mamba_cache_mode == "align":
+ for i, num_tokens in enumerate(
+ self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy()
+ ):
+ self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
+
+ mamba_utils.postprocess_mamba(
+ scheduler_output,
+ self.kv_cache_config,
+ self.input_batch,
+ self.requests,
+ self.mamba_state_idx,
+ self.compilation_config.static_forward_context,
+ self.model.get_mamba_state_copy_func(),
+ self._get_mamba_copy_bufs(),
+ )
+ else:
+ self.input_batch.num_accepted_tokens_cpu_tensor[:num_reqs].copy_(
+ self.num_accepted_tokens.gpu[:num_reqs], non_blocking=True
+ )
+ assert self.num_accepted_tokens_event is not None
+ self.num_accepted_tokens_event.record()
+
+ def _update_streaming_request(
+ self, req_id: str, new_req_data: NewRequestData
+ ) -> CachedRequestState:
+ """Updates streaming session request from `scheduled_new_reqs`.
+
+ Removes the request from InputBatch (if present), updates the cached
+ state, and prepares it for re-addition to the batch.
+
+ NOTE: prompt_token_ids includes intermediate output tokens - tokens
+ previously generated but now are input context (part of the prompt).
+ """
+ self.input_batch.remove_request(req_id)
+ req_state = self.requests[req_id]
+
+ req_state.prompt_token_ids = new_req_data.prompt_token_ids
+ req_state.mm_features = new_req_data.mm_features
+ req_state.prompt_embeds = new_req_data.prompt_embeds
+ req_state.sampling_params = new_req_data.sampling_params
+ req_state.pooling_params = new_req_data.pooling_params
+ self.late_interaction_runner.register_request(req_id, req_state.pooling_params)
+ req_state.block_ids = new_req_data.block_ids
+ req_state.num_computed_tokens = new_req_data.num_computed_tokens
+ req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
+ req_state.prompt_token_ids, req_state.prompt_embeds
+ )
+
+ # Clear `output_token_ids` as previous output tokens are now part of
+ # `prompt_token_ids`.
+ req_state.output_token_ids.clear()
+
+ if self.uses_mrope:
+ self._init_mrope_positions(req_state)
+
+ return req_state
def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
@@ -1123,22 +1497,20 @@ def _extract_mm_kwargs(
if not scheduler_output or not self.is_multimodal_raw_input_only_model:
return {}
- mm_kwargs = list[MultiModalKwargsItem]()
+ mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
for req in scheduler_output.scheduled_new_reqs:
for feature in req.mm_features:
if feature.data is not None:
- mm_kwargs.append(feature.data)
+ mm_kwargs.append((feature.modality, feature.data))
# Input all modalities at once
- model = cast(SupportsMultiModal, self.model)
mm_kwargs_combined: BatchedTensorInputs = {}
- for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
+ for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
- merge_by_field_config=model.merge_by_field_config,
):
- mm_kwargs_combined.update(mm_kwargs_group)
+ mm_kwargs_combined.update(mm_kwargs_batch)
return mm_kwargs_combined
@@ -1149,6 +1521,9 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
mm_budget = self.mm_budget
assert mm_budget is not None
+ if not mm_budget.mm_max_toks_per_item:
+ return {} # No tower modalities (embed-only mode)
+
dummy_modality = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
@@ -1231,30 +1606,30 @@ def _prepare_input_ids(
prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
- num_commmon_tokens = len(sample_flattened_indices)
+ num_common_tokens = len(sample_flattened_indices)
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
- if num_commmon_tokens < total_without_spec:
+ if num_common_tokens < total_without_spec:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
- if num_commmon_tokens == 0:
+ if num_common_tokens == 0:
# No requests in common with the previous iteration
# So input_ids.cpu will have all the input ids.
return
- if indices_match and max_flattened_index == (num_commmon_tokens - 1):
+ if indices_match and max_flattened_index == (num_common_tokens - 1):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of 0..N-1 so
# we can copy directly using a single slice.
- self.input_ids.gpu[:num_commmon_tokens].copy_(
- self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0],
+ self.input_ids.gpu[:num_common_tokens].copy_(
+ self.input_batch.prev_sampled_token_ids[:num_common_tokens, 0],
non_blocking=True,
)
if self.enable_prompt_embeds:
- self.is_token_ids.gpu[:num_commmon_tokens] = True
+ self.is_token_ids.gpu[:num_common_tokens] = True
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(
@@ -1286,7 +1661,6 @@ def _prepare_input_ids(
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
- self._draft_token_ids = None
self.input_ids.gpu.scatter_(
dim=0,
@@ -1299,12 +1673,14 @@ def _get_encoder_seq_lens(
num_scheduled_tokens: dict[str, int],
kv_cache_spec: KVCacheSpec,
num_reqs: int,
+ for_cudagraph_capture: bool = False,
) -> tuple[torch.Tensor | None, np.ndarray | None]:
if not isinstance(kv_cache_spec, CrossAttentionSpec):
return None, None
# Zero out buffer for padding requests that are not actually scheduled (CGs)
self.encoder_seq_lens.np[:num_reqs] = 0
+
# Build encoder_seq_lens array mapping request indices to
# encoder lengths for inputs scheduled in this batch
for req_id in num_scheduled_tokens:
@@ -1321,6 +1697,15 @@ def _get_encoder_seq_lens(
feature.mm_position.length for feature in req_state.mm_features
)
self.encoder_seq_lens.np[req_index] = encoder_input_tokens
+ if for_cudagraph_capture:
+ # During CUDA graph capture, we need to use realistic encoder lengths
+ # so that max_seqlen_k is captured with the correct value.
+ max_encoder_len = getattr(
+ self.model_config.hf_config,
+ "max_source_positions",
+ self.max_encoder_len,
+ )
+ self.encoder_seq_lens.np[:num_reqs] = max_encoder_len
self.encoder_seq_lens.copy_to_gpu(num_reqs)
encoder_seq_lens = self.encoder_seq_lens.gpu[:num_reqs]
@@ -1501,7 +1886,6 @@ def _prepare_inputs(
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
- num_draft_tokens = None
spec_decode_metadata = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else:
@@ -1518,14 +1902,11 @@ def _prepare_inputs(
) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
- num_decode_draft_tokens[req_idx] = (
- len(draft_token_ids)
- if (
- self.input_batch.num_computed_tokens_cpu[req_idx]
- >= self.input_batch.num_prompt_tokens[req_idx]
- )
- else -1
- )
+ if (
+ self.input_batch.num_computed_tokens_cpu[req_idx]
+ >= self.input_batch.num_prompt_tokens[req_idx]
+ ):
+ num_decode_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens
)
@@ -1564,6 +1945,7 @@ def _build_attention_metadata(
for_cudagraph_capture: bool = False,
num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
+ slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
@@ -1578,7 +1960,7 @@ def _build_attention_metadata(
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
- attn_metadata = [{} for _ in range(len(ubatch_slices))]
+ attn_metadata = [dict() for _ in range(len(ubatch_slices))]
if for_cudagraph_capture:
# For some attention backends (e.g. FA) with sliding window models we need
@@ -1589,6 +1971,8 @@ def _build_attention_metadata(
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode:
+ if self.num_accepted_tokens_event is not None:
+ self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
@@ -1597,7 +1981,7 @@ def _build_attention_metadata(
kv_cache_groups = self.kv_cache_config.kv_cache_groups
- def _get_block_table_and_slot_mapping(kv_cache_gid: int):
+ def _get_block_table(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
@@ -1606,24 +1990,23 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int):
dtype=torch.int32,
device=self.device,
)
- slot_mapping = torch.zeros(
- (num_tokens_padded,),
- dtype=torch.int64,
- device=self.device,
- )
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
- slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
- slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
+ return blk_table_tensor
- return blk_table_tensor, slot_mapping
+ assert slot_mappings is not None
+ block_table_gid_0 = _get_block_table(0)
+ slot_mapping_gid_0 = slot_mappings[0]
- block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
+ if self.routed_experts_initialized:
+ attn_gid = self.routed_experts_attn_gid
+ slot_mapping_attn = slot_mappings[attn_gid]
+ self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
@@ -1662,6 +2045,15 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int):
logits_indices
)
+ # Cache attention metadata builds across hybrid KV-cache groups
+ # The only thing that changes between different hybrid KV-cache groups when the
+ # same metadata builder and KVCacheSpec is the same is the block table, so we
+ # can cache the attention metadata builds and just update the block table using
+ # `builder.update_block_table` if the builder supports it.
+ cached_attn_metadata: dict[
+ tuple[KVCacheSpec, type[AttentionMetadataBuilder]], AttentionMetadata
+ ] = {}
+
def _build_attn_group_metadata(
kv_cache_gid: int,
attn_gid: int,
@@ -1669,15 +2061,22 @@ def _build_attn_group_metadata(
ubid: int | None = None,
) -> None:
attn_group = self.attn_groups[kv_cache_gid][attn_gid]
+ builder = attn_group.get_metadata_builder(ubid or 0)
+ kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
+ if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
+ kv_cache_spec = kv_cache_spec.kv_cache_specs[attn_group.layer_names[0]]
+ cache_key = (kv_cache_spec, type(builder))
+
cascade_attn_prefix_len = (
cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
if cascade_attn_prefix_lens
else 0
)
- builder = attn_group.get_metadata_builder(ubid or 0)
extra_attn_metadata_args = {}
- if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
+ if use_spec_decode and isinstance(
+ builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder)
+ ):
assert ubid is None, "UBatching not supported with GDN yet"
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded],
@@ -1690,12 +2089,23 @@ def _build_attn_group_metadata(
attn_metadata_i = builder.build_for_cudagraph_capture(
common_attn_metadata
)
+ elif (
+ cache_key in cached_attn_metadata
+ and builder.supports_update_block_table
+ ):
+ attn_metadata_i = builder.update_block_table(
+ cached_attn_metadata[cache_key],
+ common_attn_metadata.block_table_tensor,
+ common_attn_metadata.slot_mapping,
+ )
else:
attn_metadata_i = builder.build(
common_prefix_len=cascade_attn_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args,
)
+ if builder.supports_update_block_table:
+ cached_attn_metadata[cache_key] = attn_metadata_i
if ubid is None:
assert isinstance(attn_metadata, dict)
@@ -1719,15 +2129,15 @@ def _build_attn_group_metadata(
num_scheduled_tokens or {},
kv_cache_group.kv_cache_spec,
num_reqs_padded,
+ for_cudagraph_capture=for_cudagraph_capture,
)
if kv_cache_gid > 0:
- cm.block_table_tensor, cm.slot_mapping = (
- _get_block_table_and_slot_mapping(kv_cache_gid)
- )
+ cm.block_table_tensor = _get_block_table(kv_cache_gid)
+ cm.slot_mapping = slot_mappings[kv_cache_gid]
if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer):
- if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
+ if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm
else:
spec_decode_common_attn_metadata = cm
@@ -1808,7 +2218,6 @@ def _compute_cascade_attn_prefix_lens(
return cascade_attn_prefix_lens if use_cascade_attn else None
-
def _compute_cascade_attn_prefix_len(
self,
num_scheduled_tokens: np.ndarray,
@@ -1834,6 +2243,7 @@ def _compute_cascade_attn_prefix_len(
Returns:
int: Length of common prefix in tokens.
"""
+
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
if common_prefix_len == 0:
# Common case.
@@ -1892,8 +2302,6 @@ def _compute_cascade_attn_prefix_len(
and kv_cache_spec.attention_chunk_size is not None
)
assert isinstance(kv_cache_spec, AttentionSpec)
-
- ###TODO(lms): fix use of num_sms
use_cascade = attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
@@ -1902,7 +2310,7 @@ def _compute_cascade_attn_prefix_len(
use_alibi=self.use_alibi,
use_sliding_window=use_sliding_window,
use_local_attention=use_local_attention,
- num_sms=1,
+ num_sms=self.num_sms,
dcp_world_size=self.dcp_world_size,
)
return common_prefix_len if use_cascade else 0
@@ -2076,6 +2484,7 @@ def _calc_spec_decode_metadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
+ cu_num_sampled_tokens=cu_num_sampled_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
@@ -2096,42 +2505,45 @@ def _prepare_kv_sharing_fast_prefill(
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item()
)
- if (
- self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
- and num_logits <= self.cudagraph_batch_sizes[-1]
- ):
- # Use piecewise CUDA graphs.
- # Add padding to the batch size.
- num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
- else:
- num_logits_padded = num_logits
+ # Dispatch for the decoder portion of the model.
+ _, batch_desc = self.cudagraph_dispatcher.dispatch(
+ num_logits, invalid_modes={CUDAGraphMode.FULL}
+ )
+ num_logits_padded = batch_desc.num_tokens
logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
:num_logits_padded
]
return logits_indices_padded
- def _batch_mm_kwargs_from_scheduler(
+ def _batch_mm_inputs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
- ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
- """Batch multimodal kwargs from scheduled encoder inputs.
+ ) -> tuple[
+ list[str],
+ list[tuple[str, MultiModalKwargsItem]],
+ list[tuple[str, PlaceholderRange]],
+ ]:
+ """Batch multimodal inputs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
- A tuple of (mm_kwargs, req_ids_pos) where:
- - mm_kwargs: List of multimodal kwargs items to be batched
- - mm_hashes_pos: List of (mm_hash, position_info) tuples
+ A tuple of (mm_hashes, mm_kwargs, mm_lora_refs) where:
+ - mm_hashes: List of multimodal hashes for each item
+ - mm_kwargs: List of multimodal kwargs for each item
+ - mm_lora_refs: List of (req_id, placeholder_range) for each item
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
- return [], []
- # Batch the multi-modal inputs.
- mm_kwargs = list[MultiModalKwargsItem]()
- # list of tuple (mm_hash, position_info)
- mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
+ return [], [], []
+
+ mm_hashes = list[str]()
+ mm_kwargs = list[tuple[str, MultiModalKwargsItem]]()
+ # Multimodal LoRA reference info to map each multimodal item
+ # back to its request & position
+ mm_lora_refs = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
@@ -2139,23 +2551,29 @@ def _batch_mm_kwargs_from_scheduler(
mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
- mm_hash = mm_feature.identifier
- mm_kwargs.append(mm_feature.data)
- mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
- return mm_kwargs, mm_hashes_pos
+ mm_hashes.append(mm_feature.identifier)
+ mm_kwargs.append((mm_feature.modality, mm_feature.data))
+ mm_lora_refs.append((req_id, mm_feature.mm_position))
+
+ return mm_hashes, mm_kwargs, mm_lora_refs
def _execute_mm_encoder(
self, scheduler_output: "SchedulerOutput"
) -> list[torch.Tensor]:
- # Batch the multi-modal inputs using the helper method.
- mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
+ mm_hashes, mm_kwargs, mm_lora_refs = self._batch_mm_inputs_from_scheduler(
scheduler_output
)
if not mm_kwargs:
return []
+ should_time = bool(
+ self.observability_config
+ and self.observability_config.enable_mm_processor_stats
+ and scheduler_output.scheduled_encoder_inputs
+ )
+
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
@@ -2164,13 +2582,72 @@ def _execute_mm_encoder(
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
+
+ if self.lora_config and self.lora_manager.supports_tower_connector_lora():
+ # Build LoRA mappings independently for encoder inputs
+ # (encoder batch structure is different from main batch)
+ prompt_lora_mapping = []
+ token_lora_mapping = []
+ lora_requests = set()
+ encoder_token_counts = []
+
+ for req_id, pos_info in mm_lora_refs:
+ req_idx = self.input_batch.req_id_to_index[req_id]
+ lora_id = int(self.input_batch.request_lora_mapping[req_idx])
+
+ # Prefer pos_info.get_num_embeds to count precise MM embedding tokens.
+ num_tokens = self.model.get_num_mm_encoder_tokens( # type: ignore[attr-defined]
+ pos_info.get_num_embeds()
+ )
+ prompt_lora_mapping.append(lora_id)
+ token_lora_mapping.extend([lora_id] * num_tokens)
+ encoder_token_counts.append(num_tokens)
+
+ if lora_id > 0:
+ lora_request = self.input_batch.lora_id_to_lora_request.get(lora_id)
+ if lora_request is not None:
+ lora_requests.add(lora_request)
+
+ # Set tower adapter mapping
+ tower_mapping = LoRAMapping(
+ tuple(token_lora_mapping),
+ tuple(prompt_lora_mapping),
+ is_prefill=True,
+ type=LoRAMappingType.TOWER,
+ )
+ self.lora_manager.set_active_adapters(lora_requests, tower_mapping)
+
+ if hasattr(self.model, "get_num_mm_connector_tokens"):
+ post_op_counts = [
+ self.model.get_num_mm_connector_tokens(num_tokens) # type: ignore[attr-defined]
+ for num_tokens in encoder_token_counts
+ ]
+
+ connector_token_mapping = np.repeat(
+ np.array(prompt_lora_mapping, dtype=np.int32),
+ np.array(post_op_counts, dtype=np.int32),
+ )
+ connector_mapping = LoRAMapping(
+ index_mapping=tuple(connector_token_mapping.tolist()),
+ prompt_mapping=tuple(prompt_lora_mapping),
+ is_prefill=True,
+ type=LoRAMappingType.CONNECTOR,
+ )
+
+ self.lora_manager.set_active_adapters(
+ lora_requests,
+ connector_mapping,
+ )
+
encoder_outputs: list[torch.Tensor] = []
- for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
+ # Track the current index in mm_kwargs/mm_lora_refs to map groups to request IDs
+ current_item_idx = 0
+ for modality, num_items, mm_kwargs_batch in group_and_batch_mm_kwargs(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
- curr_group_outputs: list[torch.Tensor] = []
+ batch_outputs: MultiModalEmbeddings
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
@@ -2186,40 +2663,48 @@ def _execute_mm_encoder(
and modality == "video"
and num_items > 1
):
- for video_mm_kwargs_item in filter(
- lambda item: item.modality == "video", mm_kwargs
- ):
- _, _, micro_batch_mm_inputs = next(
- group_mm_kwargs_by_modality(
- [video_mm_kwargs_item],
- device=self.device,
- pin_memory=self.pin_memory,
+ batch_outputs_lst = list[torch.Tensor]()
+ for video_idx in range(num_items):
+ video_mm_kwargs_item = mm_kwargs[current_item_idx + video_idx]
+ with self.timed_encoder_operation(
+ should_time, mm_lora_refs, current_item_idx + video_idx, 1
+ ):
+ _, _, micro_batch_mm_inputs = next(
+ group_and_batch_mm_kwargs(
+ [video_mm_kwargs_item],
+ device=self.device,
+ pin_memory=self.pin_memory,
+ )
)
- )
- micro_batch_outputs = model.embed_multimodal(
- **micro_batch_mm_inputs
- )
+ micro_batch_outputs = model.embed_multimodal(
+ **micro_batch_mm_inputs
+ )
+
+ batch_outputs_lst.extend(micro_batch_outputs)
- curr_group_outputs.extend(micro_batch_outputs)
+ batch_outputs = batch_outputs_lst
else:
# Run the encoder.
- # `curr_group_outputs` is either of the following:
+ # `batch_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
- curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
- sanity_check_mm_encoder_outputs(
- curr_group_outputs,
- expected_num_items=num_items,
- )
- encoder_outputs.extend(curr_group_outputs)
+ with self.timed_encoder_operation(
+ should_time, mm_lora_refs, current_item_idx, num_items
+ ):
+ batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
+
+ sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
+ encoder_outputs.extend(batch_outputs)
+
+ current_item_idx += num_items
# Cache the encoder outputs by mm_hash
- for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
+ for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
@@ -2233,8 +2718,13 @@ def _gather_mm_embeddings(
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ # Swap to the other buffer to avoid race condition with previous
+ # iteration's async copy that may still be reading from CPU.
+ self.is_mm_embed_idx = 1 - self.is_mm_embed_idx
+ is_mm_embed_buf = self.is_mm_embed_buffers[self.is_mm_embed_idx]
+
mm_embeds = list[torch.Tensor]()
- is_mm_embed = self.is_mm_embed.cpu
+ is_mm_embed = is_mm_embed_buf.cpu
is_mm_embed[:total_num_scheduled_tokens] = False
req_start_idx = 0
@@ -2290,9 +2780,15 @@ def _gather_mm_embeddings(
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
- is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
- True if is_embed is None else is_embed
- )
+ # OR mask for overlapping mm_features (use_audio_in_video)
+ if is_embed is None:
+ is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
+ True
+ )
+ else:
+ is_mm_embed[
+ req_start_pos + start_idx : req_start_pos + end_idx
+ ] |= is_embed
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
@@ -2312,7 +2808,7 @@ def _gather_mm_embeddings(
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens
- is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
+ is_mm_embed = is_mm_embed_buf.copy_to_gpu(total_num_scheduled_tokens)
if should_sync_mrope_positions:
self._calc_mrope_positions(scheduler_output)
@@ -2325,8 +2821,10 @@ def _gather_mm_embeddings(
return mm_embeds, is_mm_embed
def get_model(self) -> nn.Module:
- # get raw model out of the cudagraph wrapper.
+ if not hasattr(self, "model"):
+ raise ValueError("Cannot get model before model has been initialized")
if isinstance(self.model, (GraphWrapper, UBatchWrapper)):
+ # get raw model out of the cudagraph wrapper.
return self.model.unwrap()
return self.model
@@ -2343,6 +2841,9 @@ def get_supported_generation_tasks(self) -> list[GenerationTask]:
supported_tasks.append("transcription")
+ if supports_realtime(model):
+ supported_tasks.append("realtime")
+
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
@@ -2405,7 +2906,7 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
"""
Step for the EPLB (Expert Parallelism Load Balancing) state.
"""
- if not self.parallel_config.enable_eplb:
+ if not self.parallel_config.enable_eplb or self.eep_eplb_suppressed:
return
assert self.eplb_state is not None
@@ -2417,59 +2918,104 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
log_stats=self.parallel_config.eplb_config.log_balancedness,
)
+ def setup_eplb_from_mapping(
+ self,
+ expanded_physical_to_logical: torch.Tensor,
+ old_num_physical_experts: int,
+ ) -> None:
+ model = self.get_model()
+ assert is_mixture_of_experts(model)
+
+ self.eplb_state = EplbState.from_mapping(
+ model=model,
+ model_config=self.model_config,
+ device=self.device,
+ parallel_config=self.parallel_config,
+ expanded_physical_to_logical=expanded_physical_to_logical,
+ num_valid_physical_experts=old_num_physical_experts,
+ )
+
def _pool(
self,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
- ) -> ModelRunnerOutput:
- assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), (
+ kv_connector_output: KVConnectorOutput | None,
+ ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
+ num_reqs = self.input_batch.num_reqs
+ assert num_reqs == len(self.input_batch.pooling_params), (
"Either all or none of the requests in a batch must be pooling request"
)
hidden_states = hidden_states[:num_scheduled_tokens]
- seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
+ seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(
- num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
+ num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device
)
model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler(
- hidden_states=hidden_states,
- pooling_metadata=pooling_metadata,
- )
- raw_pooler_output = json_map_leaves(
- lambda x: x.to("cpu", non_blocking=True) if x is not None else x,
- raw_pooler_output,
+ hidden_states=hidden_states, pooling_metadata=pooling_metadata
)
- self._sync_device()
-
- pooler_output: list[torch.Tensor | None] = []
- for raw_output, seq_len, prompt_len in zip(
- raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens
- ):
- output = raw_output if seq_len == prompt_len else None
- pooler_output.append(output)
- return ModelRunnerOutput(
+ finished_mask = [
+ seq_len == prompt_len
+ for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens)
+ ]
+ raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output(
+ raw_pooler_output=raw_pooler_output,
+ pooling_params=pooling_metadata.pooling_params,
req_ids=self.input_batch.req_ids,
- req_id_to_index=self.input_batch.req_id_to_index,
- sampled_token_ids=[],
- logprobs=None,
- prompt_logprobs_dict={},
- pooler_output=pooler_output,
+ finished_mask=finished_mask,
)
- def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
- # Pad tokens to multiple of tensor_parallel_size when
- # enabled collective fusion for SP
- tp_size = self.vllm_config.parallel_config.tensor_parallel_size
- if self.compilation_config.pass_config.enable_sp and tp_size > 1:
- return round_up(num_scheduled_tokens, tp_size)
+ model_runner_output = ModelRunnerOutput(
+ req_ids=self.input_batch.req_ids.copy(),
+ req_id_to_index=self.input_batch.req_id_to_index.copy(),
+ kv_connector_output=kv_connector_output,
+ )
+
+ if raw_pooler_output is None or not any(finished_mask):
+ model_runner_output.pooler_output = [None] * num_reqs
+ return model_runner_output
+
+ if self.use_async_scheduling:
+ return AsyncGPUPoolingModelRunnerOutput(
+ model_runner_output=model_runner_output,
+ raw_pooler_output=raw_pooler_output,
+ finished_mask=finished_mask,
+ async_output_copy_stream=self.async_output_copy_stream,
+ )
+
+ model_runner_output.pooler_output = _copy_pooler_output_to_cpu(
+ raw_pooler_output=raw_pooler_output,
+ finished_mask=finished_mask,
+ )
+ self._sync_device()
+
+ return model_runner_output
+
+ def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
+ # Pad tokens to multiple of tensor_parallel_size when
+ # enabled collective fusion for SP
+ tp_size = self.vllm_config.parallel_config.tensor_parallel_size
+ if self.compilation_config.pass_config.enable_sp and tp_size > 1:
+ return round_up(num_scheduled_tokens, tp_size)
return num_scheduled_tokens
+ def _prepare_mm_inputs(
+ self, num_tokens: int
+ ) -> tuple[torch.Tensor | None, torch.Tensor]:
+ if self.model.requires_raw_input_tokens:
+ input_ids = self.input_ids.gpu[:num_tokens]
+ else:
+ input_ids = None
+
+ inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
+ return input_ids, inputs_embeds
+
def _preprocess(
self,
scheduler_output: "SchedulerOutput",
@@ -2512,10 +3058,9 @@ def _preprocess(
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
- input_ids = None
- inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
+ input_ids, inputs_embeds = self._prepare_mm_inputs(num_input_tokens)
model_kwargs = {
- **self._init_model_kwargs(num_scheduled_tokens),
+ **self._init_model_kwargs(),
**self._extract_mm_kwargs(scheduler_output),
}
elif self.enable_prompt_embeds and is_first_rank:
@@ -2543,7 +3088,7 @@ def _preprocess(
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
- model_kwargs = self._init_model_kwargs(num_input_tokens)
+ model_kwargs = self._init_model_kwargs()
input_ids = None
else:
# For text-only models, we use token ids as input.
@@ -2552,7 +3097,7 @@ def _preprocess(
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
- model_kwargs = self._init_model_kwargs(num_input_tokens)
+ model_kwargs = self._init_model_kwargs()
if self.uses_mrope:
positions = self.mrope_positions.gpu[:, :num_input_tokens]
@@ -2594,22 +3139,27 @@ def _sample(
) -> SamplerOutput:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
+ # Update output token ids with tokens sampled in last step
+ # if async scheduling and required by current sampling params.
+ self.input_batch.update_async_output_token_ids()
if spec_decode_metadata is None:
- # Update output token ids with tokens sampled in last step
- # if async scheduling and required by current sampling params.
- self.input_batch.update_async_output_token_ids()
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
+ # Update spec_token_ids with real draft tokens from pre step only when
+ # output_token_ids is needed (penalties or bad_words are in use).
+ if self.use_async_scheduling and self._draft_token_req_ids is not None:
+ draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu()
+ self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
+
sampler_output = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
logits,
sampling_metadata,
)
- self._update_states_after_model_execute(sampler_output.sampled_token_ids)
return sampler_output
def _bookkeeping_sync(
@@ -2651,7 +3201,7 @@ def _bookkeeping_sync(
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
- cu_num_tokens: list[int] | None = None
+ logprobs_lists = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
@@ -2661,13 +3211,16 @@ def _bookkeeping_sync(
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
+
+ if logprobs_tensors is not None:
+ logprobs_lists = logprobs_tensors.tolists()
else:
# Includes spec decode tokens.
- valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
+ valid_sampled_token_ids, logprobs_lists = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
- return_cu_num_tokens=logprobs_tensors is not None,
+ logprobs_tensors=logprobs_tensors,
)
else:
valid_sampled_token_ids = []
@@ -2715,18 +3268,11 @@ def _bookkeeping_sync(
self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
- self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
- logprobs_lists = (
- logprobs_tensors.tolists(cu_num_tokens)
- if not self.use_async_scheduling and logprobs_tensors is not None
- else None
- )
-
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
@@ -2790,6 +3336,27 @@ def _model_forward(
**model_kwargs,
)
+ @staticmethod
+ def _is_uniform_decode(
+ max_num_scheduled_tokens: int,
+ uniform_decode_query_len: int,
+ num_tokens: int,
+ num_reqs: int,
+ force_uniform_decode: bool | None = None,
+ ) -> bool:
+ """
+ Checks if it's a decode batch with same amount scheduled tokens
+ across all requests.
+ """
+ return (
+ (
+ (max_num_scheduled_tokens == uniform_decode_query_len)
+ and (num_tokens == max_num_scheduled_tokens * num_reqs)
+ )
+ if force_uniform_decode is None
+ else force_uniform_decode
+ )
+
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
@@ -2803,6 +3370,7 @@ def _determine_batch_execution_and_padding(
# be improved in model runner v2)
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
+ force_num_active_loras: int | None = None,
num_encoder_reqs: int = 0,
) -> tuple[
CUDAGraphMode,
@@ -2811,14 +3379,12 @@ def _determine_batch_execution_and_padding(
torch.Tensor | None,
CUDAGraphStat | None,
]:
- num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
- uniform_decode = (
- (
- (max_num_scheduled_tokens == self.uniform_decode_query_len)
- and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
- )
- if force_uniform_decode is None
- else force_uniform_decode
+ uniform_decode = self._is_uniform_decode(
+ max_num_scheduled_tokens=max_num_scheduled_tokens,
+ uniform_decode_query_len=self.uniform_decode_query_len,
+ num_tokens=num_tokens,
+ num_reqs=num_reqs,
+ force_uniform_decode=force_uniform_decode,
)
# Encoder-decoder models only support CG for decoder_step > 0 (no enc_output
# is present). Also, chunked-prefill is disabled, so batch are uniform.
@@ -2826,46 +3392,49 @@ def _determine_batch_execution_and_padding(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
)
- has_lora = (
- len(self.input_batch.lora_id_to_lora_request) > 0
- if force_has_lora is None
- else force_has_lora
+ # Compute LoRA state for cudagraph dispatch
+ num_active_loras = (
+ force_num_active_loras
+ if force_num_active_loras is not None
+ else len(self.input_batch.lora_id_to_lora_request)
)
+ has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
+
+ num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
- dispatch_cudagraph = (
- lambda num_tokens, disable_full: self.cudagraph_dispatcher.dispatch(
+ def dispatch_cudagraph(num_tokens, disable_full=False, valid_modes=None):
+ return self.cudagraph_dispatcher.dispatch(
num_tokens=num_tokens,
has_lora=has_lora,
uniform_decode=uniform_decode,
- disable_full=disable_full,
+ num_active_loras=num_active_loras,
+ valid_modes={CUDAGraphMode.NONE} if force_eager else valid_modes,
+ invalid_modes={CUDAGraphMode.FULL} if disable_full else None,
)
- if not force_eager
- else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
- )
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
- num_tokens_padded, use_cascade_attn or has_encoder_output
+ num_tokens_padded, disable_full=use_cascade_attn or has_encoder_output
)
num_tokens_padded = batch_descriptor.num_tokens
+ if self.compilation_config.pass_config.enable_sp:
+ assert (
+ batch_descriptor.num_tokens
+ % self.vllm_config.parallel_config.tensor_parallel_size
+ == 0
+ ), (
+ "Sequence parallelism requires num_tokens to be "
+ "a multiple of tensor parallel size"
+ )
# Extra coordination when running data-parallel since we need to coordinate
# across ranks
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
- # Disable DP padding when running eager to avoid excessive padding when
- # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
- # in a P/D setup and still use CUDA graphs (enabled by this padding) on the
- # decoder.
- allow_dp_padding = (
- self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
- )
-
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.parallel_config,
allow_microbatching=allow_microbatching,
- allow_dp_padding=allow_dp_padding,
num_tokens_padded=num_tokens_padded,
uniform_decode=uniform_decode,
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
@@ -2880,7 +3449,7 @@ def _determine_batch_execution_and_padding(
# Re-dispatch with DP padding so we have the correct batch_descriptor
cudagraph_mode, batch_descriptor = dispatch_cudagraph(
num_tokens_padded,
- disable_full=synced_cudagraph_mode <= CUDAGraphMode.PIECEWISE.value,
+ valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct otherwise
# num_tokens_across_dp will no-longer be valid
@@ -2939,152 +3508,281 @@ def _register_layerwise_nvtx_hooks(self) -> None:
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
self.layerwise_nvtx_hooks_registered = True
+ def _get_slot_mappings(
+ self,
+ num_tokens_padded: int,
+ num_reqs_padded: int,
+ num_tokens_unpadded: int,
+ ubatch_slices: "UBatchSlices | None" = None,
+ ) -> tuple[
+ dict[int, torch.Tensor] | None,
+ dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
+ ]:
+ """
+ Build slot mappings in both formats needed by the system.
+
+ Args:
+ num_tokens_padded: Total number of tokens (padded)
+ num_reqs_padded: Total number of requests (padded)
+ num_tokens_unpadded: Actual number of tokens (unpadded)
+ ubatch_slices: Optional ubatch slicing info for DBO
+
+ Returns:
+ A tuple of:
+ - slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata
+ - slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext
+ """
+ if not (
+ hasattr(self, "kv_cache_config")
+ and self.kv_cache_config is not None
+ and len(self.kv_cache_config.kv_cache_groups) > 0
+ ):
+ return None, None
+
+ def _get_slot_mapping(kv_cache_gid: int):
+ assert num_reqs_padded is not None and num_tokens_padded is not None
+ kv_cache_spec = self.kv_cache_config.kv_cache_groups[
+ kv_cache_gid
+ ].kv_cache_spec
+ if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
+ slot_mapping = torch.zeros(
+ (num_tokens_padded,),
+ dtype=torch.int64,
+ device=self.device,
+ )
+ else:
+ blk_table = self.input_batch.block_table[kv_cache_gid]
+ slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
+
+ # Fill unused with -1. Needed for reshape_and_cache in full cuda
+ # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
+ slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1)
+
+ return slot_mapping
+
+ slot_mappings_by_gid = {
+ gid: _get_slot_mapping(gid)
+ for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups)
+ }
+
+ slot_mappings_by_layer: dict[str, torch.Tensor] = {}
+ for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups):
+ slot_mapping = slot_mappings_by_gid[gid]
+ for layer_name in kv_cache_group.layer_names:
+ slot_mappings_by_layer[layer_name] = slot_mapping
+
+ if ubatch_slices is not None:
+ result: list[dict[str, torch.Tensor]] = []
+ for ubatch in ubatch_slices:
+ sliced_mappings: dict[str, torch.Tensor] = {}
+ for layer_name, slot_mapping in slot_mappings_by_layer.items():
+ sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
+ result.append(sliced_mappings)
+ return slot_mappings_by_gid, result
+
+ return slot_mappings_by_gid, slot_mappings_by_layer
+
@managed_inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
- ) -> ModelRunnerOutput | IntermediateTensors | None:
-
+ ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
- # self._draft_token_ids is None when `input_fits_in_drafter=False`
- # and there is no draft tokens scheduled. so it need to update the
- # spec_decoding info in scheduler_output with async_scheduling.
- # use deepcopy to avoid the modification has influence on the
- # scheduler_output in engine core process.
- # TODO(Ronald1995): deepcopy is expensive when there is a large
- # number of requests, optimize it later.
+ if self.routed_experts_initialized:
+ capturer = RoutedExpertsCapturer.get_instance()
+ if capturer is not None:
+ capturer.clear_buffer() # noqa
+ else:
+ logger.error("RoutedExpertsCapturer not initialized.")
+
+ # If ngram_gpu is used, we need to copy the scheduler_output to avoid
+ # the modification has influence on the scheduler_output in engine core process.
+ # The replace is much faster than deepcopy.
if (
- self.use_async_scheduling
- and self.num_spec_tokens
- and self._draft_token_ids is None
+ self.speculative_config is not None
+ and self.speculative_config.use_ngram_gpu()
):
- scheduler_output = deepcopy(scheduler_output)
+ num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
+ spec_decode_tokens_copy = (
+ scheduler_output.scheduled_spec_decode_tokens.copy()
+ )
+ scheduler_output = replace(
+ scheduler_output,
+ num_scheduled_tokens=num_scheduled_tokens_copy,
+ scheduled_spec_decode_tokens=spec_decode_tokens_copy,
+ )
- num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
- with record_function_or_nullcontext("gpu_model_runner: preprocess"):
- with self.synchronize_input_prep():
- # Update persistent batch states.
- self._update_states(scheduler_output)
-
- if has_ec_transfer() and get_ec_transfer().is_producer:
- with self.maybe_get_ec_connector_output(
- scheduler_output,
- encoder_cache=self.encoder_cache,
- ) as ec_connector_output:
- self._execute_mm_encoder(scheduler_output)
- return make_empty_encoder_model_runner_output(scheduler_output)
-
- if not num_scheduled_tokens:
- if (
- self.parallel_config.distributed_executor_backend
- == "external_launcher"
- and self.parallel_config.data_parallel_size > 1
- ):
- # this is a corner case when both external launcher
- # and DP are enabled, num_scheduled_tokens could be
- # 0, and has_unfinished_requests in the outer loop
- # returns True. before returning early here we call
- # dummy run to ensure coordinate_batch_across_dp
- # is called into to avoid out of sync issues.
- self._dummy_run(1)
- if not has_kv_transfer_group():
- # Return empty ModelRunnerOutput if no work to do.
- return EMPTY_MODEL_RUNNER_OUTPUT
- return self.kv_connector_no_forward(
- scheduler_output, self.vllm_config
- )
- if self.cache_config.kv_sharing_fast_prefill:
- assert not self.num_prompt_logprobs, (
- "--kv-sharing-fast-prefill produces incorrect "
- "logprobs for prompt tokens, tokens, please disable "
- "it when the requests need prompt logprobs"
- )
+ if scheduler_output.preempted_req_ids and has_kv_transfer_group():
+ get_kv_transfer_group().handle_preemptions(
+ scheduler_output.preempted_req_ids
+ )
- num_reqs = self.input_batch.num_reqs
- req_ids = self.input_batch.req_ids
- tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
- num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
- max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
- num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
+ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ with (
+ record_function_or_nullcontext("gpu_model_runner: preprocess"),
+ self.synchronize_input_prep(),
+ ):
+ # Update persistent batch states.
+ self._update_states(scheduler_output)
- (
- logits_indices,
- spec_decode_metadata,
- ) = self._prepare_inputs(
+ if has_ec_transfer() and not get_ec_transfer().is_consumer:
+ with self.maybe_get_ec_connector_output(
scheduler_output,
- num_scheduled_tokens_np,
+ encoder_cache=self.encoder_cache,
+ ) as ec_connector_output:
+ self._execute_mm_encoder(scheduler_output)
+ return make_empty_encoder_model_runner_output(scheduler_output)
+
+ if not num_scheduled_tokens:
+ if (
+ self.parallel_config.distributed_executor_backend
+ == "external_launcher"
+ and self.parallel_config.data_parallel_size > 1
+ ):
+ # this is a corner case when both external launcher
+ # and DP are enabled, num_scheduled_tokens could be
+ # 0, and has_unfinished_requests in the outer loop
+ # returns True. before returning early here we call
+ # dummy run to ensure coordinate_batch_across_dp
+ # is called into to avoid out of sync issues.
+ self._dummy_run(1)
+ if not has_kv_transfer_group():
+ # Return empty ModelRunnerOutput if no work to do.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+ return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
+
+ if self.cache_config.kv_sharing_fast_prefill:
+ assert not self.num_prompt_logprobs, (
+ "--kv-sharing-fast-prefill produces incorrect "
+ "logprobs for prompt tokens, tokens, please disable "
+ "it when the requests need prompt logprobs"
)
- cascade_attn_prefix_lens = None
- # Disable cascade attention when using microbatching (DBO)
- if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
- # Pre-compute cascade attention prefix lengths
- cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
- num_scheduled_tokens_np,
- self.input_batch.num_computed_tokens_cpu[:num_reqs],
- scheduler_output.num_common_prefix_blocks,
- )
+ num_reqs = self.input_batch.num_reqs
+ req_ids = self.input_batch.req_ids
+ tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
+ num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
+ max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
+ num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
- (
- cudagraph_mode,
- batch_desc,
- should_ubatch,
- num_tokens_across_dp,
- cudagraph_stats,
- ) = self._determine_batch_execution_and_padding(
- num_tokens=num_tokens_unpadded,
- num_reqs=num_reqs,
- num_scheduled_tokens_np=num_scheduled_tokens_np,
- max_num_scheduled_tokens=max_num_scheduled_tokens,
- use_cascade_attn=cascade_attn_prefix_lens is not None,
- num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
- )
+ logits_indices, spec_decode_metadata = self._prepare_inputs(
+ scheduler_output,
+ num_scheduled_tokens_np,
+ )
- logger.debug(
- "Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
- "should_ubatch: %s, num_tokens_across_dp: %s",
- cudagraph_mode,
- batch_desc,
- should_ubatch,
- num_tokens_across_dp,
+ cascade_attn_prefix_lens = None
+ # Disable cascade attention when using microbatching (DBO)
+ if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
+ # Pre-compute cascade attention prefix lengths
+ cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
+ num_scheduled_tokens_np,
+ self.input_batch.num_computed_tokens_cpu[:num_reqs],
+ scheduler_output.num_common_prefix_blocks,
)
- num_tokens_padded = batch_desc.num_tokens
- num_reqs_padded = (
- batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
+ (
+ cudagraph_mode,
+ batch_desc,
+ should_ubatch,
+ num_tokens_across_dp,
+ cudagraph_stats,
+ ) = self._determine_batch_execution_and_padding(
+ num_tokens=num_tokens_unpadded,
+ num_reqs=num_reqs,
+ num_scheduled_tokens_np=num_scheduled_tokens_np,
+ max_num_scheduled_tokens=max_num_scheduled_tokens,
+ use_cascade_attn=cascade_attn_prefix_lens is not None,
+ num_encoder_reqs=len(scheduler_output.scheduled_encoder_inputs),
+ )
+
+ logger.debug(
+ "Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
+ "should_ubatch: %s, num_tokens_across_dp: %s",
+ cudagraph_mode,
+ batch_desc,
+ should_ubatch,
+ num_tokens_across_dp,
+ )
+
+ num_tokens_padded = batch_desc.num_tokens
+ num_reqs_padded = (
+ batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
+ )
+ ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
+ should_ubatch,
+ num_scheduled_tokens_np,
+ num_tokens_padded,
+ num_reqs_padded,
+ self.parallel_config.num_ubatches,
+ )
+
+ logger.debug(
+ "ubatch_slices: %s, ubatch_slices_padded: %s",
+ ubatch_slices,
+ ubatch_slices_padded,
+ )
+
+ # True if any attention backend handles KV cache update separately
+ # from forward() (i.e., forward_includes_kv_cache_update=False). When true,
+ # slot_mappings must use padded dimensions to match the key/value tensors.
+ has_separate_kv_update = not all(
+ all(
+ g.backend.forward_includes_kv_cache_update
+ for g in self.attn_groups[id]
)
- ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
- should_ubatch,
- num_scheduled_tokens_np,
- num_tokens_padded,
- num_reqs_padded,
+ for id, spec in enumerate(self.kv_cache_config.kv_cache_groups)
+ if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec)
+ )
+ pad_attn = cudagraph_mode == CUDAGraphMode.FULL
+
+ if self.cache_config.mamba_cache_mode == "align":
+ mamba_utils.preprocess_mamba(
+ scheduler_output,
+ self.kv_cache_config,
+ self.cache_config,
+ self.mamba_state_idx,
+ self.input_batch,
+ self.requests,
+ self.compilation_config.static_forward_context,
+ self.model.get_mamba_state_copy_func(),
+ self._get_mamba_copy_bufs(),
)
- pad_attn = cudagraph_mode == CUDAGraphMode.FULL
-
- use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
- ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
-
- (attn_metadata, spec_decode_common_attn_metadata) = (
- self._build_attention_metadata(
- num_tokens=num_tokens_unpadded,
- num_tokens_padded=num_tokens_padded if pad_attn else None,
- num_reqs=num_reqs,
- num_reqs_padded=num_reqs_padded if pad_attn else None,
- max_query_len=max_num_scheduled_tokens,
- ubatch_slices=ubatch_slices_attn,
- logits_indices=logits_indices,
- use_spec_decode=use_spec_decode,
- num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
- cascade_attn_prefix_lens=cascade_attn_prefix_lens,
- )
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
+ ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
+
+ slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
+ num_tokens_padded=num_tokens_padded
+ if pad_attn or has_separate_kv_update
+ else num_tokens_unpadded,
+ num_reqs_padded=(
+ num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs
+ ),
+ num_tokens_unpadded=num_tokens_unpadded,
+ ubatch_slices=ubatch_slices_padded,
+ )
+
+ attn_metadata, spec_decode_common_attn_metadata = (
+ self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded if pad_attn else None,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded if pad_attn else None,
+ max_query_len=max_num_scheduled_tokens,
+ ubatch_slices=ubatch_slices_attn,
+ logits_indices=logits_indices,
+ use_spec_decode=use_spec_decode,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
+ cascade_attn_prefix_lens=cascade_attn_prefix_lens,
+ slot_mappings=slot_mappings_by_group,
)
+ )
(
input_ids,
@@ -3105,8 +3803,18 @@ def execute_model(
# Mark KV scales as calculated after the first forward pass
self.calculate_kv_scales = False
+ # Encoder-decoder models can only compile the pure decode steps where no
+ # encoder inputs are present. Use eager for the first pass.
+ num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs)
+ has_encoder_input = (
+ self.model_config.is_encoder_decoder and num_encoder_reqs > 0
+ )
+
# Run the model.
# Use persistent buffers for CUDA graphs.
+ # When spec decode is enabled, defer connector finalization
+ # (wait_for_save + clear metadata) until after draft model runs.
+ defer_kv_connector_finalize = self.speculative_config is not None
with (
set_forward_context(
attn_metadata,
@@ -3116,9 +3824,14 @@ def execute_model(
cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
+ slot_mapping=slot_mappings,
+ skip_compiled=has_encoder_input,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
- self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
+ self.maybe_get_kv_connector_output(
+ scheduler_output,
+ defer_finalize=defer_kv_connector_finalize,
+ ) as kv_connector_output,
):
model_output = self._model_forward(
input_ids=input_ids,
@@ -3148,11 +3861,12 @@ def execute_model(
if self.is_pooling_model:
# Return the pooling output.
- output = self._pool(
- hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
+ return self._pool(
+ hidden_states,
+ num_scheduled_tokens,
+ num_scheduled_tokens_np,
+ kv_connector_output,
)
- output.kv_connector_output = kv_connector_output
- return output
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -3179,6 +3893,7 @@ def execute_model(
model_output_broadcast_data: dict[str, Any] = {}
if logits is not None:
model_output_broadcast_data["logits"] = logits.contiguous()
+
broadcasted = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
)
@@ -3195,20 +3910,21 @@ def execute_model(
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
+ slot_mappings,
)
self.kv_connector_output = kv_connector_output
-
return None
@managed_inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
- kv_connector_output = self.kv_connector_output
- self.kv_connector_output = None
-
if self.execute_model_state is None:
- # Nothing to do (PP non-final rank case), output isn't used.
+ kv_connector_output = self.kv_connector_output
+ self.kv_connector_output = None
+ # receive sampled token ids from the last PP rank.
+ if self.use_async_scheduling and get_pp_group().world_size > 1:
+ self._pp_receive_prev_sampled_token_ids_to_input_batch()
if not kv_connector_output:
return None # type: ignore[return-value]
@@ -3232,6 +3948,7 @@ def sample_tokens(
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
+ slot_mappings,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
@@ -3245,6 +3962,21 @@ def sample_tokens(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
+ self._update_states_after_model_execute(
+ sampler_output.sampled_token_ids, scheduler_output
+ )
+ if self.use_async_scheduling:
+ pp = get_pp_group()
+ # For torchrun external_launcher PP mode with broadcast_pp_output=True,
+ # PP outputs have been broadcasted to all ranks at logits computation.
+ # Therefore, here is no need to send sampled token ids again in this case.
+ if not self.broadcast_pp_output and pp.world_size > 1 and pp.is_last_rank:
+ self._pp_broadcast_prev_sampled_token_ids(
+ sampler_output.sampled_token_ids
+ )
+
+ self._draft_token_ids = None
+ self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None
def propose_draft_token_ids(sampled_token_ids):
@@ -3259,51 +3991,80 @@ def propose_draft_token_ids(sampled_token_ids):
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
+ slot_mappings,
)
+ self._copy_draft_token_ids_to_cpu(scheduler_output)
spec_config = self.speculative_config
- use_padded_batch_for_eagle = (
- spec_config is not None
- and spec_config.use_eagle()
- and not spec_config.disable_padded_drafter_batch
- )
- effective_drafter_max_model_len = self.max_model_len
- if effective_drafter_max_model_len is None:
- effective_drafter_max_model_len = self.model_config.max_model_len
- if (
- spec_config is not None
- and spec_config.draft_model_config is not None
- and spec_config.draft_model_config.max_model_len is not None
- ):
- effective_drafter_max_model_len = (
- spec_config.draft_model_config.max_model_len
+ propose_drafts_after_bookkeeping = False
+ if spec_config is not None:
+ input_fits_in_drafter = spec_decode_common_attn_metadata is not None and (
+ spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
+ <= self.effective_drafter_max_model_len
)
- input_fits_in_drafter = spec_decode_common_attn_metadata and (
- spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
- <= effective_drafter_max_model_len
- )
- if use_padded_batch_for_eagle:
- assert self.speculative_config is not None
- assert isinstance(self.drafter, EagleProposer)
- sampled_token_ids = sampler_output.sampled_token_ids
- if input_fits_in_drafter:
- # EAGLE speculative decoding can use the GPU sampled tokens
+ use_gpu_toks = (
+ spec_config.use_eagle()
+ or spec_config.uses_draft_model()
+ or spec_config.uses_extract_hidden_states()
+ ) and not spec_config.disable_padded_drafter_batch
+ if use_gpu_toks:
+ # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
- propose_draft_token_ids(sampled_token_ids)
- elif self.valid_sampled_token_count_event is not None:
- assert spec_decode_common_attn_metadata is not None
- next_token_ids, valid_sampled_tokens_count = (
- self.drafter.prepare_next_token_ids_padded(
- spec_decode_common_attn_metadata,
- sampled_token_ids,
- self.requests,
- self.input_batch,
- self.discard_request_mask.gpu,
- )
- )
- self._copy_valid_sampled_token_count(
- next_token_ids, valid_sampled_tokens_count
+ assert isinstance(
+ self.drafter,
+ EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
+ sampled_token_ids = sampler_output.sampled_token_ids
+ if input_fits_in_drafter:
+ propose_draft_token_ids(sampled_token_ids)
+ elif self.valid_sampled_token_count_event is not None:
+ assert spec_decode_common_attn_metadata is not None
+ next_token_ids, valid_sampled_tokens_count = (
+ self.drafter.prepare_next_token_ids_padded(
+ spec_decode_common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
+ self._draft_token_ids = torch.zeros(
+ 1, device=self.device, dtype=torch.int32
+ ).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
+ self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
+ elif (
+ spec_config.use_ngram_gpu()
+ and not spec_config.disable_padded_drafter_batch
+ ):
+ assert isinstance(self.drafter, NgramProposerGPU)
+ sampled_token_ids = sampler_output.sampled_token_ids
+ if input_fits_in_drafter:
+ propose_draft_token_ids(sampled_token_ids)
+ elif self.valid_sampled_token_count_event is not None:
+ assert spec_decode_common_attn_metadata is not None
+ next_token_ids, valid_sampled_tokens_count, _ = (
+ self.drafter.update_token_ids_ngram(
+ sampled_token_ids,
+ self.input_batch,
+ self.token_ids_gpu_tensor,
+ self.num_tokens_no_spec_gpu,
+ self.discard_request_mask.gpu,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
+ # Since we couldn't run the drafter,
+ # just use zeros for the draft tokens.
+ self._draft_token_ids = torch.zeros(
+ 1, device=self.device, dtype=torch.int32
+ ).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
+ self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
+ else:
+ propose_drafts_after_bookkeeping = input_fits_in_drafter
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
@@ -3323,25 +4084,38 @@ def propose_draft_token_ids(sampled_token_ids):
spec_decode_metadata,
)
- if (
- self.speculative_config
- and not use_padded_batch_for_eagle
- and input_fits_in_drafter
- ):
+ if propose_drafts_after_bookkeeping:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
+ # Finalize KV connector (wait_for_save + clear metadata) after
+ # draft model runs. Deferred from target model forward to allow
+ # draft model to also save its KV cache.
+ if spec_config is not None:
+ self.finalize_kv_connector()
+
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
+
+ # self.kv_connector_output may be modified during drafting
+ kv_connector_output = self.kv_connector_output
+ self.kv_connector_output = None
+
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
+ if self.routed_experts_initialized:
+ capturer = RoutedExpertsCapturer.get_instance()
+ if capturer is not None:
+ capturer.save_captured_experts(indices=self.slot_mapping) # noqa
+ else:
+ logger.error("RoutedExpertsCapturer not initialized.")
+
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
- pooler_output=[],
kv_connector_output=kv_connector_output,
ec_connector_output=ec_connector_output
if self.supports_mm_inputs
@@ -3350,13 +4124,12 @@ def propose_draft_token_ids(sampled_token_ids):
cudagraph_stats=cudagraph_stats,
)
- # Advance IO step after the full inference cycle (forward + sampling)
- # so that one step encompasses both execute_model and sample_tokens.
+ # FL: Advance IO step after the full inference cycle
advance_io_step()
- ### TODO(lms): abstract async schedule for all hardware
if not self.use_async_scheduling:
return output
+
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):
@@ -3380,17 +4153,95 @@ def propose_draft_token_ids(sampled_token_ids):
return async_output
+ def _pp_broadcast_prev_sampled_token_ids(
+ self, sampled_token_ids: torch.Tensor
+ ) -> None:
+ """Broadcast sampled token ids (GPU) from last PP stage"""
+ pp = get_pp_group()
+ assert pp.is_last_rank
+ # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
+ assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
+ "PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
+ )
+ torch.distributed.broadcast(
+ sampled_token_ids, src=pp.rank, group=pp.device_group
+ )
+
+ def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None:
+ """Receive sampled token ids broadcast from last PP stage"""
+ pp = get_pp_group()
+ assert not pp.is_last_rank
+ num_reqs = self.input_batch.num_reqs
+ # `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
+ recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
+ torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
+ self.input_batch.prev_sampled_token_ids = recv
+
+ # construct `prev_req_id_to_index` here so `_prepare_input_ids`
+ # can map req_id -> previous batch row
+ discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
+ discard_req_indices_set = set(discard_req_indices)
+ prev_req_id_to_index: dict[str, int] = {}
+ for i, req_id in enumerate(self.input_batch.req_ids):
+ if i in discard_req_indices_set:
+ continue
+ prev_req_id_to_index[req_id] = i
+ # PP+async scheduling: advance per-request local cached output length by
+ # appending a placeholder (-1) token id.
+ if (req_state := self.requests.get(req_id)) is not None:
+ req_state.output_token_ids.append(-1)
+ self.input_batch.prev_req_id_to_index = prev_req_id_to_index
+
def take_draft_token_ids(self) -> DraftTokenIds | None:
- if self._draft_token_ids is None:
+ if not self.num_spec_tokens or not self._draft_token_req_ids:
return None
- req_ids = self.input_batch.req_ids
- if isinstance(self._draft_token_ids, torch.Tensor):
- draft_token_ids = self._draft_token_ids.tolist()
- else:
- draft_token_ids = self._draft_token_ids
- self._draft_token_ids = None
+ draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
return DraftTokenIds(req_ids, draft_token_ids)
+ def _copy_draft_token_ids_to_cpu(
+ self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
+ ) -> None:
+ # Check if we need to copy draft tokens to CPU. In async scheduling,
+ # we only copy when needed for structured output, penalties or bad_words.
+ if self.use_async_scheduling and not (
+ scheduler_output.has_structured_output_requests
+ or self.input_batch.sampling_metadata.output_token_ids
+ ):
+ return
+ # We must also set the corresponding request ids.
+ self._draft_token_req_ids = self.input_batch.req_ids.copy()
+
+ draft_token_ids: torch.Tensor = self._draft_token_ids
+ if not torch.is_tensor(draft_token_ids):
+ return
+ assert self.draft_token_ids_event is not None
+ assert self.draft_token_ids_copy_stream is not None
+ assert self.draft_token_ids_cpu is not None
+ default_stream = current_platform.torch_device_fn.current_stream()
+ num_reqs = draft_token_ids.shape[0]
+ with current_platform.torch_device_fn.stream(self.draft_token_ids_copy_stream):
+ if not zeros_only:
+ # Trigger async copy of draft token ids to cpu.
+ self.draft_token_ids_copy_stream.wait_stream(default_stream)
+ self.draft_token_ids_cpu[:num_reqs].copy_(
+ draft_token_ids, non_blocking=True
+ )
+ else:
+ # No copy needed, just zero-out cpu tensor.
+ self.draft_token_ids_cpu[:num_reqs] = 0
+ self.draft_token_ids_event.record()
+
+ def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
+ if isinstance(self._draft_token_ids, list):
+ return self._draft_token_ids, self.input_batch.req_ids
+ req_ids = self._draft_token_req_ids
+ if req_ids is None:
+ return [], []
+ assert self.draft_token_ids_event is not None
+ assert self.draft_token_ids_cpu is not None
+ self.draft_token_ids_event.synchronize()
+ return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
+
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
) -> None:
@@ -3400,10 +4251,11 @@ def _copy_valid_sampled_token_count(
default_stream = current_platform.torch_device_fn.current_stream()
# Initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
- with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
+ with current_platform.torch_device_fn.stream(self.valid_sampled_token_count_copy_stream):
self.valid_sampled_token_count_copy_stream.wait_stream(default_stream) # type: ignore
counts = valid_sampled_tokens_count
counts_cpu = self.valid_sampled_token_count_cpu
+ assert counts_cpu is not None
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record()
@@ -3412,14 +4264,13 @@ def _copy_valid_sampled_token_count(
def _get_valid_sampled_token_count(self) -> list[int]:
# Wait until valid_sampled_tokens_count is copied to cpu,
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
- if (
- self.valid_sampled_token_count_event is None
- or prev_sampled_token_ids is None
- ):
+ sampled_count_event = self.valid_sampled_token_count_event
+ if sampled_count_event is None or prev_sampled_token_ids is None:
return []
counts_cpu = self.valid_sampled_token_count_cpu
- self.valid_sampled_token_count_event.synchronize()
+ assert counts_cpu is not None
+ sampled_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
def propose_draft_token_ids(
@@ -3432,24 +4283,65 @@ def propose_draft_token_ids(
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
+ slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
spec_config = self.speculative_config
assert spec_config is not None
if spec_config.method == "ngram":
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
+
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
- self.input_batch.req_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
- self.input_batch.spec_decode_unsupported_reqs,
+ slot_mappings=slot_mappings,
+ )
+ elif spec_config.use_ngram_gpu():
+ assert isinstance(self.drafter, NgramProposerGPU)
+ (
+ next_token_ids,
+ valid_sampled_tokens_count,
+ valid_sampled_token_ids_gpu,
+ ) = self.drafter.update_token_ids_ngram(
+ sampled_token_ids,
+ self.input_batch,
+ self.token_ids_gpu_tensor,
+ self.num_tokens_no_spec_gpu,
+ self.discard_request_mask.gpu,
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
+
+ batch_size = next_token_ids.shape[0]
+
+ draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
+ self.num_tokens_no_spec_gpu[:batch_size],
+ self.token_ids_gpu_tensor[:batch_size],
+ valid_sampled_token_ids_gpu,
+ valid_sampled_tokens_count,
+ )
+
+ # Cache valid draft counts for scheduler-side trimming.
+ self._num_valid_draft_tokens = num_valid_draft_tokens
+
+ # Async D2H copy on a dedicated stream.
+ copy_num_valid_draft_tokens(
+ self._num_valid_draft_tokens_cpu,
+ self._num_valid_draft_tokens_copy_stream,
+ self._num_valid_draft_tokens_event,
+ self._num_valid_draft_tokens,
+ self.input_batch.num_reqs,
)
elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)
- draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
+ draft_token_ids = self.drafter.propose(
+ self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
+ )
elif spec_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
@@ -3474,9 +4366,41 @@ def propose_draft_token_ids(
draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
+ slot_mappings=slot_mappings,
+ )
+ elif spec_config.uses_extract_hidden_states():
+ assert isinstance(self.drafter, ExtractHiddenStatesProposer)
+ assert isinstance(sampled_token_ids, torch.Tensor), (
+ "sampled_token_ids should be a torch.Tensor for "
+ "extract_hidden_states method."
)
- elif spec_config.use_eagle():
- assert isinstance(self.drafter, EagleProposer)
+ if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
+ raise ValueError(
+ "aux_hidden_states are required when using `extract_hidden_states`"
+ )
+ target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
+
+ draft_token_ids = self.drafter.propose(
+ sampled_token_ids=sampled_token_ids,
+ target_hidden_states=target_hidden_states,
+ common_attn_metadata=common_attn_metadata,
+ slot_mappings=slot_mappings,
+ )
+ next_token_ids, valid_sampled_tokens_count = (
+ self.drafter.prepare_next_token_ids_padded(
+ common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
+ )
+ )
+ self._copy_valid_sampled_token_count(
+ next_token_ids, valid_sampled_tokens_count
+ )
+
+ elif spec_config.use_eagle() or spec_config.uses_draft_model():
+ assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
if spec_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -3514,6 +4438,7 @@ def propose_draft_token_ids(
next_token_ids, valid_sampled_tokens_count
)
+ num_rejected_tokens_gpu = None
if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
@@ -3544,12 +4469,14 @@ def propose_draft_token_ids(
else:
target_hidden_states = hidden_states[token_indices]
else:
- common_attn_metadata, token_indices_to_sample = (
- self.drafter.prepare_inputs_padded(
- common_attn_metadata,
- spec_decode_metadata,
- valid_sampled_tokens_count,
- )
+ (
+ common_attn_metadata,
+ token_indices_to_sample,
+ num_rejected_tokens_gpu,
+ ) = self.drafter.prepare_inputs_padded(
+ common_attn_metadata,
+ spec_decode_metadata,
+ valid_sampled_tokens_count,
)
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
@@ -3563,7 +4490,7 @@ def propose_draft_token_ids(
else:
target_hidden_states = hidden_states[:total_num_tokens]
- if self.supports_mm_inputs:
+ if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
mm_embed_inputs = self._gather_mm_embeddings(
scheduler_output,
shift_computed_tokens=1,
@@ -3576,10 +4503,12 @@ def propose_draft_token_ids(
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
- last_token_indices=token_indices_to_sample,
+ token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
+ num_rejected_tokens_gpu=num_rejected_tokens_gpu,
+ slot_mappings=slot_mappings,
)
return draft_token_ids
@@ -3595,34 +4524,30 @@ def update_config(self, overrides: dict[str, Any]) -> None:
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
- def load_model(self, eep_scale_up: bool = False) -> None:
+ @instrument(span_name="Loading (GPU)")
+ def load_model(self, load_dummy_weights: bool = False) -> None:
"""
Args:
- eep_scale_up: the model loading is for elastic EP scale up.
+ load_dummy_weights: load dummy weights instead of real weights.
"""
logger.info_once(
"Starting to load model %s...",
self.model_config.model,
scope="global",
)
- global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
- EplbState.get_eep_state(self.parallel_config)
- if eep_scale_up
- else (None, None, None)
- )
if self.parallel_config.enable_eplb:
self.eplb_state = EplbState(self.parallel_config, self.device)
eplb_models = 0
- # IO dumper is only supported in eager mode. In graph mode (torch.compile)
- # TorchDispatchMode and the module forward hooks are incompatible with
- # Dynamo tracing, so IO dumping is silently skipped.
+ # FL: IO dumper init (skipped under Dynamo tracing)
init_io_dump_from_env(getattr(self.model_config, "enforce_eager", False))
try:
with DeviceMemoryProfiler() as m:
time_before_load = time.perf_counter()
+ if load_dummy_weights:
+ self.load_config.load_format = "dummy"
model_loader = get_model_loader(self.load_config)
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
@@ -3639,6 +4564,9 @@ def load_model(self, eep_scale_up: bool = False) -> None:
and is_mixture_of_experts(self.drafter.model)
and self.parallel_config.enable_eplb
):
+ assert not self.parallel_config.enable_elastic_ep, (
+ "Elastic EP is not supported with drafter model."
+ )
spec_config = self.vllm_config.speculative_config
assert spec_config is not None
assert spec_config.draft_model_config is not None
@@ -3646,17 +4574,6 @@ def load_model(self, eep_scale_up: bool = False) -> None:
"EPLB is enabled for drafter model %s.",
spec_config.draft_model_config.model,
)
-
- global_expert_load = (
- global_expert_loads[eplb_models]
- if global_expert_loads
- else None
- )
- old_global_expert_indices = (
- old_global_expert_indices_per_model[eplb_models]
- if old_global_expert_indices_per_model
- else None
- )
if self.eplb_state is None:
self.eplb_state = EplbState(
self.parallel_config, self.device
@@ -3664,9 +4581,6 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.eplb_state.add_model(
self.drafter.model,
spec_config.draft_model_config,
- global_expert_load,
- old_global_expert_indices,
- rank_mapping,
)
eplb_models += 1
@@ -3686,7 +4600,9 @@ def load_model(self, eep_scale_up: bool = False) -> None:
aux_layers,
)
else:
- aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
+ aux_layers = (
+ self.model.get_eagle3_default_aux_hidden_state_layers()
+ )
self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter()
@@ -3706,20 +4622,19 @@ def load_model(self, eep_scale_up: bool = False) -> None:
logger.error(combined_msg)
raise e
logger.info_once(
- "Model loading took %.4f GiB memory and %.6f seconds",
- self.model_memory_usage / GiB_bytes,
+ "Model loading took %s GiB memory and %.6f seconds",
+ format_gib(self.model_memory_usage),
time_after_load - time_before_load,
scope="local",
)
-
- # IO dumper: register module paths and install module context hooks.
- register_io_module_hooks(self.model)
-
- prepare_communication_buffer_for_model(self.model)
- if (drafter := getattr(self, "drafter", None)) and (
- drafter_model := getattr(drafter, "model", None)
- ):
- prepare_communication_buffer_for_model(drafter_model)
+ if not load_dummy_weights:
+ prepare_communication_buffer_for_model(self.model)
+ # FL: register IO dumper module hooks
+ register_io_module_hooks(self.model)
+ if (drafter := getattr(self, "drafter", None)) and (
+ drafter_model := getattr(drafter, "model", None)
+ ):
+ prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())
@@ -3727,49 +4642,42 @@ def load_model(self, eep_scale_up: bool = False) -> None:
and mm_config.is_multimodal_pruning_enabled()
)
- if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
+ if (
+ is_mixture_of_experts(self.model)
+ and self.parallel_config.enable_eplb
+ and not load_dummy_weights
+ ):
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
- global_expert_load = (
- global_expert_loads[eplb_models] if global_expert_loads else None
- )
- old_global_expert_indices = (
- old_global_expert_indices_per_model[eplb_models]
- if old_global_expert_indices_per_model
- else None
- )
assert self.eplb_state is not None
self.eplb_state.add_model(
self.model,
self.model_config,
- global_expert_load,
- old_global_expert_indices,
- rank_mapping,
)
if self.eplb_state.is_async:
- self.eplb_state.start_async_loop(rank_mapping=rank_mapping)
-
- # print(f"{self.vllm_config.compilation_config.mode=}")
+ self.eplb_state.start_async_loop()
if (
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
- and supports_dynamo()
):
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
compilation_counter.stock_torch_compile_count += 1
self.model.compile(fullgraph=True, backend=backend)
return
# for other compilation modes, cudagraph behavior is controlled by
- # CudagraphWraper and CudagraphDispatcher of vllm.
+ # CudagraphWrapper and CudagraphDispatcher of vllm.
# wrap the model with full cudagraph wrapper if needed.
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
- if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
+ if (
+ cudagraph_mode.has_full_cudagraphs()
+ and not self.parallel_config.use_ubatching
+ ):
self.model = GraphWrapper(
self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
)
- elif self.parallel_config.enable_dbo:
+ elif self.parallel_config.use_ubatching:
if cudagraph_mode.has_full_cudagraphs():
self.model = UBatchWrapper(
self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
@@ -3779,6 +4687,8 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
)
+ get_offloader().post_init()
+
def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
"""Extract Eagle3 auxiliary layer indices from speculative config.
@@ -3803,23 +4713,89 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
return None
- def reload_weights(self) -> None:
- assert getattr(self, "model", None) is not None, (
- "Cannot reload weights before model is loaded."
- )
- model_loader = get_model_loader(self.load_config)
- logger.info("Reloading weights inplace...")
- model_loader.load_weights(self.get_model(), model_config=self.model_config)
-
- def save_tensorized_model(
+ def reload_weights(
self,
- tensorizer_config: "TensorizerConfig",
+ weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None,
+ weights_path: str | None = None,
+ is_checkpoint_format: bool = True,
) -> None:
- TensorizerLoader.save_model(
- self.get_model(),
- tensorizer_config=tensorizer_config,
- model_config=self.model_config,
+ """
+ Reload weights from a weights iterator or from disk
+
+ :param weights_iterator: weights to load into model
+ :param weights_path: path to load weights from if weights_iterator is not
+ provided. Use path of original model if neither is provided.
+ :param is_checkpoint_format: set to False if weights have already been processed
+ into kernel format (repacking, renaming, etc.)
+ """
+ # TODO(@kylesayrs): generalize to all runners and loaders
+ # argument validation
+ if weights_iterator is None and not is_checkpoint_format:
+ logger.warning(
+ "Reloading from disk means that weights will be in checkpoint format. "
+ "Please use `is_checkpoint_format=True` "
+ "to avoid weight reloading errors"
+ )
+
+ model = self.get_model()
+ weights_to_load = {name for name, _ in model.named_parameters()}
+ counter_before_reloading = time.perf_counter()
+
+ # load weights from disk if none are provided
+ if weights_iterator is None:
+ model_loader = get_model_loader(self.load_config)
+ if not hasattr(model_loader, "get_all_weights"):
+ raise NotImplementedError(
+ f"Model reloading with `{self.load_config.load_format}` format"
+ )
+
+ if weights_path is not None:
+ self.model_config.model = weights_path
+ weights_iterator = model_loader.get_all_weights(self.model_config, model)
+ weights_iterator = cast(
+ Iterable[tuple[str, torch.Tensor]], weights_iterator
+ )
+
+ # begin loading weights
+ logger.info_once("Reloading weights inplace...", scope="local")
+ load_device = (
+ self.vllm_config.load_config.device or self.vllm_config.device_config.device
)
+ with torch.device(load_device):
+ if is_checkpoint_format:
+ # load weights from checkpoint/ original model format
+ initialize_layerwise_reload(model)
+ loaded_weights = model.load_weights(weights_iterator)
+ finalize_layerwise_reload(model, self.model_config)
+
+ else:
+ # load weights from kernel format
+ logger.warning_once(
+ "Reloading with `is_checkpoint_format=True` requires that "
+ "weights be in kernel format and already sharded",
+ scope="local",
+ )
+ loaded_weights = set()
+ for name, loaded_weight in weights_iterator:
+ param = model.get_parameter(name) # TODO: buffers?
+ param.copy_(loaded_weight)
+ loaded_weights.add(name)
+
+ # logging and validation
+ counter_after_reloading = time.perf_counter()
+ diff_seconds = counter_after_reloading - counter_before_reloading
+ logger.info_once(
+ "Reloading and processing weights took %.2f seconds",
+ diff_seconds,
+ scope="local",
+ )
+ if self.model_config.quantization is None and loaded_weights is not None:
+ weights_not_loaded = weights_to_load - loaded_weights
+ if weights_not_loaded:
+ logger.warning(
+ "Following weights were not loaded from checkpoint: %s",
+ weights_not_loaded,
+ )
def _get_prompt_logprobs_dict(
self,
@@ -3900,7 +4876,7 @@ def _get_prompt_logprobs_dict(
# Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits)
- token_ids, logprobs, ranks = self.sampler.gather_logprobs(
+ token_ids, logprobs, ranks, _ = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids
)
@@ -3932,7 +4908,7 @@ def _get_nans_in_logits(
) -> dict[str, int]:
try:
if logits is None:
- return dict.fromkeys(self.input_batch.req_ids, 0)
+ return {req_id: 0 for req_id in self.input_batch.req_ids}
num_nans_in_logits = {}
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
@@ -4000,22 +4976,22 @@ def _get_mm_dummy_batch(
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
- dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
- model_config=self.model_config,
- seq_len=self.max_model_len,
+ # Don't use `max_items_per_batch` here to avoid redundant computation
+ dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs(
+ self.model_config,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
- dummy_mm_data = dummy_decoder_data.multi_modal_data
+ dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0]
- # Result in the maximum GPU consumption of the model
- dummy_mm_item = dummy_mm_data[modality][0]
- dummy_mm_items = [dummy_mm_item] * max_items_per_batch
+ # We use the cache so that the item is saved to the cache,
+ # but not read from the cache
+ assert dummy_mm_item is not None, "Item should not already be cached"
return next(
- mm_kwargs_group
- for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
- dummy_mm_items,
+ mm_kwargs_batch
+ for _, _, mm_kwargs_batch in group_and_batch_mm_kwargs(
+ [(modality, dummy_mm_item)] * max_items_per_batch,
device=self.device,
pin_memory=self.pin_memory,
)
@@ -4033,12 +5009,13 @@ def _dummy_run(
is_profile: bool = False,
create_mixed_batch: bool = False,
remove_lora: bool = True,
- activate_lora: bool = False,
is_graph_capturing: bool = False,
+ num_active_loras: int = 0,
+ profile_seq_lens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
- graph for the model.
+ CUDA graph for the model.
Args:
num_tokens: Number of tokens to run the dummy forward pass.
@@ -4057,11 +5034,21 @@ def _dummy_run(
create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
- activate_lora: If False, dummy_run is performed without LoRAs.
+ num_active_loras: Number of distinct active LoRAs to capture for.
+ LoRA is activated when num_active_loras > 0.
+ profile_seq_lens: If provided, use this value for seq_lens instead
+ of max_query_len. Used to profile attention workspace that
+ scales with context length.
"""
+ mm_config = self.vllm_config.model_config.multimodal_config
+ if mm_config and mm_config.mm_encoder_only:
+ # The current dummy run only covers LM execution, so we can skip it.
+ # mm encoder dummy run may need to add in the future.
+ return torch.tensor([]), torch.tensor([])
+
assert (
cudagraph_runtime_mode is None
- or cudagraph_runtime_mode.valid_runtime_modes()
+ or cudagraph_runtime_mode.is_valid_runtime_mode()
)
# If cudagraph_mode.decode_mode() == FULL and
@@ -4082,7 +5069,7 @@ def _dummy_run(
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
- assert num_tokens <= self.scheduler_config.max_num_batched_tokens
+ assert num_tokens <= self.max_num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
if create_mixed_batch:
assert not uniform_decode
@@ -4133,7 +5120,10 @@ def _dummy_run(
# `force_has_lora` is used for cudagraph capture; because LoRA is
# activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture
- force_has_lora=activate_lora,
+ force_has_lora=num_active_loras > 0,
+ # `force_num_active_loras` is used for cudagraph capture; because we
+ # need to capture graphs for specific num_active_loras counts
+ force_num_active_loras=num_active_loras,
)
)
@@ -4150,51 +5140,77 @@ def _dummy_run(
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
- should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
+ should_ubatch,
+ num_scheduled_tokens,
+ num_tokens_padded,
+ num_reqs_padded,
+ self.vllm_config.parallel_config.num_ubatches,
+ )
+ logger.debug(
+ "ubatch_slices: %s, ubatch_slices_padded: %s",
+ ubatch_slices,
+ ubatch_slices_padded,
)
attn_metadata: PerLayerAttnMetadata | None = None
- # If force_attention is True, we always capture attention. Otherwise,
- # it only happens for cudagraph_runtime_mode=FULL.
- if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
- if create_mixed_batch:
- # In the mixed batch mode (used for FI warmup), we use
- # shorter sequence lengths to run faster.
- # TODO(luka) better system for describing dummy batches
- seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
- else:
- seq_lens = max_query_len # type: ignore[assignment]
- self.seq_lens.np[:num_reqs] = seq_lens
- self.seq_lens.np[num_reqs:] = 0
- self.seq_lens.copy_to_gpu()
+ slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
+ num_tokens_padded=num_tokens,
+ num_reqs_padded=num_reqs_padded,
+ num_tokens_unpadded=num_tokens_unpadded,
+ ubatch_slices=ubatch_slices_padded,
+ )
- cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
- self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
- self.query_start_loc.copy_to_gpu()
+ # _dummy_run shares pinned CPU buffers (seq_lens, query_start_loc,
+ # etc.) with execute_model. It must participate in the same event
+ # protocol so that back-to-back dummy/real steps don't overwrite
+ # pinned memory while a prior non_blocking H2D DMA is still reading.
+ with self.synchronize_input_prep():
+ # If force_attention is True, we always capture attention.
+ # Otherwise, it only happens for cudagraph_runtime_mode=FULL.
+ if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
+ if profile_seq_lens is not None:
+ seq_lens = profile_seq_lens # type: ignore[assignment]
+ elif create_mixed_batch:
+ # In the mixed batch mode (used for FI warmup), we use
+ # shorter sequence lengths to run faster.
+ # TODO(luka) better system for describing dummy batches
+ seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment]
+ else:
+ seq_lens = max_query_len # type: ignore[assignment]
+ self.seq_lens.np[:num_reqs] = seq_lens
+ self.seq_lens.np[num_reqs:] = 0
+ self.seq_lens.copy_to_gpu()
- pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
- attn_metadata, _ = self._build_attention_metadata(
- num_tokens=num_tokens_unpadded,
- num_reqs=num_reqs_padded,
- max_query_len=max_query_len,
- ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
- for_cudagraph_capture=is_graph_capturing,
- )
+ cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
+ self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
+ self.query_start_loc.copy_to_gpu()
+
+ pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
+ attn_metadata, _ = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded if pad_attn else None,
+ num_reqs=num_reqs_padded,
+ max_query_len=max_query_len,
+ ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices),
+ for_cudagraph_capture=is_graph_capturing,
+ slot_mappings=slot_mappings_by_group,
+ use_spec_decode=self.speculative_config is not None,
+ )
with self.maybe_dummy_run_with_lora(
self.lora_config,
num_scheduled_tokens,
num_sampled_tokens,
- activate_lora,
remove_lora,
+ num_active_loras,
):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
- model_kwargs = self._init_model_kwargs(num_tokens_padded)
+ model_kwargs = self._init_model_kwargs()
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
- input_ids = None
- inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
+ input_ids, inputs_embeds = self._prepare_mm_inputs(num_tokens_padded)
+
model_kwargs = {
**model_kwargs,
**self._dummy_mm_kwargs(num_reqs),
@@ -4202,7 +5218,7 @@ def _dummy_run(
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
- model_kwargs = self._init_model_kwargs(num_tokens_padded)
+ model_kwargs = self._init_model_kwargs()
else:
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
@@ -4237,6 +5253,7 @@ def _dummy_run(
num_tokens_padded = ubatch_slices_padded[0].num_tokens
if num_tokens_across_dp is not None:
num_tokens_across_dp[:] = num_tokens_padded
+
with (
self.maybe_randomize_inputs(input_ids, inputs_embeds),
set_forward_context(
@@ -4247,6 +5264,7 @@ def _dummy_run(
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
+ slot_mapping=slot_mappings,
),
):
outputs = self.model(
@@ -4262,8 +5280,16 @@ def _dummy_run(
else:
hidden_states = outputs
- if self.speculative_config and self.speculative_config.use_eagle():
- assert isinstance(self.drafter, EagleProposer)
+ if self.speculative_config and (
+ self.speculative_config.use_eagle()
+ or self.speculative_config.uses_draft_model()
+ or self.speculative_config.uses_extract_hidden_states()
+ ):
+ assert isinstance(
+ self.drafter,
+ EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
+ )
+ assert self.speculative_config is not None
# Eagle currently only supports PIECEWISE cudagraphs.
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
@@ -4282,13 +5308,17 @@ def _dummy_run(
# lora cases when cudagraph_specialize_lora is enabled. This is a
# short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/28334
- if self.compilation_config.cudagraph_specialize_lora and activate_lora:
+ if (
+ self.compilation_config.cudagraph_specialize_lora
+ and num_active_loras > 0
+ ):
use_cudagraphs = False
self.drafter.dummy_run(
num_tokens,
use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing,
+ slot_mappings=slot_mappings,
)
# We register layerwise NVTX hooks here after the first dynamo tracing is
@@ -4326,6 +5356,12 @@ def _dummy_sampler_run(
# The dummy hidden states may contain special values,
# like `inf` or `nan`.
# To avoid breaking the sampler, we use a random tensor here instead.
+
+ mm_config = self.vllm_config.model_config.multimodal_config
+ if mm_config and mm_config.mm_encoder_only:
+ # MM Encoder only model no need to run sampler.
+ return torch.tensor([])
+
hidden_states = torch.rand_like(hidden_states)
logits = self.model.compute_logits(hidden_states)
@@ -4400,24 +5436,21 @@ def _dummy_pooler_run_task(
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
- num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
- num_scheduled_tokens_list[-1] += num_tokens % num_reqs
- assert sum(num_scheduled_tokens_list) == num_tokens
- assert len(num_scheduled_tokens_list) == num_reqs
+ num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req)
+ num_scheduled_tokens_np[-1] += num_tokens % num_reqs
+ assert np.sum(num_scheduled_tokens_np) == num_tokens
+ assert len(num_scheduled_tokens_np) == num_reqs
req_num_tokens = num_tokens // num_reqs
- dummy_prompt_lens = torch.tensor(
- num_scheduled_tokens_list,
- device="cpu",
- )
+ dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np)
dummy_token_ids = torch.zeros(
(num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
)
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
- dummy_pooling_params.verify(task=task, model_config=self.model_config)
+ dummy_pooling_params.verify(self.model_config)
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(dummy_pooling_params)
@@ -4429,7 +5462,7 @@ def _dummy_pooler_run_task(
)
dummy_metadata.build_pooling_cursor(
- num_scheduled_tokens_list,
+ num_scheduled_tokens_np,
seq_lens_cpu=dummy_prompt_lens,
device=hidden_states.device,
)
@@ -4454,6 +5487,11 @@ def _dummy_pooler_run(
self,
hidden_states: torch.Tensor,
) -> PoolerOutput:
+ mm_config = self.vllm_config.model_config.multimodal_config
+ if mm_config and mm_config.mm_encoder_only:
+ # MM Encoder only model not need to run pooler.
+ return torch.tensor([])
+
# Find the task that has the largest output for subsequent steps
supported_pooling_tasks = self.get_supported_pooling_tasks()
@@ -4469,7 +5507,7 @@ def _dummy_pooler_run(
for task in supported_pooling_tasks:
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
- output_size[task] = sum(o.nbytes for o in output)
+ output_size[task] = sum(o.nbytes for o in output if o is not None)
del output # Allow GC
max_task = max(output_size.items(), key=lambda x: x[1])[0]
@@ -4489,40 +5527,51 @@ def profile_run(self) -> None:
assert mm_budget is not None
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
- # NOTE: Currently model is profiled with a single non-text
- # modality with the max possible input tokens even when
- # it supports multiple.
- dummy_modality = mm_budget.get_modality_with_max_tokens()
- max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
- dummy_modality
- ]
+ if not mm_budget.mm_max_toks_per_item:
+ # All modality limits are 0 — embedding-only mode.
+ # Budget is non-zero for embedding storage, but
+ # there's no encoder to profile.
+ logger.info(
+ "Skipping encoder profiling for embedding-only "
+ "mode (all modality limits=0 with "
+ "enable_mm_embeds=True).",
+ )
+ else:
+ # NOTE: Currently model is profiled with a single
+ # non-text modality with the max possible input
+ # tokens even when it supports multiple.
+ dummy_modality = mm_budget.get_modality_with_max_tokens()
+ max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[
+ dummy_modality
+ ]
- logger.info(
- "Encoder cache will be initialized with a budget of "
- "%s tokens, and profiled with %s %s items of the "
- "maximum feature size.",
- encoder_budget,
- max_mm_items_per_batch,
- dummy_modality,
- )
+ logger.info_once(
+ "Encoder cache will be initialized with a "
+ "budget of %s tokens, and profiled with "
+ "%s %s items of the maximum feature size.",
+ encoder_budget,
+ max_mm_items_per_batch,
+ dummy_modality,
+ scope="local",
+ )
- # Create dummy batch of multimodal inputs.
- batched_dummy_mm_inputs = self._get_mm_dummy_batch(
- dummy_modality,
- max_mm_items_per_batch,
- )
+ # Create dummy batch of multimodal inputs.
+ batched_dummy_mm_inputs = self._get_mm_dummy_batch(
+ dummy_modality,
+ max_mm_items_per_batch,
+ )
- # Run multimodal encoder.
- dummy_encoder_outputs = self.model.embed_multimodal(
- **batched_dummy_mm_inputs
- )
+ # Run multimodal encoder.
+ dummy_encoder_outputs = self.model.embed_multimodal(
+ **batched_dummy_mm_inputs
+ )
- sanity_check_mm_encoder_outputs(
- dummy_encoder_outputs,
- expected_num_items=max_mm_items_per_batch,
- )
- for i, output in enumerate(dummy_encoder_outputs):
- self.encoder_cache[f"tmp_{i}"] = output
+ sanity_check_mm_encoder_outputs(
+ dummy_encoder_outputs,
+ expected_num_items=max_mm_items_per_batch,
+ )
+ for i, output in enumerate(dummy_encoder_outputs):
+ self.encoder_cache[f"tmp_{i}"] = output
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run(
@@ -4540,6 +5589,169 @@ def profile_run(self) -> None:
self.encoder_cache.clear()
gc.collect()
+ def _init_minimal_kv_cache_for_profiling(self) -> None:
+ from vllm.v1.core.kv_cache_utils import (
+ get_kv_cache_config_from_groups,
+ get_kv_cache_groups,
+ )
+
+ kv_cache_spec = self.get_kv_cache_spec()
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
+ min_blocks = self.compilation_config.max_cudagraph_capture_size or 1
+
+ # Temporarily change num_gpu_blocks_override to allocate a minimal KV cache
+ saved_override = self.cache_config.num_gpu_blocks_override
+ self.cache_config.num_gpu_blocks_override = min_blocks
+ minimal_config = get_kv_cache_config_from_groups(
+ self.vllm_config, kv_cache_groups, available_memory=0
+ )
+ self.cache_config.num_gpu_blocks_override = saved_override
+
+ self.initialize_kv_cache(minimal_config)
+ self.cache_config.num_gpu_blocks = minimal_config.num_blocks
+
+ logger.debug("Initialized minimal KV cache for CUDA graph profiling")
+
+ @staticmethod
+ @contextmanager
+ def _freeze_gc():
+ gc.collect()
+ should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
+ if should_freeze:
+ gc.freeze()
+ try:
+ yield
+ finally:
+ if should_freeze:
+ gc.unfreeze()
+ gc.collect()
+
+ def _cleanup_profiling_kv_cache(self) -> None:
+ torch.accelerator.synchronize()
+ if hasattr(self, "kv_caches") and self.kv_caches:
+ for i in range(len(self.kv_caches)):
+ self.kv_caches[i] = None # type: ignore
+ self.kv_caches.clear()
+ if hasattr(self, "cross_layers_kv_cache"):
+ self.cross_layers_kv_cache = None
+ self.cross_layers_attn_backend = None
+ if hasattr(self, "attn_groups"):
+ self.attn_groups.clear()
+ if hasattr(self, "kv_cache_config"):
+ delattr(self, "kv_cache_config")
+ self.cache_config.num_gpu_blocks = None
+
+ for layer in self.compilation_config.static_forward_context.values():
+ if hasattr(layer, "kv_cache"):
+ layer.kv_cache = []
+
+ gc.collect()
+ torch.accelerator.empty_cache()
+
+ logger.debug("Cleaned up profiling KV cache and CUDA graphs")
+
+ @torch.inference_mode()
+ def profile_cudagraph_memory(self) -> int:
+ with set_current_vllm_config(self.vllm_config):
+ self._init_minimal_kv_cache_for_profiling()
+
+ saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured
+
+ capture_descs = self.cudagraph_dispatcher.get_capture_descs()
+
+ total_graphs = sum(len(descs) for _, descs in capture_descs)
+ if total_graphs == 0:
+ logger.debug("No CUDA graphs will be captured, skipping profiling")
+ self._cleanup_profiling_kv_cache()
+ return 0
+
+ logger.info(
+ "Profiling CUDA graph memory: %s",
+ ", ".join(
+ f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})"
+ for mode, descs in capture_descs
+ if descs
+ ),
+ )
+
+ # Use a temporary pool for profiling to avoid fragmentation in the main pool.
+ profiling_pool = current_platform.graph_pool_handle()
+ original_pools: dict[int, Any] = {}
+ for instance in list(GraphWrapper._all_instances):
+ original_pools[id(instance)] = instance.graph_pool
+ instance.graph_pool = profiling_pool
+
+ set_cudagraph_capturing_enabled(True)
+ with self._freeze_gc(), graph_capture(device=self.device):
+ shared_memory_estimate = {}
+ per_graph_estimate = {}
+ torch.accelerator.synchronize()
+ torch.accelerator.empty_cache()
+
+ for mode, descs in capture_descs:
+ profile_descs = descs[:2]
+ mem_samples: list[int] = []
+
+ for i, desc in enumerate(profile_descs):
+ mem_before = current_platform.torch_device_fn.mem_get_info()[0]
+ self._warmup_and_capture(
+ desc,
+ cudagraph_runtime_mode=mode,
+ profile_seq_lens=(
+ min(
+ self.max_model_len,
+ self.max_num_tokens // desc.num_tokens,
+ )
+ if mode == CUDAGraphMode.FULL and i == 0
+ else None
+ ),
+ )
+ torch.accelerator.synchronize()
+ free_after = current_platform.torch_device_fn.mem_get_info()[0]
+ mem_samples.append(mem_before - free_after)
+
+ first_capture = mem_samples[0]
+ # Use at least 1 MiB per graph for driver overhead
+ per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20)
+
+ shared_memory_estimate[mode] = first_capture
+ per_graph_estimate[mode] = per_graph * (len(descs) - 1)
+
+ logger.debug(
+ "Estimated %s CUDA graph memory: "
+ "%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph",
+ mode.name,
+ first_capture / (1 << 20),
+ len(descs),
+ per_graph / (1 << 20),
+ )
+
+ set_cudagraph_capturing_enabled(False)
+ GraphWrapper.clear_all_graphs()
+ for instance in list(GraphWrapper._all_instances):
+ if id(instance) in original_pools:
+ instance.graph_pool = original_pools[id(instance)]
+ for key_set in self.cudagraph_dispatcher.cudagraph_keys.values():
+ key_set.clear()
+ self.cudagraph_dispatcher.keys_initialized = False
+ self.maybe_remove_all_loras(self.lora_config)
+ self._cleanup_profiling_kv_cache()
+ compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured
+
+ # FULL and PIECEWISE graphs share the global pool at runtime and are
+ # never replayed concurrently, so the pool overlays their memory.
+ # Take the max to avoid double-counting the overlap.
+ total_estimate = max(shared_memory_estimate.values()) + sum(
+ per_graph_estimate.values()
+ )
+ logger.info(
+ "Estimated CUDA graph memory: %.2f GiB total",
+ total_estimate / (1 << 30),
+ )
+
+ return int(total_estimate)
+
+ @instrument(span_name="Capture model")
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
logger.warning(
@@ -4552,75 +5764,28 @@ def capture_model(self) -> int:
start_time = time.perf_counter()
- @contextmanager
- def freeze_gc():
- # Optimize garbage collection during CUDA graph capture.
- # Clean up, then freeze all remaining objects from being included
- # in future collections.
- gc.collect()
- should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
- if should_freeze:
- gc.freeze()
- try:
- yield
- finally:
- if should_freeze:
- gc.unfreeze()
- gc.collect()
-
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
- with freeze_gc(), graph_capture(device=self.device):
+ with self._freeze_gc(), graph_capture(device=self.device):
+ torch.accelerator.synchronize()
+ torch.accelerator.empty_cache()
start_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0]
- cudagraph_mode = self.compilation_config.cudagraph_mode
- assert cudagraph_mode is not None
-
- if self.lora_config:
- if self.compilation_config.cudagraph_specialize_lora:
- lora_cases = [True, False]
- else:
- lora_cases = [True]
- else:
- lora_cases = [False]
- if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
- cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
- # make sure we capture the largest batch size first
- compilation_cases = list(
- product(reversed(self.cudagraph_batch_sizes), lora_cases)
- )
- self._capture_cudagraphs(
- compilation_cases,
- cudagraph_runtime_mode=cudagraph_runtime_mode,
- uniform_decode=False,
- )
- # Capture full cudagraph for uniform decode batches if we
- # don't already have full mixed prefill-decode cudagraphs.
- if (
- cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
- and cudagraph_mode.separate_routine()
- ):
- max_num_tokens = (
- self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
- )
- decode_cudagraph_batch_sizes = [
- x
- for x in self.cudagraph_batch_sizes
- if max_num_tokens >= x >= self.uniform_decode_query_len
- ]
- compilation_cases_decode = list(
- product(reversed(decode_cudagraph_batch_sizes), lora_cases)
- )
+ for (
+ runtime_mode,
+ batch_descs,
+ ) in self.cudagraph_dispatcher.get_capture_descs():
self._capture_cudagraphs(
- compilation_cases=compilation_cases_decode,
- cudagraph_runtime_mode=CUDAGraphMode.FULL,
- uniform_decode=True,
+ batch_descriptors=batch_descs,
+ cudagraph_runtime_mode=runtime_mode,
)
+ torch.accelerator.synchronize()
- current_platform.torch_device_fn.synchronize()
+ torch.accelerator.synchronize()
end_free_gpu_memory = current_platform.torch_device_fn.mem_get_info()[0]
+
# Disable cudagraph capturing globally, so any unexpected cudagraph
# capturing will be detected and raise an error after here.
# Note: We don't put it into graph_capture context manager because
@@ -4628,6 +5793,9 @@ def freeze_gc():
# after here.
set_cudagraph_capturing_enabled(False)
+ torch.accelerator.synchronize()
+ torch.accelerator.empty_cache()
+
# Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling.
lock_workspace()
@@ -4644,21 +5812,59 @@ def freeze_gc():
)
return cuda_graph_size
+ def _warmup_and_capture(
+ self,
+ desc: BatchDescriptor,
+ cudagraph_runtime_mode: CUDAGraphMode,
+ profile_seq_lens: int | None = None,
+ allow_microbatching: bool = False,
+ num_warmups: int | None = None,
+ ):
+ if num_warmups is None:
+ num_warmups = self.compilation_config.cudagraph_num_of_warmups
+ force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
+ for _ in range(num_warmups):
+ self._dummy_run(
+ desc.num_tokens,
+ cudagraph_runtime_mode=CUDAGraphMode.NONE,
+ force_attention=force_attention,
+ uniform_decode=desc.uniform,
+ allow_microbatching=allow_microbatching,
+ skip_eplb=True,
+ remove_lora=False,
+ num_active_loras=desc.num_active_loras,
+ )
+ self._dummy_run(
+ desc.num_tokens,
+ cudagraph_runtime_mode=cudagraph_runtime_mode,
+ uniform_decode=desc.uniform,
+ allow_microbatching=allow_microbatching,
+ skip_eplb=True,
+ remove_lora=False,
+ num_active_loras=desc.num_active_loras,
+ is_graph_capturing=True,
+ profile_seq_lens=profile_seq_lens,
+ )
+
def _capture_cudagraphs(
self,
- compilation_cases: list[tuple[int, bool]],
+ batch_descriptors: list[BatchDescriptor],
cudagraph_runtime_mode: CUDAGraphMode,
- uniform_decode: bool,
):
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
- and cudagraph_runtime_mode.valid_runtime_modes()
+ and cudagraph_runtime_mode.is_valid_runtime_mode()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
+ if not batch_descriptors:
+ return
+
+ uniform_decode = batch_descriptors[0].uniform
+
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
- compilation_cases = tqdm(
- compilation_cases,
+ batch_descriptors = tqdm(
+ batch_descriptors,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
@@ -4667,49 +5873,27 @@ def _capture_cudagraphs(
)
# We skip EPLB here since we don't want to record dummy metrics
- for num_tokens, activate_lora in compilation_cases:
+ for batch_desc in batch_descriptors:
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
# version of the graph
allow_microbatching = (
- self.parallel_config.enable_dbo
+ self.parallel_config.use_ubatching
and cudagraph_runtime_mode == CUDAGraphMode.FULL
and uniform_decode
and check_ubatch_thresholds(
config=self.vllm_config.parallel_config,
- num_tokens=num_tokens,
+ num_tokens=batch_desc.num_tokens,
uniform_decode=uniform_decode,
)
)
-
- for _ in range(self.compilation_config.cudagraph_num_of_warmups):
- # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
- # But be careful, warm up with `NONE`is orthogonal to
- # if we want to warm up attention or not. This is
- # different from the case where `FULL` implies capture
- # attention while `PIECEWISE` implies no attention.
- force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
- self._dummy_run(
- num_tokens,
- cudagraph_runtime_mode=CUDAGraphMode.NONE,
- force_attention=force_attention,
- uniform_decode=uniform_decode,
- allow_microbatching=allow_microbatching,
- skip_eplb=True,
- remove_lora=False,
- activate_lora=activate_lora,
- )
- self._dummy_run(
- num_tokens,
+ self._warmup_and_capture(
+ batch_desc,
cudagraph_runtime_mode=cudagraph_runtime_mode,
- uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
- skip_eplb=True,
- remove_lora=False,
- activate_lora=activate_lora,
- is_graph_capturing=True,
)
+ torch.accelerator.synchronize()
self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
@@ -4808,14 +5992,22 @@ def initialize_metadata_builders(
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
- if not self.parallel_config.enable_dbo
- else 2,
+ if not self.parallel_config.use_ubatching
+ else self.parallel_config.num_ubatches,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()
+ # Initialize drafter attention backend
+ if self.speculative_config and (
+ self.speculative_config.use_eagle()
+ or self.speculative_config.uses_draft_model()
+ ):
+ assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
+ self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
+
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],
@@ -4959,10 +6151,22 @@ def _check_and_update_cudagraph_mode(
self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
)
- capture_sizes = self.compilation_config.cudagraph_capture_sizes
- self.cudagraph_batch_sizes = (
- capture_sizes if capture_sizes is not None else []
+
+ # If the model has Mamba layers and cudagraph mode includes FULL
+ # decode, cap cudagraph capture sizes to the number of available
+ # Mamba cache blocks. Each decode request needs one conv_state
+ # cache line, so capture batch sizes cannot exceed num_blocks.
+ # Only FULL decode graphs are affected because PIECEWISE captures
+ # run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update).
+ # See: https://github.com/vllm-project/vllm/issues/34094
+ if cudagraph_mode.has_full_cudagraphs():
+ has_mamba = any(
+ isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups
)
+ if has_mamba and self.kv_cache_config is not None:
+ self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache(
+ self.kv_cache_config.num_blocks
+ )
# Trigger cudagraph dispatching keys initialization after
# resolved cudagraph mode.
@@ -4971,6 +6175,14 @@ def _check_and_update_cudagraph_mode(
cudagraph_mode, self.uniform_decode_query_len
)
+ # Initialize drafter's cudagraph dispatcher if using spec decode.
+ if self.speculative_config and (
+ self.speculative_config.use_eagle()
+ or self.speculative_config.uses_extract_hidden_states()
+ ):
+ assert isinstance(self.drafter, EagleProposer | ExtractHiddenStatesProposer)
+ self.drafter.initialize_cudagraph_keys(cudagraph_mode)
+
def calculate_reorder_batch_threshold(self) -> None:
"""
Choose the minimum reorder batch threshold from all attention groups.
@@ -4991,120 +6203,75 @@ def calculate_reorder_batch_threshold(self) -> None:
return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
- @staticmethod
- def select_common_block_size(
- kv_manager_block_size: int, attn_groups: list[AttentionGroup]
- ) -> int:
- """
- Select a block size that is supported by all backends and is a factor of
- kv_manager_block_size.
-
- If kv_manager_block_size is supported by all backends, return it directly.
- Otherwise, return the max supported size.
-
- Args:
- kv_manager_block_size: Block size of KV cache
- attn_groups: List of attention groups
-
- Returns:
- The selected block size
-
- Raises:
- ValueError: If no valid block size found
- """
-
- def block_size_is_supported(
- backends: list[type[AttentionBackend]], block_size: int
- ) -> bool:
- """
- Check if the block size is supported by all backends.
- """
- for backend in backends:
- is_supported = False
- for supported_size in backend.get_supported_kernel_block_sizes():
- if isinstance(supported_size, int):
- if block_size == supported_size:
- is_supported = True
- elif isinstance(supported_size, MultipleOf):
- if block_size % supported_size.base == 0:
- is_supported = True
- else:
- raise ValueError(f"Unknown supported size: {supported_size}")
- if not is_supported:
- return False
- return True
-
- backends = [group.backend for group in attn_groups]
-
- # Case 1: if the block_size of kv cache manager is supported by all backends,
- # return it directly
- if block_size_is_supported(backends, kv_manager_block_size):
- return kv_manager_block_size
-
- # Case 2: otherwise, the block_size must be an `int`-format supported size of
- # at least one backend. Iterate over all `int`-format supported sizes in
- # descending order and return the first one that is supported by all backends.
- # Simple proof:
- # If the supported size b is in MultipleOf(x_i) format for all attention
- # backends i, and b a factor of kv_manager_block_size, then
- # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
- # return kv_manager_block_size in case 1.
- all_int_supported_sizes = set(
- supported_size
- for backend in backends
- for supported_size in backend.get_supported_kernel_block_sizes()
- if isinstance(supported_size, int)
- )
-
- for supported_size in sorted(all_int_supported_sizes, reverse=True):
- if kv_manager_block_size % supported_size != 0:
- continue
- if block_size_is_supported(backends, supported_size):
- return supported_size
- raise ValueError(f"No common block size for {kv_manager_block_size}. ")
-
def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Re-initialize the input batch if the block sizes are different from
- `[self.cache_config.block_size]`. This usually happens when there
- are multiple KV cache groups.
+ what it was originally created with. This happens when the final
+ block size (determined after model loading) differs from the
+ placeholder used during __init__, or when there are multiple
+ KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
kernel_block_sizes: The kernel block sizes for each KV cache group.
"""
- block_sizes = [
- kv_cache_group.kv_cache_spec.block_size
- for kv_cache_group in kv_cache_config.kv_cache_groups
- if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
- ]
+ block_sizes = []
+ max_num_blocks = []
+ max_model_len = max(self.max_model_len, self.max_encoder_len)
+ for kv_cache_group in kv_cache_config.kv_cache_groups:
+ if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
+ continue
+ block_size = kv_cache_group.kv_cache_spec.block_size
+ block_sizes.append(block_size)
+ max_num_blocks_per_req = cdiv(
+ max_model_len, block_size * get_total_cp_world_size()
+ )
+ if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
+ max_num_blocks_per_req = (
+ max_num_blocks_per_req
+ if self.cache_config.enable_prefix_caching
+ else 1
+ ) + kv_cache_group.kv_cache_spec.num_speculative_blocks
+ max_num_blocks.append(max_num_blocks_per_req)
- if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
- self.cache_config.block_size
- ]:
- assert self.cache_config.cpu_offload_gb == 0, (
+ if (
+ block_sizes != self._init_block_sizes
+ or kernel_block_sizes != self._init_kernel_block_sizes
+ ):
+ assert self.offload_config.uva.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details."
)
+ self._init_block_sizes = block_sizes
+ self._init_kernel_block_sizes = kernel_block_sizes
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
- max_model_len=max(self.max_model_len, self.max_encoder_len),
+ max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
+ max_num_blocks_per_req=max_num_blocks,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
- num_speculative_tokens=self.num_spec_tokens,
)
+ assert self._init_block_sizes == block_sizes, (
+ f"InputBatch block_sizes {self._init_block_sizes} != "
+ f"kv_cache block_sizes {block_sizes}"
+ )
+ assert self._init_kernel_block_sizes == kernel_block_sizes, (
+ f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} "
+ f"!= kv_cache kernel_block_sizes {kernel_block_sizes}"
+ )
+
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
@@ -5146,49 +6313,6 @@ def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
for attn_groups in self.attn_groups:
yield from attn_groups
- def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
- """
- Generate kernel_block_sizes that matches each block_size.
-
- For attention backends that support virtual block splitting,
- use the supported block sizes from the backend.
- For other backends (like Mamba), use the same block size (no splitting).
-
- Args:
- kv_cache_config: The KV cache configuration.
-
- Returns:
- list[int]: List of kernel block sizes for each cache group.
- """
- kernel_block_sizes = []
- for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
- kv_cache_spec = kv_cache_group.kv_cache_spec
- if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
- # All layers in the UniformTypeKVCacheSpecs have the same type,
- # Pick an arbitrary one to dispatch.
- kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
- if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
- continue
- elif isinstance(kv_cache_spec, AttentionSpec):
- # This is an attention backend that supports virtual
- # block splitting. Get the supported block sizes from
- # all backends in the group.
- attn_groups = self.attn_groups[kv_cache_gid]
- kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
- selected_kernel_size = self.select_common_block_size(
- kv_manager_block_size, attn_groups
- )
- kernel_block_sizes.append(selected_kernel_size)
- elif isinstance(kv_cache_spec, MambaSpec):
- # This is likely Mamba or other non-attention cache,
- # no splitting.
- kernel_block_sizes.append(kv_cache_spec.block_size)
- else:
- raise NotImplementedError(
- f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
- )
- return kernel_block_sizes
-
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
@@ -5252,7 +6376,8 @@ def _reshape_kv_cache_tensors(
)
# Maintain original KV shape view.
inv_order = [
- kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
+ kv_cache_stride_order.index(i)
+ for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = (
kv_cache_raw_tensors[layer_name]
@@ -5411,6 +6536,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
+ self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
@@ -5419,7 +6545,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
# backends for that group only supports block_size 64, we will return
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
- kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
+ kernel_block_sizes = prepare_kernel_block_sizes(
+ kv_cache_config, self.attn_groups
+ )
+ self._kernel_block_sizes = kernel_block_sizes
# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
@@ -5430,8 +6559,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config, kernel_block_sizes
)
- if self.speculative_config and self.speculative_config.use_eagle():
- assert isinstance(self.drafter, EagleProposer)
+ if (
+ self.speculative_config
+ and self.speculative_config.uses_extract_hidden_states()
+ ):
+ assert isinstance(self.drafter, ExtractHiddenStatesProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
@@ -5447,6 +6579,58 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
+ def _get_attention_kv_cache_gid(self) -> int:
+ """Find the KV cache group index for attention layers."""
+ for gid, group in enumerate(self.kv_cache_config.kv_cache_groups):
+ if isinstance(group.kv_cache_spec, AttentionSpec):
+ return gid
+ return 0
+
+ def init_routed_experts_capturer(self):
+ logger.info(
+ "Initializing routed experts capturer, enable_return_routed_experts: %s",
+ self.model_config.enable_return_routed_experts,
+ )
+ routed_experts_capturer = RoutedExpertsCapturer.create()
+ self.routed_experts_attn_gid = self._get_attention_kv_cache_gid()
+ min_block_size = min(
+ [
+ group.kv_cache_spec.block_size
+ for group in self.kv_cache_config.kv_cache_groups
+ ]
+ )
+ num_groups = len(self.kv_cache_config.kv_cache_groups)
+ self.max_num_kv_tokens = (
+ self.kv_cache_config.num_blocks // num_groups
+ ) * min_block_size
+ dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size
+ pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
+ if pcp_size * dcp_size > 1:
+ self.max_num_kv_tokens *= pcp_size * dcp_size
+
+ routed_experts_capturer.init_buffer(
+ max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
+ max_num_kv_tokens=self.max_num_kv_tokens,
+ vllm_config=self.vllm_config,
+ )
+ self._bind_routed_experts_capturer(routed_experts_capturer)
+ self.routed_experts_initialized = True
+
+ def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None:
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+ from vllm.model_executor.layers.fused_moe.router.base_router import (
+ BaseRouter,
+ )
+
+ for module in self.compilation_config.static_forward_context.values():
+ if isinstance(module, FusedMoE) and isinstance(module.router, BaseRouter):
+ layer_id = module.layer_id
+
+ def _capture_fn(topk_ids, _layer_id=layer_id, _capturer=capturer):
+ _capturer.capture(_layer_id, topk_ids)
+
+ module.router.set_capture_fn(_capture_fn)
+
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
"""
Add encoder-only layers to the KV cache config.
@@ -5481,7 +6665,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
- if has_ec_transfer() and get_ec_transfer().is_producer:
+ if has_ec_transfer() and not get_ec_transfer().is_consumer:
return {}
kv_cache_spec: dict[str, KVCacheSpec] = {}
layer_type = cast(type[Any], AttentionLayerBase)
@@ -5519,3 +6703,79 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
self.transfer_event.record()
self.transfer_event.synchronize()
return pinned.tolist()
+
+ def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
+ """
+ Get encoder timing stats for all requests and clear the registry.
+
+ Returns:
+ Dictionary mapping request_id to stats dict.
+ """
+ with self._encoder_timing_lock:
+ stats = {
+ req_id: stats_obj.to_dict()
+ for req_id, stats_obj in self.encoder_timing_registry.items()
+ }
+ self.encoder_timing_registry.clear()
+ return stats
+
+ @contextmanager
+ def timed_encoder_operation(
+ self,
+ should_time: bool,
+ group_lora_refs: list[tuple[str, Any]],
+ current_item_idx: int,
+ num_items: int,
+ ):
+ """
+ Context manager to time encoder forward operations.
+
+ Args:
+ should_time: Whether timing is enabled
+ group_lora_refs: Full list of (request_id, pos_info) tuples
+ current_item_idx: Starting index for this group
+ num_items: Number of items in this group
+ """
+ if not should_time:
+ yield
+ return
+
+ group_refs = group_lora_refs[current_item_idx : current_item_idx + num_items]
+ group_request_ids = {req_id for req_id, _ in group_refs}
+
+ torch.accelerator.synchronize()
+ start_time = time.perf_counter()
+
+ try:
+ yield
+ finally:
+ torch.accelerator.synchronize()
+ elapsed = time.perf_counter() - start_time
+
+ per_request_time = elapsed / max(len(group_request_ids), 1)
+
+ with self._encoder_timing_lock:
+ for req_id in group_request_ids:
+ if req_id not in self.encoder_timing_registry:
+ self.encoder_timing_registry[req_id] = EncoderTimingStats()
+
+ stats = self.encoder_timing_registry[req_id]
+ stats.encoder_forward_secs += per_request_time
+ stats.num_encoder_calls += 1
+
+
+@dataclass
+class EncoderTimingStats:
+ """Per-request timing statistics for encoder forward pass."""
+
+ encoder_forward_secs: float = 0.0
+ """Time spent in vision encoder forward pass (seconds)."""
+
+ num_encoder_calls: int = 0
+ """Number of times encoder was called for this request."""
+
+ def to_dict(self) -> dict[str, float | int]:
+ return {
+ "encoder_forward_secs": self.encoder_forward_secs,
+ "num_encoder_calls": self.num_encoder_calls,
+ }
diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py
index de6dbb0f..4cd7631f 100644
--- a/vllm_fl/worker/worker.py
+++ b/vllm_fl/worker/worker.py
@@ -1,5 +1,5 @@
# Copyright (c) 2025 BAAI. All rights reserved.
-# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/v1/worker/gpu_model_runner.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/v1/worker/gpu_model_runner.py
# Below is the original copyright:
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
@@ -16,7 +16,7 @@
import torch.distributed
import torch.nn as nn
-from vllm.config import CUDAGraphMode, VllmConfig
+from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.config.compilation import CompilationMode
from vllm.distributed import (
ensure_model_parallel_initialized,
@@ -47,7 +47,7 @@ def kernel_warmup(worker):
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
-from vllm.model_executor import set_random_seed
+from vllm.utils.torch_utils import set_random_seed
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.platforms import current_platform
from vllm.profiler.wrapper import TorchProfilerWrapper
@@ -206,11 +206,6 @@ def __init__(
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
- if self.model_config.trust_remote_code:
- # note: lazy import to avoid importing torch before initializing
- from vllm.utils.import_utils import init_cached_hf_modules
-
- init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
@@ -429,10 +424,10 @@ def init_device(self):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
- def load_model(self) -> None:
- eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
+ def load_model(self, *, load_dummy_weights: bool = False) -> None:
### TODO(lms): support manages a memory pool for device tensors.
- self.model_runner.load_model(eep_scale_up=eep_scale_up)
+ with set_current_vllm_config(self.vllm_config):
+ self.model_runner.load_model(load_dummy_weights=load_dummy_weights)
# with self._maybe_get_memory_pool_context(tag="weights"):
# self.model_runner.load_model(eep_scale_up=eep_scale_up)
@@ -546,6 +541,11 @@ def get_kv_connector_handshake_metadata(self) -> dict | None:
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
+ def update_max_model_len(self, max_model_len: int) -> None:
+ """Update max_model_len after auto-fit to GPU memory."""
+ self.model_config.max_model_len = max_model_len
+ self.model_runner.max_model_len = max_model_len
+
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
@@ -565,7 +565,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
# context = nullcontext()
self.model_runner.initialize_kv_cache(kv_cache_config)
- def compile_or_warm_up_model(self) -> None:
+ def compile_or_warm_up_model(self) -> float:
warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# warm up sizes that are not in cudagraph capture sizes,
@@ -688,15 +688,24 @@ def compile_or_warm_up_model(self) -> None:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
+ return self.compilation_config.compilation_time
+
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
+ def reset_encoder_cache(self) -> None:
+ self.model_runner.reset_encoder_cache()
+
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
+ def get_encoder_timing_stats(self) -> dict[str, dict[str, float | int]]:
+ """Get encoder timing stats from model runner."""
+ return self.model_runner.get_encoder_timing_stats()
+
def annotate_profile(self, scheduler_output):
# add trace annotation so that we can easily distinguish
# new/cached request numbers in each iteration
@@ -789,7 +798,7 @@ def execute_model(
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
- def profile(self, is_start: bool = True):
+ def profile(self, is_start: bool = True, profile_prefix: str | None = None):
if self.profiler is None:
raise RuntimeError("Profiling is not enabled.")
if is_start: