Skip to content

Commit 0c068f8

Browse files
cyberpioneerceci3
authored andcommitted
add weak_ref_tensor csrc
1 parent 2d4ca81 commit 0c068f8

10 files changed

Lines changed: 701 additions & 33 deletions

File tree

csrc/CMakeLists.txt

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project
3+
#
4+
# vLLM-FL C++ extensions - Root CMakeLists.txt
5+
6+
cmake_minimum_required(VERSION 3.26)
7+
project(vllm_fl_extensions LANGUAGES CXX)
8+
9+
set(CMAKE_CXX_STANDARD 17)
10+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
11+
12+
# =============================================================================
13+
# Vendor Selection (REQUIRED - no auto-detection)
14+
# =============================================================================
15+
16+
if(NOT DEFINED VLLM_VENDOR)
17+
if(DEFINED ENV{VLLM_VENDOR})
18+
set(VLLM_VENDOR $ENV{VLLM_VENDOR})
19+
endif()
20+
endif()
21+
22+
if(NOT VLLM_VENDOR)
23+
message(FATAL_ERROR
24+
"VLLM_VENDOR is required but not specified.\n"
25+
"Please set VLLM_VENDOR environment variable or cmake option:\n"
26+
" export VLLM_VENDOR=cuda # For NVIDIA CUDA\n"
27+
" export VLLM_VENDOR=ascend # For Huawei Ascend\n"
28+
"\n"
29+
"Or pass to cmake:\n"
30+
" cmake -DVLLM_VENDOR=cuda .."
31+
)
32+
endif()
33+
34+
set(SUPPORTED_VENDORS cuda ascend)
35+
if(NOT VLLM_VENDOR IN_LIST SUPPORTED_VENDORS)
36+
message(FATAL_ERROR
37+
"Unsupported vendor: ${VLLM_VENDOR}\n"
38+
"Supported vendors: ${SUPPORTED_VENDORS}"
39+
)
40+
endif()
41+
42+
message(STATUS "==============================================")
43+
message(STATUS "vLLM-FL Extensions: ${VLLM_VENDOR}")
44+
message(STATUS "==============================================")
45+
46+
# =============================================================================
47+
# Find Python
48+
# =============================================================================
49+
50+
if(VLLM_PYTHON_EXECUTABLE)
51+
set(Python_EXECUTABLE ${VLLM_PYTHON_EXECUTABLE})
52+
endif()
53+
54+
find_package(Python REQUIRED COMPONENTS Interpreter Development.Module)
55+
message(STATUS "Python: ${Python_EXECUTABLE} (${Python_VERSION})")
56+
57+
# =============================================================================
58+
# Find PyTorch
59+
# =============================================================================
60+
61+
execute_process(
62+
COMMAND ${Python_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
63+
OUTPUT_VARIABLE TORCH_CMAKE_PREFIX
64+
OUTPUT_STRIP_TRAILING_WHITESPACE
65+
)
66+
list(APPEND CMAKE_PREFIX_PATH ${TORCH_CMAKE_PREFIX})
67+
68+
find_package(Torch REQUIRED)
69+
message(STATUS "PyTorch: ${Torch_VERSION}")
70+
71+
# =============================================================================
72+
# Include directories
73+
# =============================================================================
74+
75+
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
76+
77+
# =============================================================================
78+
# Build Vendor Backend
79+
# =============================================================================
80+
81+
add_subdirectory(${VLLM_VENDOR})

csrc/ascend/CMakeLists.txt

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project
3+
#
4+
# Ascend backend for vLLM-FL
5+
6+
# Ascend CANN toolkit
7+
set(ASCEND_TOOLKIT_PATH "$ENV{ASCEND_TOOLKIT_HOME}")
8+
if(NOT ASCEND_TOOLKIT_PATH)
9+
set(ASCEND_TOOLKIT_PATH "/usr/local/Ascend/ascend-toolkit/latest")
10+
endif()
11+
12+
if(NOT EXISTS "${ASCEND_TOOLKIT_PATH}/include/acl/acl.h")
13+
message(WARNING "Ascend CANN not found at ${ASCEND_TOOLKIT_PATH}. Skipping.")
14+
return()
15+
endif()
16+
17+
message(STATUS "Ascend CANN: ${ASCEND_TOOLKIT_PATH}")
18+
19+
# =============================================================================
20+
# Source files
21+
# =============================================================================
22+
23+
set(VLLM_FL_ASCEND_SRCS
24+
weak_ref_tensor.cpp
25+
torch_bindings.cpp
26+
)
27+
28+
# =============================================================================
29+
# Define extension target
30+
# =============================================================================
31+
32+
# Create Python extension module named _C
33+
# This will be importable as: import vllm_fl._C
34+
Python_add_library(_C MODULE WITH_SOABI ${VLLM_FL_ASCEND_SRCS})
35+
36+
# Set TORCH_EXTENSION_NAME so TORCH_LIBRARY_EXPAND works
37+
target_compile_definitions(_C PRIVATE "-DTORCH_EXTENSION_NAME=_C")
38+
39+
# Include directories
40+
target_include_directories(_C PRIVATE
41+
${CMAKE_CURRENT_SOURCE_DIR}/..
42+
${ASCEND_TOOLKIT_PATH}/include
43+
${TORCH_INCLUDE_DIRS}
44+
)
45+
46+
# Link libraries
47+
get_filename_component(TORCH_LIB_DIR "${TORCH_LIBRARY}" DIRECTORY)
48+
target_link_directories(_C PRIVATE
49+
${TORCH_LIB_DIR}
50+
${ASCEND_TOOLKIT_PATH}/lib64
51+
)
52+
53+
target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES})
54+
55+
# C++ settings
56+
set_target_properties(_C PROPERTIES
57+
CXX_STANDARD 17
58+
CXX_STANDARD_REQUIRED ON
59+
)
60+
61+
# =============================================================================
62+
# Install to vllm_fl package directory
63+
# =============================================================================
64+
65+
install(TARGETS _C LIBRARY DESTINATION vllm_fl COMPONENT _C)

