-
Notifications
You must be signed in to change notification settings - Fork 42
upgrade vllm to 0.18.1 #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
d5bb6d6
a4f6351
c1fc8e7
3cfb835
5d9682a
8268235
136a3dc
838341f
a6a3b11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,14 +1,15 @@ | ||||||
| # Copyright (c) 2025 BAAI. All rights reserved. | ||||||
| # Adapted from https://github.com/vllm-project/vllm/blob/v0.11.0/vllm/compilation/cuda_graph.py | ||||||
| # Adapted from https://github.com/vllm-project/vllm/blob/v0.18.1/vllm/compilation/cuda_graph.py | ||||||
| # Below is the original copyright: | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|
|
||||||
| import dataclasses | ||||||
| import weakref | ||||||
| from collections import Counter | ||||||
| from collections.abc import Callable | ||||||
| from contextlib import ExitStack | ||||||
| from typing import Any, Optional | ||||||
| from typing import Any, ClassVar | ||||||
| from unittest.mock import patch | ||||||
|
|
||||||
| import torch | ||||||
|
|
@@ -18,13 +19,18 @@ | |||||
| from vllm.compilation.monitor import validate_cudagraph_capturing_enabled | ||||||
| from vllm.config import CUDAGraphMode, VllmConfig | ||||||
| from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id | ||||||
| from vllm.forward_context import BatchDescriptor, get_forward_context | ||||||
| from vllm.forward_context import ( | ||||||
| BatchDescriptor, | ||||||
| get_forward_context, | ||||||
| is_forward_context_available, | ||||||
| ) | ||||||
| from vllm.logger import init_logger | ||||||
| from vllm.platforms import current_platform | ||||||
|
|
||||||
| logger = init_logger(__name__) | ||||||
|
|
||||||
|
|
||||||
| # FL-specific: platform-agnostic weak_ref_tensors | ||||||
| def weak_ref_tensors(tensor: Any) -> Any: | ||||||
| if current_platform.device_type == "cuda": | ||||||
| from vllm.utils.torch_utils import weak_ref_tensors | ||||||
|
|
@@ -34,6 +40,7 @@ | |||||
| return tensor | ||||||
|
|
||||||
|
|
||||||
| # FL-specific: platform-agnostic graph class selection | ||||||
| class Graph: | ||||||
| if current_platform.device_type == "cuda": | ||||||
| graph = torch.cuda.CUDAGraph | ||||||
|
|
@@ -42,15 +49,21 @@ | |||||
| else: | ||||||
| raise NotImplementedError("not support graph") | ||||||
|
|
||||||
|
|
||||||
| # Re-export CUDAGraphStat for compatibility | ||||||
| from vllm.compilation.cuda_graph import CUDAGraphStat # noqa: F401, E402 | ||||||
Check noticeCode scanning / CodeQL Unused import Note
Import of 'CUDAGraphStat' is not used.
|
||||||
|
|
||||||
|
|
||||||
| @dataclasses.dataclass | ||||||
| class GraphEntry: | ||||||
| batch_descriptor: BatchDescriptor | ||||||
| graph: Optional[Graph] = None | ||||||
| output: Optional[Any] = None | ||||||
| graph: Any | None = None | ||||||
| output: Any | None = None | ||||||
|
|
||||||
| # for graph debugging, track the input addresses | ||||||
| # during capture, and check if they are the same during replay | ||||||
| input_addresses: Optional[list[int]] = None | ||||||
| input_addresses: list[int] | None = None | ||||||
|
|
||||||
|
|
||||||
| @dataclasses.dataclass | ||||||
| class GraphOptions: | ||||||
|
|
@@ -60,21 +73,33 @@ | |||||
|
|
||||||
|
|
||||||
| class GraphWrapper: | ||||||
| """FL-specific graph wrapper that supports multiple device types (CUDA, NPU). | ||||||
| Adapted from upstream CUDAGraphWrapper with platform-agnostic graph capture.""" | ||||||
|
|
||||||
| _all_instances: ClassVar[weakref.WeakSet["GraphWrapper"]] = weakref.WeakSet() | ||||||
|
|
||||||
| @classmethod | ||||||
| def clear_all_graphs(cls) -> None: | ||||||
| """Clear captured graphs from all GraphWrapper instances.""" | ||||||
| for instance in list(cls._all_instances): | ||||||
| instance.clear_graphs() | ||||||
|
|
||||||
| def __init__(self, | ||||||
| runnable: Callable, | ||||||
| vllm_config: VllmConfig, | ||||||
| runtime_mode: CUDAGraphMode, | ||||||
| cudagraph_options: Optional[GraphOptions] = None): | ||||||
| cudagraph_options: GraphOptions | None = None): | ||||||
| self.runnable = runnable | ||||||
| self.vllm_config = vllm_config | ||||||
| self.runtime_mode = runtime_mode | ||||||
| self.compilation_config = vllm_config.compilation_config | ||||||
|
|
||||||
| self.first_run_finished = False | ||||||
| self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" | ||||||
| self._runnable_str = str(runnable) if self.is_debugging_mode else None | ||||||
|
|
||||||
| # assert runtime_mode is not NONE(no cudagraph), otherwise, we don't | ||||||
| # need to initialize a CUDAGraphWrapper. | ||||||
| # need to initialize a GraphWrapper. | ||||||
| assert self.runtime_mode != CUDAGraphMode.NONE | ||||||
| # TODO: in the future, if we want to use multiple | ||||||
| # streams, it might not be safe to share a global pool. | ||||||
|
|
@@ -83,25 +108,41 @@ | |||||
|
|
||||||
| if cudagraph_options is None: | ||||||
| cudagraph_options = GraphOptions() | ||||||
| self.graph_options = cudagraph_options | ||||||
| self.cudagraph_options = cudagraph_options | ||||||
| # the entries for different batch descriptors that we need to capture | ||||||
| # cudagraphs for. | ||||||
| self.concrete_graph_entries: dict[BatchDescriptor, GraphEntry] = {} | ||||||
|
|
||||||
| def __getattr__(self, key: str): | ||||||
| GraphWrapper._all_instances.add(self) | ||||||
|
|
||||||
| def __getattr__(self, key: str) -> Any: | ||||||
| # allow accessing the attributes of the runnable. | ||||||
| if hasattr(self.runnable, key): | ||||||
| return getattr(self.runnable, key) | ||||||
| raise AttributeError( | ||||||
| f"Attribute {key} not exists in the runnable of " | ||||||
| f"cudagraph wrapper: {self.runnable}" | ||||||
| ) | ||||||
| if self.is_debugging_mode: | ||||||
| raise AttributeError( | ||||||
| f"Attribute {key} not exists in the runnable of " | ||||||
| f"cudagraph wrapper: {self._runnable_str}" | ||||||
| ) | ||||||
| raise AttributeError | ||||||
|
|
||||||
| def unwrap(self) -> Callable: | ||||||
| # in case we need to access the original runnable. | ||||||
| return self.runnable | ||||||
|
|
||||||
| @property | ||||||
| def cudagraph_wrapper(self) -> "GraphWrapper": | ||||||
| return self | ||||||
|
|
||||||
| def clear_graphs(self) -> None: | ||||||
| self.concrete_graph_entries.clear() | ||||||
|
|
||||||
| def __call__(self, *args, **kwargs): | ||||||
| if not is_forward_context_available(): | ||||||
| # No forward context means we are outside the normal | ||||||
| # inference path (e.g. a vision encoder forward pass). | ||||||
| return self.runnable(*args, **kwargs) | ||||||
|
|
||||||
| forward_context = get_forward_context() | ||||||
| batch_descriptor = forward_context.batch_descriptor | ||||||
| graph_runtime_mode = forward_context.cudagraph_runtime_mode | ||||||
|
|
@@ -110,14 +151,9 @@ | |||||
| graph_runtime_mode == CUDAGraphMode.NONE | ||||||
| or graph_runtime_mode != self.runtime_mode | ||||||
| ): | ||||||
| # CUDAGraphMode.NONE could mean the profile run, a warmup run, or | ||||||
| # running without cudagraphs. | ||||||
| # We do not trigger capture/replay if the runtime mode is not | ||||||
| # matches. This enables properly dispatching to the correct | ||||||
| # CUDAGraphWrapper when nesting multiple instances with different | ||||||
| # runtime modes. | ||||||
| return self.runnable(*args, **kwargs) | ||||||
|
|
||||||
| assert batch_descriptor is not None | ||||||
| if batch_descriptor not in self.concrete_graph_entries: | ||||||
| # create a new entry for this batch descriptor | ||||||
| self.concrete_graph_entries[batch_descriptor] = GraphEntry( | ||||||
|
|
@@ -127,11 +163,7 @@ | |||||
| entry = self.concrete_graph_entries[batch_descriptor] | ||||||
|
|
||||||
| if entry.graph is None: | ||||||
| if self.graph_options.debug_log_enable: | ||||||
| # Since we capture cudagraph for many different shapes and | ||||||
| # capturing is fast, we don't need to log it for every | ||||||
| # shape. E.g. we only log it for the first subgraph in | ||||||
| # piecewise mode. | ||||||
| if self.cudagraph_options.debug_log_enable: | ||||||
| logger.debug( | ||||||
| "Capturing a cudagraph on (%s,%s)", | ||||||
| self.runtime_mode.name, | ||||||
|
|
@@ -147,32 +179,40 @@ | |||||
| graph = Graph.graph() | ||||||
|
|
||||||
| with ExitStack() as stack: | ||||||
| if self.graph_options.gc_disable: | ||||||
| # during every model forward for piecewise graph | ||||||
| # mode, we will capture many pieces of graphs | ||||||
| # (roughly one per layer). running gc again and again | ||||||
| # across layers will make the graph capture very slow. | ||||||
| # therefore, we only run gc for the first graph, | ||||||
| # and disable gc for the rest of the graphs. | ||||||
| if self.cudagraph_options.gc_disable: | ||||||
| stack.enter_context(patch("gc.collect", lambda: None)) | ||||||
| # FL-specific: patch our platform's empty_cache | ||||||
| stack.enter_context( | ||||||
|
Comment on lines
183
to
187
|
||||||
| patch("vllm_fl.platform.PlatformFL.empty_cache", lambda: None) | ||||||
| patch("vllm_fl.platform.PlatformFL.empty_cache", | ||||||
| lambda: None) | ||||||
|
||||||
| lambda: None) | |
| lambda *args, **kwargs: None) |
Check notice
Code scanning / CodeQL
Empty except Note
Check notice
Code scanning / CodeQL
Empty except Note
Check notice
Code scanning / CodeQL
Empty except Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR title/description says upgrade to vLLM 0.18.0, but this change updates docs/source references to v0.18.1 (and several file headers also reference v0.18.1). Please align the stated target version (either update PR metadata to 0.18.1 or change the references back to 0.18.0) to avoid confusion about the required dependency version.