Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n

### Setup

1. Install vllm from the official [v0.13.0](https://github.com/vllm-project/vllm/tree/v0.13.0) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL).
1. Install vllm from the official [v0.18.1](https://github.com/vllm-project/vllm/tree/v0.18.1) (optional if the correct version is installed) or from the fork [vllm-FL](https://github.com/flagos-ai/vllm-FL).

Comment on lines 35 to 39
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR title/description says upgrade to vLLM 0.18.0, but this change updates docs/source references to v0.18.1 (and several file headers also reference v0.18.1). Please align the stated target version (either update PR metadata to 0.18.1 or change the references back to 0.18.0) to avoid confusion about the required dependency version.

Copilot uses AI. Check for mistakes.

2. Install vllm-plugin-FL
Expand Down Expand Up @@ -65,6 +65,7 @@ In theory, vllm-plugin-FL can support all models available in vLLM, as long as n

```sh
git clone https://github.com/flagos-ai/FlagGems
git checkout v5.0.0
cd FlagGems
pip install --no-build-isolation .
# or editble install
Expand Down
66 changes: 5 additions & 61 deletions vllm_fl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,74 +45,18 @@ def register():


def register_model():
"""Register the FL model."""
from vllm import ModelRegistry
import vllm.model_executor.models.qwen3_next as qwen3_next_module
"""Register FL-specific models not yet upstream."""
# Models now upstream in vLLM v0.18.1 (no longer need plugin registration):
# Qwen3NextForCausalLM, Qwen3_5MoeForConditionalGeneration,
# MiniCPMO, KimiK25ForConditionalGeneration, Qwen3_5MoeConfig

# Register Qwen3.5 MoE config
try:
from vllm.transformers_utils.config import _CONFIG_REGISTRY
from vllm_fl.configs.qwen3_5_moe import Qwen3_5MoeConfig
_CONFIG_REGISTRY["qwen3_5_moe"] = Qwen3_5MoeConfig
except Exception as e:
logger.error(f"Register Qwen3.5 MoE config error: {str(e)}")

# Register Qwen3Next model
try:
from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401

qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM
logger.warning(
"Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, "
"original vLLM implementation is overridden"
)

ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_fl.models.qwen3_next:Qwen3NextForCausalLM"
)
except Exception as e:
logger.error(f"Register Qwen3Next model error: {str(e)}")

# Register Qwen3.5 MoE model
try:
ModelRegistry.register_model(
"Qwen3_5MoeForConditionalGeneration",
"vllm_fl.models.qwen3_5:Qwen3_5MoeForConditionalGeneration"
)
except Exception as e:
logger.error(f"Register Qwen3.5 MoE model error: {str(e)}")

# Register MiniCPMO model
try:
ModelRegistry.register_model(
"MiniCPMO",
"vllm_fl.models.minicpmo:MiniCPMO"
)
except Exception as e:
logger.error(f"Register MiniCPMO model error: {str(e)}")

# Register Kimi-K2.5 model
try:
ModelRegistry.register_model(
"KimiK25ForConditionalGeneration",
"vllm_fl.models.kimi_k25:KimiK25ForConditionalGeneration",
)
except Exception as e:
logger.error(f"Register KimiK25 model error: {str(e)}")

# Register GLM-5 (GlmMoeDsa) model
# Register GLM-5 (GlmMoeDsa) — config not yet upstream
try:
from vllm.transformers_utils.config import _CONFIG_REGISTRY
from vllm_fl.configs.glm_moe_dsa import GlmMoeDsaConfig
_CONFIG_REGISTRY["glm_moe_dsa"] = GlmMoeDsaConfig

from vllm_fl.patches.glm_moe_dsa import apply_model_patches as glm5_model
glm5_model()

ModelRegistry.register_model(
"GlmMoeDsaForCausalLM",
"vllm_fl.models.glm_moe_dsa:GlmMoeDsaForCausalLM"
)
except Exception as e:
logger.error(f"Register GlmMoeDsa model error: {str(e)}")
4 changes: 2 additions & 2 deletions vllm_fl/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def patch_mm_encoder_attention():
FLASH_ATTN branch to import directly from vllm.vllm_flash_attn with a
fallback to flash_attn.
"""
import vllm.attention.layers.mm_encoder_attention as mm_mod
from vllm.attention.backends.registry import AttentionBackendEnum
import vllm.model_executor.layers.attention.mm_encoder_attention as mm_mod
from vllm.v1.attention.backends.registry import AttentionBackendEnum

def _patched_maybe_get_vit_flash_attn_backend(attn_backend):
if attn_backend == AttentionBackendEnum.FLASH_ATTN:
Expand Down
137 changes: 92 additions & 45 deletions vllm_fl/compilation/graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright (c) 2025 BAAI. All rights reserved.
# Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/compilation/cuda_graph.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/compilation/cuda_graph.py
# Below is the original copyright:
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import weakref
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any, Optional
from typing import Any, ClassVar
from unittest.mock import patch

import torch
Expand All @@ -18,13 +19,18 @@
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.forward_context import (
BatchDescriptor,
get_forward_context,
is_forward_context_available,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)


# FL-specific: platform-agnostic weak_ref_tensors
def weak_ref_tensors(tensor: Any) -> Any:
if current_platform.device_type == "cuda":
from vllm.utils.torch_utils import weak_ref_tensors
Expand All @@ -34,6 +40,7 @@
return tensor


# FL-specific: platform-agnostic graph class selection
class Graph:
if current_platform.device_type == "cuda":
graph = torch.cuda.CUDAGraph
Expand All @@ -42,15 +49,21 @@
else:
raise NotImplementedError("not support graph")


# Re-export CUDAGraphStat for compatibility
from vllm.compilation.cuda_graph import CUDAGraphStat # noqa: F401, E402

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'CUDAGraphStat' is not used.


@dataclasses.dataclass
class GraphEntry:
batch_descriptor: BatchDescriptor
graph: Optional[Graph] = None
output: Optional[Any] = None
graph: Any | None = None
output: Any | None = None

# for graph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
input_addresses: list[int] | None = None


@dataclasses.dataclass
class GraphOptions:
Expand All @@ -60,21 +73,33 @@


class GraphWrapper:
"""FL-specific graph wrapper that supports multiple device types (CUDA, NPU).
Adapted from upstream CUDAGraphWrapper with platform-agnostic graph capture."""

_all_instances: ClassVar[weakref.WeakSet["GraphWrapper"]] = weakref.WeakSet()

@classmethod
def clear_all_graphs(cls) -> None:
"""Clear captured graphs from all GraphWrapper instances."""
for instance in list(cls._all_instances):
instance.clear_graphs()

def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: Optional[GraphOptions] = None):
cudagraph_options: GraphOptions | None = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config

self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
self._runnable_str = str(runnable) if self.is_debugging_mode else None

# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
# need to initialize a GraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
Expand All @@ -83,25 +108,41 @@

if cudagraph_options is None:
cudagraph_options = GraphOptions()
self.graph_options = cudagraph_options
self.cudagraph_options = cudagraph_options
# the entries for different batch descriptors that we need to capture
# cudagraphs for.
self.concrete_graph_entries: dict[BatchDescriptor, GraphEntry] = {}

def __getattr__(self, key: str):
GraphWrapper._all_instances.add(self)

def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self.runnable}"
)
if self.is_debugging_mode:
raise AttributeError(
f"Attribute {key} not exists in the runnable of "
f"cudagraph wrapper: {self._runnable_str}"
)
raise AttributeError

def unwrap(self) -> Callable:
# in case we need to access the original runnable.
return self.runnable

@property
def cudagraph_wrapper(self) -> "GraphWrapper":
return self

def clear_graphs(self) -> None:
self.concrete_graph_entries.clear()

def __call__(self, *args, **kwargs):
if not is_forward_context_available():
# No forward context means we are outside the normal
# inference path (e.g. a vision encoder forward pass).
return self.runnable(*args, **kwargs)

forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
graph_runtime_mode = forward_context.cudagraph_runtime_mode
Expand All @@ -110,14 +151,9 @@
graph_runtime_mode == CUDAGraphMode.NONE
or graph_runtime_mode != self.runtime_mode
):
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without cudagraphs.
# We do not trigger capture/replay if the runtime mode is not
# matches. This enables properly dispatching to the correct
# CUDAGraphWrapper when nesting multiple instances with different
# runtime modes.
return self.runnable(*args, **kwargs)

assert batch_descriptor is not None
if batch_descriptor not in self.concrete_graph_entries:
# create a new entry for this batch descriptor
self.concrete_graph_entries[batch_descriptor] = GraphEntry(
Expand All @@ -127,11 +163,7 @@
entry = self.concrete_graph_entries[batch_descriptor]

if entry.graph is None:
if self.graph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
if self.cudagraph_options.debug_log_enable:
logger.debug(
"Capturing a cudagraph on (%s,%s)",
self.runtime_mode.name,
Expand All @@ -147,32 +179,40 @@
graph = Graph.graph()

with ExitStack() as stack:
if self.graph_options.gc_disable:
# during every model forward for piecewise graph
# mode, we will capture many pieces of graphs
# (roughly one per layer). running gc again and again
# across layers will make the graph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
if self.cudagraph_options.gc_disable:
stack.enter_context(patch("gc.collect", lambda: None))
# FL-specific: patch our platform's empty_cache
stack.enter_context(
Comment on lines 183 to 187
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ExitStack scope ends before graph capture starts, so the gc_disable patches (gc.collect / PlatformFL.empty_cache) are not active during capture. Move the capture setup inside the ExitStack so the patches apply for the whole capture.

Copilot uses AI. Check for mistakes.
patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None)
patch("vllm_fl.platform.PlatformFL.empty_cache",
lambda: None)
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patched replacement for PlatformFL.empty_cache is lambda: None. Since empty_cache is normally a @classmethod and may be invoked via the current_platform instance, the patched function can receive an implicit argument and raise TypeError. Use a replacement that accepts *args/**kwargs (or wrap with classmethod) to keep the call signature compatible.

Suggested change
lambda: None)
lambda *args, **kwargs: None)

Copilot uses AI. Check for mistakes.
)

set_graph_pool_id(self.graph_pool)

# mind-exploding: carefully manage the reference and memory.
with current_platform.torch_device_fn.graph(graph, pool=self.graph_pool):
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())

# Sync offloader's copy stream before capture if available.
try:
from vllm.model_executor.offloader.base import get_offloader
get_offloader().sync_prev_onload()
except (ImportError, RuntimeError):

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

# FL-specific: use platform-agnostic graph capture
with current_platform.torch_device_fn.graph(
graph, pool=self.graph_pool
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.graph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)
# Join offloader's copy stream after forward if available
try:
from vllm.model_executor.offloader.base import get_offloader
get_offloader().join_after_forward()
except (ImportError, RuntimeError):

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
if self.cudagraph_options.weak_ref_output:
output = weak_ref_tensors(output)

entry.output = weak_ref_tensors(output)
entry.graph = graph
Expand All @@ -195,6 +235,13 @@
f"got {new_input_addresses}"
)

# Sync offloader before replay if available
try:
from vllm.model_executor.offloader.base import get_offloader
get_offloader().sync_prev_onload()
except (ImportError, RuntimeError):

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

current_platform.torch_device_fn.synchronize()
entry.graph.replay()
return entry.output
Loading
Loading