csrc/ascend/torch_bindings.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project
3+
//
4+
// Ascend torch bindings for vLLM-FL operators
5+
6+
#include <torch/torch.h>
7+
#include <torch/library.h>
8+
9+
#include "registration.h"
10+
11+
namespace vllm_fl {
12+
13+
// Forward declarations of Ascend implementations
14+
torch::Tensor weak_ref_tensor_ascend(const torch::Tensor& tensor);
15+
16+
} // namespace vllm_fl
17+
18+
// Register extension for Python import
19+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
20+
21+
// Define operators using the extension name
22+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
23+
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
24+
ops.impl("weak_ref_tensor", c10::kPrivateUse1, &vllm_fl::weak_ref_tensor_ascend);
25+
}

csrc/ascend/weak_ref_tensor.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2026 BAAI. All rights reserved.
2+
// Ascend weak_ref_tensor implementation
3+
4+
#include <torch/torch.h>
5+
6+
namespace vllm_fl {
7+
torch::Tensor weak_ref_tensor_ascend(torch::Tensor& tensor) {
8+
if (!tensor.is_privateuseone()) {
9+
throw std::runtime_error("Tensor must be on NPU device");
10+
}
11+
// Get the raw data pointer
12+
void* data_ptr = tensor.data_ptr();
13+
// Get tensor sizes and strides
14+
std::vector<int64_t> sizes = tensor.sizes().vec();
15+
std::vector<int64_t> strides = tensor.strides().vec();
16+
// Get tensor options (dtype, device)
17+
auto options = tensor.options();
18+
// Create a new tensor from the raw data pointer
19+
auto new_tensor = at_npu::native::from_blob(data_ptr, sizes, strides, options);
20+
return new_tensor;
21+
}
22+
23+
}
24+
25+
} // namespace vllm_fl

