From 82968d4de568652629e27342d6b60c85adc29c36 Mon Sep 17 00:00:00 2001 From: Bruce Xue Date: Tue, 31 Dec 2024 09:55:32 +0800 Subject: [PATCH] sperate sgl-kernel with amd backend --- python/sglang/srt/server.py | 5 +- sgl-kernel/amd/CMakeLists.txt | 51 +++++ sgl-kernel/amd/LICENSE | 201 ++++++++++++++++++ sgl-kernel/amd/Makefile | 22 ++ sgl-kernel/amd/setup.py | 103 +++++++++ sgl-kernel/amd/src/sgl-kernel/__init__.py | 7 + .../src/sgl-kernel/csrc/moe_align_kernel.cu | 135 ++++++++++++ .../amd/src/sgl-kernel/csrc/sgl_kernel_ops.cu | 11 + sgl-kernel/amd/src/sgl-kernel/csrc/utils.hpp | 36 ++++ sgl-kernel/amd/src/sgl-kernel/ops/__init__.py | 23 ++ sgl-kernel/setup.py | 100 +++------ 11 files changed, 625 insertions(+), 69 deletions(-) create mode 100644 sgl-kernel/amd/CMakeLists.txt create mode 100644 sgl-kernel/amd/LICENSE create mode 100644 sgl-kernel/amd/Makefile create mode 100644 sgl-kernel/amd/setup.py create mode 100644 sgl-kernel/amd/src/sgl-kernel/__init__.py create mode 100644 sgl-kernel/amd/src/sgl-kernel/csrc/moe_align_kernel.cu create mode 100644 sgl-kernel/amd/src/sgl-kernel/csrc/sgl_kernel_ops.cu create mode 100644 sgl-kernel/amd/src/sgl-kernel/csrc/utils.hpp create mode 100644 sgl-kernel/amd/src/sgl-kernel/ops/__init__.py diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d95ce5931b5..15934b9603d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -578,8 +578,9 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - if "GLOO_SOCKET_IFNAME" not in os.environ: - os.environ["GLOO_SOCKET_IFNAME"] = "eth0" + #TODO(fix socket error with gpu backend) + #if "GLOO_SOCKET_IFNAME" not in os.environ: + # os.environ["GLOO_SOCKET_IFNAME"] = "eth0" # Set prometheus env vars if server_args.enable_metrics: diff --git a/sgl-kernel/amd/CMakeLists.txt b/sgl-kernel/amd/CMakeLists.txt new file mode 100644 index 00000000000..974018c78e1 --- /dev/null +++ b/sgl-kernel/amd/CMakeLists.txt @@ -0,0 +1,51 @@ +cmake_minimum_required(VERSION 3.18) +project(sgl-kernel LANGUAGES CXX CUDA) + +# Basic settings +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + +# Find PyTorch +execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_CMAKE_PATH + OUTPUT_STRIP_TRAILING_WHITESPACE +) +list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") + +find_package(Torch REQUIRED) + +# Warp Reduce library +add_library(_kernels SHARE + src/sgl-kernel/csrc/moe_align_kernel.cu + src/sgl-kernel/csrc/sgl_kernel_ops.cu +) + +target_include_directories(_kernels + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc + ${CUDA_INCLUDE_DIRS} + ${TORCH_INCLUDE_DIRS} +) + +target_link_libraries(_kernels + PRIVATE + ${TORCH_LIBRARIES} + Python3::Python +) + +# Set common properties for both libraries +foreach(target _kernels) + set_target_properties(${target} PROPERTIES + CUDA_SEPARABLE_COMPILATION ON + POSITION_INDEPENDENT_CODE ON + CUDA_RESOLVE_DEVICE_SYMBOLS ON + PREFIX "" + SUFFIX ".so" + ) +endforeach() diff --git a/sgl-kernel/amd/LICENSE b/sgl-kernel/amd/LICENSE new file mode 100644 index 00000000000..9c422689c8f --- /dev/null +++ b/sgl-kernel/amd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-2024 SGLang Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sgl-kernel/amd/Makefile b/sgl-kernel/amd/Makefile new file mode 100644 index 00000000000..7a041b1ed40 --- /dev/null +++ b/sgl-kernel/amd/Makefile @@ -0,0 +1,22 @@ +.PHONY: tree ln install build clean test format + +tree: + @tree --prune -I "__pycache__|*.egg-info|*.so|build" + +ln: + @rm -rf build && cmake . -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_CUDA_COMPILER=nvcc -B build && rm -rf compile_commands.json && ln -s build/compile_commands.json compile_commands.json + +install: + @pip install -e . + +build: + @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel + +clean: + @rm -rf build dist *.egg-info + +test: + @pytest tests/ + +format: + @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/amd/setup.py b/sgl-kernel/amd/setup.py new file mode 100644 index 00000000000..a025a8a5332 --- /dev/null +++ b/sgl-kernel/amd/setup.py @@ -0,0 +1,103 @@ +import os +import shutil +import zipfile +from pathlib import Path + +import torch + + +def is_hip() -> bool: + """Return whether it is HIP on the AMD ROCm platform.""" + return torch.version.hip is not None + + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +root = Path(__file__).parent.resolve() + + +def get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +def rename_wheel(): + if not os.environ.get("CUDA_VERSION"): + return + cuda_version = os.environ["CUDA_VERSION"].replace(".", "") + base_version = get_version() + + wheel_dir = Path("dist") + old_wheel = next(wheel_dir.glob("*.whl")) + tmp_dir = wheel_dir / "tmp" + tmp_dir.mkdir(exist_ok=True) + + with zipfile.ZipFile(old_wheel, "r") as zip_ref: + zip_ref.extractall(tmp_dir) + + old_info = tmp_dir / f"sgl_kernel-{base_version}.dist-info" + new_info = tmp_dir / f"sgl_kernel-{base_version}.post0+cu{cuda_version}.dist-info" + old_info.rename(new_info) + + platform = "manylinux2014_x86_64" + new_wheel = wheel_dir / old_wheel.name.replace("linux_x86_64", platform) + new_wheel = wheel_dir / new_wheel.name.replace( + base_version, f"{base_version}.post0+cu{cuda_version}" + ) + + with zipfile.ZipFile(new_wheel, "w", zipfile.ZIP_DEFLATED) as new_zip: + for file_path in tmp_dir.rglob("*"): + if file_path.is_file(): + new_zip.write(file_path, file_path.relative_to(tmp_dir)) + + old_wheel.unlink() + shutil.rmtree(tmp_dir) + + +def update_wheel_platform_tag(): + wheel_dir = Path("dist") + old_wheel = next(wheel_dir.glob("*.whl")) + new_wheel = wheel_dir / old_wheel.name.replace( + "linux_x86_64", "manylinux2014_x86_64" + ) + old_wheel.rename(new_wheel) + + +hipcc_flags = [ + "-D__HIP_PLATFORM_AMD__=1", + "--amdgpu-target=gfx90a,gfx940,gfx941,gfx942", +] +ext_modules=[ + CUDAExtension( + "sgl_kernel.ops._kernels", + [ + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + ], + extra_compile_args={ + "nvcc": hipcc_flags + + [ + "-O3", + "-fPIC", + ], + "cxx": ["-O3"], + }, + libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"], + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), +] + +setup( + name="sgl-kernel", + version=get_version(), + packages=["sgl_kernel"], + package_dir={"": "src"}, + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, + install_requires=["torch"], +) + +update_wheel_platform_tag() diff --git a/sgl-kernel/amd/src/sgl-kernel/__init__.py b/sgl-kernel/amd/src/sgl-kernel/__init__.py new file mode 100644 index 00000000000..557a9712ac2 --- /dev/null +++ b/sgl-kernel/amd/src/sgl-kernel/__init__.py @@ -0,0 +1,7 @@ +from sgl_kernel.ops import ( + moe_align_block_size, +) + +__all__ = [ + "moe_align_block_size", +] diff --git a/sgl-kernel/amd/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/amd/src/sgl-kernel/csrc/moe_align_kernel.cu new file mode 100644 index 00000000000..dfd28032fb3 --- /dev/null +++ b/sgl-kernel/amd/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -0,0 +1,135 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu + +#include +#include +#include + +#include + +#include "utils.hpp" + +#ifdef USE_ROCM +#include +#endif + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM +#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else +#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif + +#define CEILDIV(x, y) (((x) + (y)-1) / (y)) + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + auto kernel = moe_align_block_size_kernel; + kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), + token_cnts_buffer.data_ptr(), cumsum_buffer.data_ptr()); + }); +} diff --git a/sgl-kernel/amd/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/amd/src/sgl-kernel/csrc/sgl_kernel_ops.cu new file mode 100644 index 00000000000..59eb8a369fd --- /dev/null +++ b/sgl-kernel/amd/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -0,0 +1,11 @@ +#include "utils.hpp" + +// moe_align_block_size +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // moe_align_block_size + m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); +} diff --git a/sgl-kernel/amd/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/amd/src/sgl-kernel/csrc/utils.hpp new file mode 100644 index 00000000000..bbdc6311be9 --- /dev/null +++ b/sgl-kernel/amd/src/sgl-kernel/csrc/utils.hpp @@ -0,0 +1,36 @@ +#pragma once +#include + +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) diff --git a/sgl-kernel/amd/src/sgl-kernel/ops/__init__.py b/sgl-kernel/amd/src/sgl-kernel/ops/__init__.py new file mode 100644 index 00000000000..6773acf44d6 --- /dev/null +++ b/sgl-kernel/amd/src/sgl-kernel/ops/__init__.py @@ -0,0 +1,23 @@ +from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size + + +def moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, +): + _moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index f52a85d377b..a8fc9d82737 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -3,14 +3,6 @@ import zipfile from pathlib import Path -import torch - - -def is_hip() -> bool: - """Return whether it is HIP on the AMD ROCm platform.""" - return torch.version.hip is not None - - from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension @@ -66,64 +58,38 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) -if not is_hip(): - nvcc_flags = [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - ] - cxx_flags = ["-O3"] - libraries = ["c10", "torch", "torch_python"] - extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] - ext_modules = [ - CUDAExtension( - name="sgl_kernel.ops._kernels", - sources=[ - "src/sgl-kernel/csrc/warp_reduce_kernel.cu", - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - ], - extra_compile_args={ - "nvcc": nvcc_flags, - "cxx": cxx_flags, - }, - libraries=libraries, - extra_link_args=extra_link_args, - ), - ] -else: - hipcc_flags = [ - "-D__HIP_PLATFORM_AMD__=1", - "--amdgpu-target=gfx90a,gfx940,gfx941,gfx942", - ] - ext_modules=[ - CUDAExtension( - "sgl_kernel.ops.moe_align_block_size", - [ - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - ], - extra_compile_args={ - "nvcc": hipcc_flags - + [ - "-O3", - "-Xcompiler", - "-fPIC", - ], - "cxx": ["-O3"], - }, - libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"], - extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], - ), - ] +nvcc_flags = [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", +] +cxx_flags = ["-O3"] +libraries = ["c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] +ext_modules = [ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=[ + "src/sgl-kernel/csrc/warp_reduce_kernel.cu", + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + ], + extra_compile_args={ + "nvcc": nvcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + ), +] setup( name="sgl-kernel", @@ -135,4 +101,4 @@ def update_wheel_platform_tag(): install_requires=["torch"], ) -update_wheel_platform_tag() +update_wheel_platform_tag() \ No newline at end of file