diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index e69de29b..084b9282 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project +# +# vLLM-FL C++ extensions - Root CMakeLists.txt + +cmake_minimum_required(VERSION 3.26) +project(vllm_fl_extensions LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# ============================================================================= +# Vendor Selection (REQUIRED - no auto-detection) +# ============================================================================= + +if(NOT DEFINED VLLM_VENDOR) + if(DEFINED ENV{VLLM_VENDOR}) + set(VLLM_VENDOR $ENV{VLLM_VENDOR}) + endif() +endif() + +if(NOT VLLM_VENDOR) + message(FATAL_ERROR + "VLLM_VENDOR is required but not specified.\n" + "Please set VLLM_VENDOR environment variable or cmake option:\n" + " export VLLM_VENDOR=cuda # For NVIDIA CUDA\n" + " export VLLM_VENDOR=ascend # For Huawei Ascend\n" + "\n" + "Or pass to cmake:\n" + " cmake -DVLLM_VENDOR=cuda .." + ) +endif() + +set(SUPPORTED_VENDORS cuda ascend) +if(NOT VLLM_VENDOR IN_LIST SUPPORTED_VENDORS) + message(FATAL_ERROR + "Unsupported vendor: ${VLLM_VENDOR}\n" + "Supported vendors: ${SUPPORTED_VENDORS}" + ) +endif() + +message(STATUS "==============================================") +message(STATUS "vLLM-FL Extensions: ${VLLM_VENDOR}") +message(STATUS "==============================================") + +# ============================================================================= +# Find Python +# ============================================================================= + +if(VLLM_PYTHON_EXECUTABLE) + set(Python_EXECUTABLE ${VLLM_PYTHON_EXECUTABLE}) +endif() + +find_package(Python REQUIRED COMPONENTS Interpreter Development.Module) +message(STATUS "Python: ${Python_EXECUTABLE} (${Python_VERSION})") + +# ============================================================================= +# Find PyTorch +# ============================================================================= + +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_CMAKE_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE +) +list(APPEND CMAKE_PREFIX_PATH ${TORCH_CMAKE_PREFIX}) + +find_package(Torch REQUIRED) +message(STATUS "PyTorch: ${Torch_VERSION}") + +# ============================================================================= +# Include directories +# ============================================================================= + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +# ============================================================================= +# Build Vendor Backend +# ============================================================================= + +add_subdirectory(${VLLM_VENDOR}) diff --git a/csrc/ascend/CMakeLists.txt b/csrc/ascend/CMakeLists.txt index e69de29b..cb79f42c 100644 --- a/csrc/ascend/CMakeLists.txt +++ b/csrc/ascend/CMakeLists.txt @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project +# +# Ascend backend for vLLM-FL + +# Ascend CANN toolkit +set(ASCEND_TOOLKIT_PATH "$ENV{ASCEND_TOOLKIT_HOME}") +if(NOT ASCEND_TOOLKIT_PATH) + set(ASCEND_TOOLKIT_PATH "/usr/local/Ascend/ascend-toolkit/latest") +endif() + +if(NOT EXISTS "${ASCEND_TOOLKIT_PATH}/include/acl/acl.h") + message(WARNING "Ascend CANN not found at ${ASCEND_TOOLKIT_PATH}. Skipping.") + return() +endif() + +message(STATUS "Ascend CANN: ${ASCEND_TOOLKIT_PATH}") + +# ============================================================================= +# Source files +# ============================================================================= + +set(VLLM_FL_ASCEND_SRCS + weak_ref_tensor.cpp + torch_bindings.cpp +) + +# ============================================================================= +# Define extension target +# ============================================================================= + +# Create Python extension module named _C +# This will be importable as: import vllm_fl._C +Python_add_library(_C MODULE WITH_SOABI ${VLLM_FL_ASCEND_SRCS}) + +# Set TORCH_EXTENSION_NAME so TORCH_LIBRARY_EXPAND works +target_compile_definitions(_C PRIVATE "-DTORCH_EXTENSION_NAME=_C") + +# Include directories +target_include_directories(_C PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${ASCEND_TOOLKIT_PATH}/include + ${TORCH_INCLUDE_DIRS} +) + +# Link libraries +get_filename_component(TORCH_LIB_DIR "${TORCH_LIBRARY}" DIRECTORY) +target_link_directories(_C PRIVATE + ${TORCH_LIB_DIR} + ${ASCEND_TOOLKIT_PATH}/lib64 +) + +target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES}) + +# C++ settings +set_target_properties(_C PROPERTIES + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON +) + +# ============================================================================= +# Install to vllm_fl package directory +# ============================================================================= + +install(TARGETS _C LIBRARY DESTINATION vllm_fl COMPONENT _C) diff --git a/csrc/ascend/torch_bindings.cpp b/csrc/ascend/torch_bindings.cpp new file mode 100644 index 00000000..94a5e95e --- /dev/null +++ b/csrc/ascend/torch_bindings.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project +// +// Ascend torch bindings for vLLM-FL operators + +#include +#include + +#include "registration.h" + +namespace vllm_fl { + +// Forward declarations of Ascend implementations +torch::Tensor weak_ref_tensor_ascend(const torch::Tensor& tensor); + +} // namespace vllm_fl + +// Register extension for Python import +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) + +// Define operators using the extension name +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); + ops.impl("weak_ref_tensor", c10::kPrivateUse1, &vllm_fl::weak_ref_tensor_ascend); +} diff --git a/csrc/ascend/weak_ref_tensor.cpp b/csrc/ascend/weak_ref_tensor.cpp new file mode 100644 index 00000000..d4a5c0c5 --- /dev/null +++ b/csrc/ascend/weak_ref_tensor.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2026 BAAI. All rights reserved. +// Ascend weak_ref_tensor implementation + +#include + +namespace vllm_fl { + torch::Tensor weak_ref_tensor_ascend(torch::Tensor& tensor) { + if (!tensor.is_privateuseone()) { + throw std::runtime_error("Tensor must be on NPU device"); + } + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + // Get tensor options (dtype, device) + auto options = tensor.options(); + // Create a new tensor from the raw data pointer + auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options); + return new_tensor; + } + +} + +} // namespace vllm_fl diff --git a/csrc/cuda/CMakeLists.txt b/csrc/cuda/CMakeLists.txt index e69de29b..455f5ef7 100644 --- a/csrc/cuda/CMakeLists.txt +++ b/csrc/cuda/CMakeLists.txt @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project +# +# CUDA backend for vLLM-FL + +find_package(CUDAToolkit REQUIRED) +enable_language(CUDA) + +message(STATUS "CUDA Toolkit: ${CUDAToolkit_VERSION}") + +# ============================================================================= +# Source files +# ============================================================================= + +set(VLLM_FL_CUDA_SRCS + weak_ref_tensor.cu + torch_bindings.cpp +) + +# ============================================================================= +# Define extension target +# ============================================================================= + +# Create Python extension module named _C +# This will be importable as: import vllm_fl._C +Python_add_library(_C MODULE WITH_SOABI ${VLLM_FL_CUDA_SRCS}) + +# Set TORCH_EXTENSION_NAME so TORCH_LIBRARY_EXPAND works +target_compile_definitions(_C PRIVATE "-DTORCH_EXTENSION_NAME=_C") + +# Include directories +target_include_directories(_C PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CUDAToolkit_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} +) + +# Link libraries +target_link_libraries(_C PRIVATE + torch + CUDA::cudart + CUDA::cuda_driver +) + +# CUDA settings +set_target_properties(_C PROPERTIES + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON +) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES "70;75;80;86;89;90") +endif() + +target_compile_options(_C PRIVATE + $<$:-O3 --use_fast_math> +) + +# ============================================================================= +# Install to vllm_fl package directory +# ============================================================================= + +install(TARGETS _C LIBRARY DESTINATION vllm_fl COMPONENT _C) diff --git a/csrc/cuda/torch_bindings.cpp b/csrc/cuda/torch_bindings.cpp new file mode 100644 index 00000000..327a66b7 --- /dev/null +++ b/csrc/cuda/torch_bindings.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project +// +// CUDA torch bindings for vLLM-FL operators + +#include +#include + +#include "registration.h" + +namespace vllm_fl { + +// Forward declarations of CUDA implementations +torch::Tensor weak_ref_tensor_cuda(torch::Tensor& tensor); + +} // namespace vllm_fl + +// Register extension for Python import +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) + +// Define operators using the extension name +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); + ops.impl("weak_ref_tensor", c10::kCUDA, &vllm_fl::weak_ref_tensor_cuda); + + // Add more operators here: + // ops.def("another_op(Tensor input) -> Tensor"); + // ops.impl("another_op", c10::kCUDA, &vllm_fl::another_op_cuda); +} diff --git a/csrc/cuda/weak_ref_tensor.cu b/csrc/cuda/weak_ref_tensor.cu new file mode 100644 index 00000000..93346a8c --- /dev/null +++ b/csrc/cuda/weak_ref_tensor.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2026 BAAI. All rights reserved. +// CUDA weak_ref_tensor implementation + +#include +#include + +namespace vllm_fl { + torch::Tensor weak_ref_tensor_cuda(torch::Tensor& tensor) { + // Ensure tensor is on CUDA + if (!tensor.is_cuda()) { + throw std::runtime_error("Tensor must be on CUDA device"); + } + + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + + // Get tensor options (dtype, device) + auto options = tensor.options(); + + // Create a new tensor from the raw data pointer + auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); + + return new_tensor; + } + +} // namespace vllm_fl diff --git a/csrc/registration.h b/csrc/registration.h new file mode 100644 index 00000000..4d0ce1c5 --- /dev/null +++ b/csrc/registration.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/setup.py b/setup.py index dc775e74..818281ad 100644 --- a/setup.py +++ b/setup.py @@ -1,63 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project + +import glob +import logging import os +import shutil +import subprocess +import sys +from pathlib import Path +from shutil import which +from typing import List + +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext + +ROOT_DIR = Path(__file__).parent.resolve() +logger = logging.getLogger(__name__) + +# ============================================================================= +# Environment Variables +# ============================================================================= + +VLLM_VENDOR = os.environ.get("VLLM_VENDOR", "").lower() +MAX_JOBS = os.environ.get("MAX_JOBS") +NVCC_THREADS = os.environ.get("NVCC_THREADS") +CMAKE_BUILD_TYPE = os.environ.get("CMAKE_BUILD_TYPE") +VERBOSE = os.environ.get("VERBOSE", "0") == "1" + +SUPPORTED_VENDORS = ["cuda", "ascend"] + + +# ============================================================================= +# Utility Functions +# ============================================================================= + +def is_sccache_available() -> bool: + return which("sccache") is not None + + +def is_ccache_available() -> bool: + return which("ccache") is not None + + +def is_ninja_available() -> bool: + return which("ninja") is not None + + +def _is_cuda() -> bool: + return VLLM_VENDOR == "cuda" + + +def _is_ascend() -> bool: + return VLLM_VENDOR == "ascend" + + +# ============================================================================= +# Version +# ============================================================================= + +def get_cuda_version() -> str: + """Detect CUDA version from nvcc.""" + try: + output = subprocess.check_output(["nvcc", "--version"], stderr=subprocess.STDOUT) + output = output.decode("utf-8") + # Parse "release X.Y" from nvcc output + import re + match = re.search(r"release (\d+)\.(\d+)", output) + if match: + major, minor = match.groups() + return f"cu{major}{minor}" + except Exception: + pass + return "cu" + + +def get_git_commit() -> str: + """Get the first 8 characters of git commit hash.""" + try: + output = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.STDOUT) + return output.decode("utf-8").strip()[:8] + except Exception: + return "unknown" + + +def get_build_date() -> str: + """Get current date in YYYYMMDD format.""" + from datetime import datetime + return datetime.now().strftime("%Y%m%d") + + +def get_vllm_fl_version() -> str: + version = "0.0.1.dev0" + commit_id = get_git_commit() + build_date = get_build_date() + + if VLLM_VENDOR == "cuda": + cuda_ver = get_cuda_version() + version += f"+g{commit_id}.{build_date}.{cuda_ver}" + elif VLLM_VENDOR: + version += f"+g{commit_id}.{build_date}.{VLLM_VENDOR}" + else: + version += f"+g{commit_id}.{build_date}" + + return version + + +# ============================================================================= +# CMake Extension +# ============================================================================= + +class CMakeExtension(Extension): + def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: + super().__init__(name, sources=[], **kwa) + self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) + + +class cmake_build_ext(build_ext): + """CMake build extension for vLLM-FL operators.""" + + did_config: dict = {} + + def run(self): + """Override run to skip the default extension copying.""" + self.build_extensions() -from setuptools import find_packages, setup + def compute_num_jobs(self): + """Compute number of parallel compilation jobs.""" + num_jobs = MAX_JOBS + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() or 1 -ROOT_DIR = os.path.dirname(__file__) -VERSION = "0.0.0" + nvcc_threads = None + if _is_cuda() and NVCC_THREADS is not None: + nvcc_threads = int(NVCC_THREADS) + logger.info( + "Using NVCC_THREADS=%d as the number of nvcc threads.", + nvcc_threads, + ) + num_jobs = max(1, num_jobs // nvcc_threads) + return num_jobs, nvcc_threads -def get_path(*filepath) -> str: - return os.path.join(ROOT_DIR, *filepath) + def configure(self, ext: CMakeExtension) -> None: + """Configure cmake for the extension.""" + if ext.cmake_lists_dir in cmake_build_ext.did_config: + return + cmake_build_ext.did_config[ext.cmake_lists_dir] = True + + # Build type + default_cfg = "Debug" if self.debug else "RelWithDebInfo" + cfg = CMAKE_BUILD_TYPE or default_cfg + + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={cfg}", + f"-DVLLM_VENDOR={VLLM_VENDOR}", + f"-DVLLM_PYTHON_EXECUTABLE={sys.executable}", + ] + + if VERBOSE: + cmake_args.append("-DCMAKE_VERBOSE_MAKEFILE=ON") + + # Compiler cache + if is_sccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", + ] + elif is_ccache_available(): + cmake_args += [ + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", + ] + + # Parallelism and build tool + num_jobs, nvcc_threads = self.compute_num_jobs() + + if nvcc_threads: + cmake_args.append(f"-DNVCC_THREADS={nvcc_threads}") + + if is_ninja_available(): + build_tool = ["-G", "Ninja"] + cmake_args += [ + "-DCMAKE_JOB_POOL_COMPILE:STRING=compile", + f"-DCMAKE_JOB_POOLS:STRING=compile={num_jobs}", + ] + else: + build_tool = [] + + # Additional cmake args from environment + extra_cmake_args = os.environ.get("CMAKE_ARGS") + if extra_cmake_args: + cmake_args += extra_cmake_args.split() + + print(f"Configuring CMake for vendor: {VLLM_VENDOR}") + print(f" Source dir: {ext.cmake_lists_dir}") + print(f" Build dir: {self.build_temp}") + + subprocess.check_call( + ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args], + cwd=self.build_temp, + ) + + def build_extensions(self) -> None: + """Build all extensions.""" + # Check CMake + try: + subprocess.check_output(["cmake", "--version"], stderr=subprocess.STDOUT) + except (OSError, subprocess.CalledProcessError) as e: + raise RuntimeError( + f"CMake not available or not working: {e}\n" + "Please install with:\n" + f" VLLM_VENDOR={VLLM_VENDOR} pip install --no-build-isolation -e ." + ) from e + + # Create build directory + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + # Configure and collect targets + targets = [] + for ext in self.extensions: + self.configure(ext) + target_name = ext.name.split(".")[-1] + targets.append(target_name) + + # Build + num_jobs, _ = self.compute_num_jobs() + build_args = [ + "--build", ".", + f"-j={num_jobs}", + *[f"--target={name}" for name in targets], + ] + + print(f"Building targets: {targets}") + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + + # Copy built extensions to where setuptools expects them + for ext in self.extensions: + # Get the full path where setuptools expects the extension + dest_path = Path(self.get_ext_fullpath(ext.name)).absolute() + dest_dir = dest_path.parent + + # Create destination directory if it doesn't exist + dest_dir.mkdir(parents=True, exist_ok=True) + + # Find the built .so file in the build directory + target_name = ext.name.split(".")[-1] + # Look for the .so file in various possible locations + so_patterns = [ + f"{self.build_temp}/{VLLM_VENDOR}/{target_name}*.so", + f"{self.build_temp}/{target_name}*.so", + ] + + built_so = None + for pattern in so_patterns: + matches = glob.glob(pattern) + if matches: + built_so = matches[0] + break + + if built_so is None: + raise RuntimeError( + f"Could not find built extension {target_name}.so in {self.build_temp}" + ) + + # Copy to destination with the correct name + print(f"Copying {built_so} to {dest_path}") + shutil.copy2(built_so, dest_path) + + +# ============================================================================= +# Package Configuration +# ============================================================================= def read_readme() -> str: """Read the README file if present.""" - p = get_path("README.md") - if os.path.isfile(p): - with open(get_path("README.md"), encoding="utf-8") as f: - return f.read() - else: - return "" + readme_path = ROOT_DIR / "README.md" + if readme_path.is_file(): + return readme_path.read_text(encoding="utf-8") + return "" -def get_requirements() -> list[str]: +def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" + requirements_path = ROOT_DIR / "requirements.txt" + if not requirements_path.is_file(): + logger.warning("requirements.txt not found") + return [] - def _read_requirements(filename: str) -> list[str]: - with open(get_path(filename)) as f: - requirements = f.read().strip().split("\n") - resolved_requirements = [] - for line in requirements: + def _read_requirements(filepath: Path) -> List[str]: + resolved = [] + for line in filepath.read_text().strip().split("\n"): + line = line.strip() + if not line or line.startswith("#"): + continue if line.startswith("-r "): - resolved_requirements += _read_requirements(line.split()[1]) + inc_file = filepath.parent / line.split()[1] + resolved += _read_requirements(inc_file) elif line.startswith("--"): continue else: - resolved_requirements.append(line) - return resolved_requirements + resolved.append(line) + return resolved try: - requirements = _read_requirements("requirements.txt") - except ValueError: - print("Failed to read requirements.txt in vllm_flagos.") - return requirements + return _read_requirements(requirements_path) + except Exception as e: + logger.warning(f"Failed to read requirements.txt: {e}") + return [] + + +# ============================================================================= +# Extension Modules +# ============================================================================= + +ext_modules = [] + +if VLLM_VENDOR: + if VLLM_VENDOR not in SUPPORTED_VENDORS: + raise ValueError( + f"Unsupported vendor: {VLLM_VENDOR}\n" + f"Supported vendors: {SUPPORTED_VENDORS}" + ) + csrc_dir = str(ROOT_DIR / "csrc") + # Extension name is vllm_fl._C - will be importable as `import vllm_fl._C` + ext_modules.append(CMakeExtension(name="vllm_fl._C", cmake_lists_dir=csrc_dir)) +# ============================================================================= +# Command Classes +# ============================================================================= + +if ext_modules: + cmdclass = {"build_ext": cmake_build_ext} +else: + cmdclass = {} + + +# ============================================================================= +# Setup +# ============================================================================= + setup( name="vllm_fl", - # Follow: - # https://packaging.python.org/en/latest/specifications/version-specifiers - version=VERSION, + version=get_vllm_fl_version(), author="vLLM-FL team", license="Apache 2.0", - description=("vLLM FL backend plugin"), - # long_description=read_readme(), + description="vLLM FL backend plugin with multi-vendor C++ operators", + long_description=read_readme(), long_description_content_type="text/markdown", url="https://github.com/flagos-ai/vllm-plugin-FL", project_urls={ "Homepage": "https://github.com/flagos-ai/vllm-plugin-FL", }, classifiers=[ + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -68,14 +374,29 @@ def _read_requirements(filename: str) -> list[str]: "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ], - packages=find_packages(exclude=("docs", "examples", "tests*")), - package_data={ - "vllm_fl.dispatch.config": ["*.yaml"], - }, + packages=find_packages(exclude=("docs", "examples", "tests*", "csrc")), python_requires=">=3.10", install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass=cmdclass, + extras_require={ + "dev": [ + "pytest>=7.0", + "pytest-asyncio", + "black", + "isort", + "mypy", + ], + }, entry_points={ "vllm.platform_plugins": ["fl = vllm_fl:register"], "vllm.general_plugins": ["fl = vllm_fl:register_model"], }, + package_data={ + "vllm_fl": [ + "*.so", + "dispatch/config/*.yaml", + ], + }, + include_package_data=True, ) diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index 909ec2d7..0b9a10ab 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -26,11 +26,10 @@ def weak_ref_tensors(tensor: Any) -> Any: - if current_platform.device_type == "cuda": + try: from vllm.utils.torch_utils import weak_ref_tensors return weak_ref_tensors(tensor) - else: - ### TODO: add csrc npu custom op + except: return tensor