csrc/cuda/CMakeLists.txt

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project
3+
#
4+
# CUDA backend for vLLM-FL
5+
6+
find_package(CUDAToolkit REQUIRED)
7+
enable_language(CUDA)
8+
9+
message(STATUS "CUDA Toolkit: ${CUDAToolkit_VERSION}")
10+
11+
# =============================================================================
12+
# Source files
13+
# =============================================================================
14+
15+
set(VLLM_FL_CUDA_SRCS
16+
weak_ref_tensor.cu
17+
torch_bindings.cpp
18+
)
19+
20+
# =============================================================================
21+
# Define extension target
22+
# =============================================================================
23+
24+
# Create Python extension module named _C
25+
# This will be importable as: import vllm_fl._C
26+
Python_add_library(_C MODULE WITH_SOABI ${VLLM_FL_CUDA_SRCS})
27+
28+
# Set TORCH_EXTENSION_NAME so TORCH_LIBRARY_EXPAND works
29+
target_compile_definitions(_C PRIVATE "-DTORCH_EXTENSION_NAME=_C")
30+
31+
# Include directories
32+
target_include_directories(_C PRIVATE
33+
${CMAKE_CURRENT_SOURCE_DIR}/..
34+
${CUDAToolkit_INCLUDE_DIRS}
35+
${TORCH_INCLUDE_DIRS}
36+
)
37+
38+
# Link libraries
39+
target_link_libraries(_C PRIVATE
40+
torch
41+
CUDA::cudart
42+
CUDA::cuda_driver
43+
)
44+
45+
# CUDA settings
46+
set_target_properties(_C PROPERTIES
47+
CUDA_STANDARD 17
48+
CUDA_STANDARD_REQUIRED ON
49+
)
50+
51+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
52+
set(CMAKE_CUDA_ARCHITECTURES "70;75;80;86;89;90")
53+
endif()
54+
55+
target_compile_options(_C PRIVATE
56+
$<$<COMPILE_LANGUAGE:CUDA>:-O3 --use_fast_math>
57+
)
58+
59+
# =============================================================================
60+
# Install to vllm_fl package directory
61+
# =============================================================================
62+
63+
install(TARGETS _C LIBRARY DESTINATION vllm_fl COMPONENT _C)

csrc/cuda/torch_bindings.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright contributors to the vLLM-FL project
3+
//
4+
// CUDA torch bindings for vLLM-FL operators
5+
6+
#include <torch/torch.h>
7+
#include <torch/library.h>
8+
9+
#include "registration.h"
10+
11+
namespace vllm_fl {
12+
13+
// Forward declarations of CUDA implementations
14+
torch::Tensor weak_ref_tensor_cuda(torch::Tensor& tensor);
15+
16+
} // namespace vllm_fl
17+
18+
// Register extension for Python import
19+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
20+
21+
// Define operators using the extension name
22+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
23+
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
24+
ops.impl("weak_ref_tensor", c10::kCUDA, &vllm_fl::weak_ref_tensor_cuda);
25+
26+
// Add more operators here:
27+
// ops.def("another_op(Tensor input) -> Tensor");
28+
// ops.impl("another_op", c10::kCUDA, &vllm_fl::another_op_cuda);
29+
}

csrc/cuda/weak_ref_tensor.cu

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2026 BAAI. All rights reserved.
2+
// CUDA weak_ref_tensor implementation
3+
4+
#include <torch/torch.h>
5+
#include <c10/cuda/CUDAGuard.h>
6+
7+
namespace vllm_fl {
8+
torch::Tensor weak_ref_tensor_cuda(torch::Tensor& tensor) {
9+
// Ensure tensor is on CUDA
10+
if (!tensor.is_cuda()) {
11+
throw std::runtime_error("Tensor must be on CUDA device");
12+
}
13+
14+
// Get the raw data pointer
15+
void* data_ptr = tensor.data_ptr();
16+
17+
// Get tensor sizes and strides
18+
std::vector<int64_t> sizes = tensor.sizes().vec();
19+
std::vector<int64_t> strides = tensor.strides().vec();
20+
21+
// Get tensor options (dtype, device)
22+
auto options = tensor.options();
23+
24+
// Create a new tensor from the raw data pointer
25+
auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);
26+
27+
return new_tensor;
28+
}
29+
30+
} // namespace vllm_fl

csrc/registration.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
3+
#include <Python.h>
4+
5+
#define _CONCAT(A, B) A##B
6+
#define CONCAT(A, B) _CONCAT(A, B)
7+
8+
#define _STRINGIFY(A) #A
9+
#define STRINGIFY(A) _STRINGIFY(A)
10+
11+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12+
// could be a macro instead of a literal token.
13+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14+
15+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16+
// could be a macro instead of a literal token.
17+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19+
20+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
21+
// via python's import statement.
22+
#define REGISTER_EXTENSION(NAME) \
23+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
26+
return PyModule_Create(&module); \
27+
}

0 commit comments

Comments
 (0)