diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 600c781..5806787 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -2,32 +2,146 @@ cmake_minimum_required(VERSION 3.22) project(PPLXKernels VERSION 0.0.1 DESCRIPTION "PPLX Kernels" - LANGUAGES CXX CUDA) + LANGUAGES CXX) + + +set(ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME") +if (NOT IS_DIRECTORY ${ROCM_HOME}) + message(WARNING "ROCM_HOME ${ROCM_HOME} is not a directory") +endif() + +if (LINUX) + # SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH} + set(ENV{ROCM_PATH} ${ROCM_HOME}) +endif() + +if(NOT DEFINED HIP_CMAKE_PATH) + if(NOT DEFINED ENV{HIP_CMAKE_PATH}) + # NOTE(yiakwy) : find_package(HIP) will first search for + # cmake/Modules/FindAMDDeviceLibs.cmake + # , then + # /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake + # this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x). + # Add hip-config.cmake to CMake module search path. + # set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed") + # NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp + set(ROCSHMEM_LIB_DIR "/root/rocshmem") + set(ROCSHMEM_CMAKE_PATH "${ROCSHMEM_LIB_DIR}/lib/cmake/rocshmem") + + set(HIP_CMAKE_PATH + # NOTE (yiakwy) : by default rocm install in local directory ~/rocshmem + "${ROCSHMEM_CMAKE_PATH}" + "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" + "${ROCM_HOME}/lib/cmake/amd_comgr" + "${ROCM_HOME}/lib/cmake/hsa-runtime64" + "${ROCM_HOME}/lib/cmake/hipcub" + "${ROCM_HOME}/lib/cmake/rccl" + "${ROCM_HOME}/lib/cmake/composable_kernel" CACHE PATH "Path to which HIP has been installed") + message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}") + + set(CMAKE_PREFIX_PATH "${ROCM_HOME};${ROCM_HOME}/lib/cmake/hip;${ROCSHMEM_CMAKE_PATH};${CMAKE_PREFIX_PATH}") + else() + set(HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed") + endif() + + set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH}) + +endif() + +add_definitions(-Wall) +find_package(HIP QUIET) +if(HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) + execute_process(COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'" + OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE) + + message(STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES}") + + enable_language(HIP) + + add_definitions(-DUSE_ROCM=1) + + if (NOT DEFINED CMAKE_CXX_COMPILER) + find_program(CMAKE_CXX_COMPILER hipcc PATHS /opt/rocm) + endif() + + # NOTE (yiakwy) : modern way to include ROCM tools + find_package(ROCmCMakeBuildTools PATHS /opt/rocm) + include(ROCMCreatePackage) + include(ROCMInstallTargets) + include(ROCMCheckTargetIds) + + # NOTE (yiakwy) : include rocSHMEM + if(NOT TARGET roc::rocshmem) + # find_package(rocshmem REQUIRED) + find_package(rocshmem REQUIRED PATHS /root) #${ROCSHMEM_LIB_DIR}) + endif() + +else() + message(WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location.") +endif() + + +find_package(CUDA QUIET) +if (CUDA_FOUND) + message(STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR}) + + execute_process(COMMAND bash -c "/usr/bin/nvidia-smi --query-gpu=compute_cap --format=csv,noheader | grep -o -m1 '[0-9.]*'" + OUTPUT_VARIABLE CUDA_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE) + + set(CUDA_SUPPORTED_ARCHS "9.0") + + find_package(CUDAToolkit REQUIRED) + set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON) + + find_package(NVSHMEM REQUIRED) + + set(CMAKE_CUDA_ARCHITECTURES 90a CACHE STRING "CUDA architecture to target") + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + + find_package(NVSHMEM REQUIRED) +else() + message(WARNING "Could not find CUDA.") +endif() + +if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND)) + message(FATAL "ROCM/CUDA SDK must be supported") +endif() + # === Configuration options === option(WITH_TESTS "Build tests" OFF) option(WITH_BENCHMARKS "Build benchmarks" OFF) -set(CMAKE_CUDA_ARCHITECTURES 90a CACHE STRING "CUDA architecture to target") # === CMake configuration === -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_INCLUDE_CURRENT_DIR ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_FLAGS_DEBUG "-g -ggdb -O0") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE BOOL "" FORCE ) + # === Dependencies === include(FetchContent) -find_package(CUDAToolkit REQUIRED) # Modern replacement for find_package(CUDA) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) + +include(cmake/py_helper.cmake) +append_torch_cmake_prefix_path() find_package(Torch REQUIRED) -find_package(NVSHMEM REQUIRED) + +find_package(MPI REQUIRED) if(WITH_TESTS) enable_testing() - find_package(MPI REQUIRED) find_package(PkgConfig REQUIRED) - pkg_check_modules(NCCL nccl) + + if (HIP_FOUND) + pkg_check_modules(RCCL rccl) + else() + pkg_check_modules(NCCL nccl) + endif() endif() # Create imported target for PyTorch @@ -44,33 +158,38 @@ add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) add_compile_definitions(Py_LIMITED_API=0x03090000) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -# CUDA-specific compile options function -function(set_cuda_compile_options target) - target_compile_options(${target} PRIVATE - $<$:--threads=32 -O3>) -endfunction() - # === Library targets === +if (HIP_FOUND) + include(cmake/rocm_helper.cmake) +endif() +include(cmake/py_helper.cmake) add_subdirectory(all_to_all) add_subdirectory(core) # Main shared library -add_library(pplx_kernels SHARED - bindings/all_to_all_ops.cpp - bindings/bindings.cpp -) -target_link_libraries(pplx_kernels PUBLIC - all_to_all_internode_lib - all_to_all_intranode_lib - core_lib - torch::py_limited - Python::Module - CUDA::cuda_driver - CUDA::cudart - nvshmem::nvshmem_host - nvshmem::nvshmem_device -) -set_target_properties(pplx_kernels PROPERTIES - LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels - CUDA_SEPARABLE_COMPILATION ON -) +# NOTE (yiakwy) : TODO +# add_library(pplx_kernels SHARED +# bindings/all_to_all_ops.cpp +# bindings/bindings.cpp +# ) + +# target_link_libraries(pplx_kernels PUBLIC +# all_to_all_internode_lib +# all_to_all_intranode_lib +# core_lib +# torch::py_limited +# Python::Module + +# # CUDA::cuda_driver +# # CUDA::cudart + +# roc::rocshmem + +# # nvshmem::nvshmem_host +# # nvshmem::nvshmem_device +# ) + +# set_target_properties(pplx_kernels PROPERTIES +# LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../src/pplx_kernels +# CUDA_SEPARABLE_COMPILATION ON +# ) diff --git a/csrc/all_to_all/CMakeLists.txt b/csrc/all_to_all/CMakeLists.txt index 9cd333f..d732e9a 100644 --- a/csrc/all_to_all/CMakeLists.txt +++ b/csrc/all_to_all/CMakeLists.txt @@ -1,61 +1,105 @@ # All-to-All library - -add_library(all_to_all_common STATIC +set(all_to_all_srcs all_to_all.cpp ) -target_link_libraries(all_to_all_common PUBLIC - CUDA::cudart -) - -add_library(all_to_all_intranode_lib STATIC +set(all_to_all_intranode_lib_srcs intranode_combine.cu intranode_dispatch.cu intranode.cpp ) -target_link_libraries(all_to_all_intranode_lib PUBLIC - all_to_all_common - CUDA::cudart -) -target_link_libraries(all_to_all_intranode_lib INTERFACE - nvshmem::nvshmem_host -) -target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) -set_cuda_compile_options(all_to_all_intranode_lib) -add_library(all_to_all_internode_lib STATIC +set(all_to_all_internode_lib_srcs internode_combine.cu internode_dispatch.cu internode.cpp ) -target_link_libraries(all_to_all_internode_lib PUBLIC - all_to_all_common - CUDA::cudart -) -target_link_libraries(all_to_all_internode_lib INTERFACE - nvshmem::nvshmem_host -) -target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) -set_cuda_compile_options(all_to_all_internode_lib) -if(WITH_TESTS) - # All-to-All test - add_executable(test_all_to_all - test_all_to_all.cpp +if (HIP_FOUND) + rocshmem_add_library(all_to_all_common + "${all_to_all_srcs}" ) - target_link_libraries(test_all_to_all PUBLIC - all_to_all_intranode_lib - all_to_all_internode_lib - core_lib + + # add all_to_all_intranode_lib + rocshmem_add_library(all_to_all_intranode_lib + "${all_to_all_intranode_lib_srcs}" + DEPS + all_to_all_common + ) + + rocshmem_add_library(all_to_all_internode_lib + "${all_to_all_internode_lib_srcs}" + DEPS + all_to_all_common + ) +else() + add_library(all_to_all_common STATIC + ${all_to_all_srcs} + ) + + target_link_libraries(all_to_all_common PUBLIC CUDA::cudart - CUDA::cuda_driver - MPI::MPI_CXX + ) + + # add all_to_all_intranode_lib + add_library(all_to_all_intranode_lib STATIC + "${all_to_all_intranode_lib_srcs}" + ) + target_link_libraries(all_to_all_intranode_lib PUBLIC + all_to_all_common + CUDA::cudart + ) + target_link_libraries(all_to_all_intranode_lib INTERFACE nvshmem::nvshmem_host ) - set_cuda_compile_options(test_all_to_all) - add_test(NAME AllToAllTest - COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $) - set_tests_properties(AllToAllTest PROPERTIES ENVIRONMENT "NVSHMEM_REMOTE_TRANSPORT=None") + target_include_directories(all_to_all_intranode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + set_cuda_compile_options(all_to_all_intranode_lib) + + # add all_to_all_internode_lib + add_library(all_to_all_internode_lib STATIC + "${all_to_all_internode_lib_srcs}" + ) + target_link_libraries(all_to_all_internode_lib PUBLIC + all_to_all_common + CUDA::cudart + ) + target_link_libraries(all_to_all_internode_lib INTERFACE + nvshmem::nvshmem_host + ) + target_include_directories(all_to_all_internode_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + set_cuda_compile_options(all_to_all_internode_lib) +endif() + +if(WITH_TESTS OR HIP_FOUND) # OR HIP_FOUND + + if (HIP_FOUND) + rocshmem_add_executable(test_all_to_all + "internode.h;intranode.h;test_all_to_all.cpp" + DEPS + all_to_all_intranode_lib + # all_to_all_internode_lib + core_lib + roctx64 + ) + else() + # All-to-All test + add_executable(test_all_to_all + test_all_to_all.cpp + ) + target_link_libraries(test_all_to_all PUBLIC + all_to_all_intranode_lib + all_to_all_internode_lib + core_lib + CUDA::cudart + CUDA::cuda_driver + MPI::MPI_CXX + nvshmem::nvshmem_host + ) + set_cuda_compile_options(test_all_to_all) + add_test(NAME AllToAllTest + COMMAND ${MPIEXEC_EXECUTABLE} -np 4 $) + set_tests_properties(AllToAllTest PROPERTIES ENVIRONMENT "NVSHMEM_REMOTE_TRANSPORT=None") + endif() endif() if (WITH_BENCHMARKS) diff --git a/csrc/all_to_all/internode.cpp b/csrc/all_to_all/internode.cpp index b73e547..0373a4a 100644 --- a/csrc/all_to_all/internode.cpp +++ b/csrc/all_to_all/internode.cpp @@ -1,5 +1,9 @@ +#ifdef USE_ROCM +#include "core/hip_dist_defs.h" +#else #include +#endif #include #include diff --git a/csrc/all_to_all/internode.h b/csrc/all_to_all/internode.h index 28aa939..546bc76 100644 --- a/csrc/all_to_all/internode.h +++ b/csrc/all_to_all/internode.h @@ -2,7 +2,12 @@ #include #include + +#ifdef USE_ROCM +#include "core/hip_cuda_dtype_defs.h" +#else #include +#endif // USE_ROCM #include "all_to_all/all_to_all.h" #include "core/buffer.h" diff --git a/csrc/all_to_all/internode_combine.cu b/csrc/all_to_all/internode_combine.cu index 0e02a06..e2ff2c0 100644 --- a/csrc/all_to_all/internode_combine.cu +++ b/csrc/all_to_all/internode_combine.cu @@ -2,14 +2,39 @@ #include "core/utils.h" #include "internode.h" +#ifdef USE_ROCM + +#include + +#include + +#include "core/hip_defs.h" +#include "core/hip_cuda_dtype_defs.h" +#include "core/hip_dist_defs.h" +#include "core/hip_roctx_defs.h" + +#include + +#else + #include + #include #include +#endif // USE_ROCM + +#include "core/common_utils.h" + using namespace pplx; +#if USE_ROCM +using namespace rocshmem; +#define __ldg(...) __builtin_nontemporal_load(__VA_ARGS__) +#endif + template -__global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( +__global__ __launch_bounds__(NUM_WARPS * WARP_SIZE, 1) void combineKernel( U *outTokens, size_t outTokensStrideElem, const uint32_t *indices, @@ -42,15 +67,29 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( ) { const unsigned numLocalExperts = numExperts / worldSize; const size_t stride = hiddenDim * sizeof(T); - constexpr unsigned WARP_SIZE = 32; + uint32_t warpId = threadIdx.x / WARP_SIZE; + uint32_t laneId = threadIdx.x % WARP_SIZE; + + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (DO_SEND) { + const size_t numSendTokens = __ldg(&globalTokenIndex); + +#if USE_ROCM + if (tid < worldSize) { + for (int i= tid; i < worldSize; i += gridDim.x * blockDim.x) { + rocshmem_ulonglong_p((unsigned long long *)&combineSyncBuffer[rank], 1U, i); + } + } + rocshmem_fence(); +#else for (unsigned i = blockIdx.x * blockDim.x + threadIdx.x; i < worldSize; i += gridDim.x * blockDim.x) { nvshmemx_signal_op(&combineSyncBuffer[rank], 1, NVSHMEM_SIGNAL_SET, i); } +#endif // USE_ROCM // Dispatch the tokens from the expert to the DP groups. for (uint32_t token = blockIdx.x; token < numSendTokens; token += gridDim.x) { @@ -70,7 +109,12 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( const unsigned n = stride / sizeof(int4); #pragma unroll(4) for (unsigned j = threadIdx.x; j < n; j += blockDim.x) { +#if USE_ROCM + // + *dstPtr = *srcPtr; +#else *dstPtr = __ldg(srcPtr); +#endif // USE_ROCM dstPtr += blockDim.x; srcPtr += blockDim.x; } @@ -85,24 +129,42 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( const int dstRank = dp * dpSize + i; const unsigned index = dstExpert * maxNumTokens + source; std::byte *dstPtr = xBufferOut + index * stride; +#if USE_ROCM + rocshmem::rocshmem_putmem_signal_nbi_wave( + dstPtr, xTokenPtr, stride, &combineSignalBuffer[source], 1, rocshmem::ROCSHMEM_SIGNAL_ADD, dstRank + ); +#else nvshmemx_putmem_signal_nbi_warp( dstPtr, xTokenPtr, stride, &combineSignalBuffer[source], 1, NVSHMEM_SIGNAL_ADD, dstRank ); +#endif // USE_ROCM } } } - // Synchronize the grid to ensure that tokens routed within the rank are - // correctly transported from one block to another. if (DO_RECV) { if (DO_SEND) { + // Synchronize the grid to ensure that tokens routed within the rank are + // correctly transported from one block to another. cooperative_groups::this_grid().sync(); +#if USE_ROCM + // NOTE (yiakwy) : for AMD GPU we explicitly make sure IO fence inserted + rocshmem::rocshmem_fence(); + rocshmem::rocshmem_barrier_all_wg(); +#endif // USE_ROCM } // Compute the weighed sum of the input tokens. const size_t localNumTokens = boundM ? __ldg(boundM) : m; for (unsigned i = blockIdx.x; i < localNumTokens; i += gridDim.x) { + +#if USE_ROCM + // TODO (yiakwy) : add uint64 interface + rocshmem::rocshmem_ulonglong_wait_until((unsigned long long *)&combineSignalBuffer[i], rocshmem::ROCSHMEM_CMP_EQ, expertsPerToken); +#else nvshmem_uint64_wait_until(&combineSignalBuffer[i], NVSHMEM_CMP_EQ, expertsPerToken); +#endif + __syncthreads(); combineSignalBuffer[i] = 0; @@ -136,7 +198,13 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( for (unsigned i = blockIdx.x * blockDim.x + threadIdx.x; i < worldSize; i += gridDim.x * blockDim.x) { + +#if USE_ROCM + rocshmem::rocshmem_ulonglong_wait_until((unsigned long long *)&combineSignalBuffer[i], rocshmem::ROCSHMEM_CMP_EQ, 1U); +#else nvshmem_uint64_wait_until(&combineSyncBuffer[i], NVSHMEM_CMP_EQ, 1); +#endif // USE_ROCM + combineSyncBuffer[i] = 0; } @@ -167,7 +235,7 @@ void AllToAllInterNode::combine( assert(hiddenDimBytes % 16 == 0); dim3 dimGrid(numBlocks, 1, 1); - dim3 dimBlock(NUM_WARPS * 32, 1, 1); + dim3 dimBlock(NUM_WARPS * WARP_SIZE, 1, 1); void *args[] = { const_cast(&outTokens.data), diff --git a/csrc/all_to_all/internode_dispatch.cu b/csrc/all_to_all/internode_dispatch.cu index d7374e6..68b2463 100644 --- a/csrc/all_to_all/internode_dispatch.cu +++ b/csrc/all_to_all/internode_dispatch.cu @@ -1,18 +1,43 @@ +#include "core/nvshmem_utils.h" +#include "core/utils.h" +#include "internode.h" + +#ifdef USE_ROCM + +#include + +#include + +#include "core/hip_defs.h" +#include "core/hip_cuda_dtype_defs.h" +#include "core/hip_dist_defs.h" +#include "core/hip_roctx_defs.h" + +#else + #include #include #include #include +#endif // USE_ROCM + #include "all_to_all/internode.h" +#include "core/common_utils.h" #include "core/device_utils.cuh" #include "core/utils.h" using namespace pplx; -namespace { +#ifdef USE_ROCM +using namespace rocshmem; +#define __ldg(...) __builtin_nontemporal_load(__VA_ARGS__) +#endif + +// namespace { template -__global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( +__global__ __launch_bounds__(NUM_WARPS * WARP_SIZE, 1) void dispatchKernel( int32_t *outNumTokensPerExpert, size_t outNumTokensPerExpertStrideElem, std::byte *expertX, @@ -59,7 +84,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( const unsigned dpRank = rank % dpSize; const unsigned tokenDim = hiddenDim + hiddenDimScale; const unsigned tokenStride = round_up(tokenDim + sizeof(uint32_t), sizeof(int4)); - const unsigned WARP_SIZE = 32; + // const unsigned WARP_SIZE = 32; const unsigned warpId = threadIdx.x / WARP_SIZE; const unsigned laneId = threadIdx.x % WARP_SIZE; @@ -101,7 +126,13 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( uint64_t *dstCount = &numTokensBuffer[dstLocalExpert * numDPGroups + dpGroup]; if (laneId == 0) { + +#ifdef USE_ROCM + rocshmem_ulong_p((unsigned long *)dstCount, numTokensPerExpert + 1, dstRank); +#else nvshmemx_signal_op(dstCount, numTokensPerExpert + 1, NVSHMEM_SIGNAL_SET, dstRank); +#endif // USE_ROCM + } } @@ -141,7 +172,11 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( } // Synchronize the warps within this warp group. +#if USE_ROCM + __syncthreads(); +#else asm volatile("bar.sync 1, %0;" ::"r"(numGroupThreads)); +#endif // USE_ROCM // Send the token to the other ranks, one send per warp. for (unsigned j = warpId; j < numExpertsPerToken; j += numGroupWarps) { @@ -154,6 +189,17 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( const unsigned loc = group * maxNumTokens + index; std::byte *destPointer = xBufferOut + loc * tokenStride; +#if USE_ROCM + rocshmem_putmem_signal_nbi_wave( + destPointer, + xInPtr, + tokenStride, + &numRecvBuffer[group], + 1, + ROCSHMEM_SIGNAL_ADD, + dstRank + ); +#else nvshmemx_putmem_signal_nbi_warp( destPointer, xInPtr, @@ -163,6 +209,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( NVSHMEM_SIGNAL_ADD, dstRank ); +#endif // USE_ROCM } } } @@ -188,9 +235,19 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( // Fetch the token count per DP, which is non-zero to indicate receipt. // Afterwards, wait for exactly that many tokens to be sent to us. +#if USE_ROCM + rocshmem_ulong_wait_until((unsigned long *)&numTokensBuffer[group], ROCSHMEM_CMP_NE, 0); +#else nvshmem_uint64_wait_until(&numTokensBuffer[group], NVSHMEM_CMP_NE, 0); +#endif // USE_ROCM + size_t numTokens = numTokensBuffer[group] - 1; + +#if USE_ROCM + rocshmem_ulong_wait_until((unsigned long *)&numRecvBuffer[group], ROCSHMEM_CMP_EQ, numTokens); +#else nvshmem_uint64_wait_until(&numRecvBuffer[group], NVSHMEM_CMP_EQ, numTokens); +#endif // USE_ROCM numTokensPerDP[group] = numTokens; numTokensBuffer[group] = 0; @@ -249,7 +306,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( } } -} // namespace +// } // namespace void AllToAllInterNode::dispatch( const Strided1D &outNumTokensPerExpert, @@ -262,7 +319,8 @@ void AllToAllInterNode::dispatch( const unsigned *boundM, SplitMode splitMode, cudaStream_t stream -) { +) +{ constexpr unsigned NUM_WARPS = 10; const unsigned numBlocks = std::min( std::max( @@ -271,7 +329,7 @@ void AllToAllInterNode::dispatch( static_cast(numSMs) ); dim3 dimGrid(numBlocks, 1, 1); - dim3 dimBlock(NUM_WARPS * 32, 1, 1); + dim3 dimBlock(NUM_WARPS * WARP_SIZE, 1, 1); const size_t expertsPerBlock = ceil_div(numLocalExperts * numDPGroups, numBlocks); const size_t sharedMemorySend = sizeof(uint32_t) * numExperts; diff --git a/csrc/all_to_all/intranode.h b/csrc/all_to_all/intranode.h index 871347e..6bbb33b 100644 --- a/csrc/all_to_all/intranode.h +++ b/csrc/all_to_all/intranode.h @@ -6,7 +6,11 @@ #include #include +#ifdef USE_ROCM +#include "core/hip_cuda_dtype_defs.h" +#else #include +#endif // USE_ROCM #include "all_to_all/all_to_all.h" #include "core/buffer.h" diff --git a/csrc/all_to_all/intranode_combine.cu b/csrc/all_to_all/intranode_combine.cu index 54f9102..2ab2d01 100644 --- a/csrc/all_to_all/intranode_combine.cu +++ b/csrc/all_to_all/intranode_combine.cu @@ -1,20 +1,37 @@ #include "all_to_all/intranode.cuh" + #include "core/atomic.cuh" + #include "core/device_utils.cuh" #include "core/utils.h" #include "intranode.h" #include +#ifdef USE_ROCM + +#include + +#include + +#include "core/hip_defs.h" +#include "core/hip_roctx_defs.h" + +#else #include #include +#endif // USE_ROCM using namespace pplx; +#if USE_ROCM +// #define __ldg(...) __builtin_nontemporal_load(__VA_ARGS__) +#endif + namespace { template -__global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( +__global__ __launch_bounds__(NUM_WARPS * WARP_SIZE, 1) void combineKernel( U *outTokens, size_t outTokensStrideElem, uint32_t *indices, @@ -48,7 +65,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel( const unsigned numLocalExperts = numExperts / worldSize; const size_t tokenDim = hiddenDim * sizeof(T); const size_t tokenStride = round_up(tokenDim, sizeof(int4)); - constexpr unsigned WARP_SIZE = 32; + // constexpr unsigned WARP_SIZE = 32; uint32_t warpId = threadIdx.x / WARP_SIZE; const unsigned laneId = threadIdx.x % WARP_SIZE; @@ -208,7 +225,7 @@ void AllToAllIntraNode::combine( assert(hiddenDimBytes % 16 == 0); dim3 dimGrid(numBlocks, 1, 1); - dim3 dimBlock(NUM_WARPS * 32, 1, 1); + dim3 dimBlock(NUM_WARPS * WARP_SIZE, 1, 1); void *args[] = { const_cast(&outTokens.data), diff --git a/csrc/all_to_all/intranode_dispatch.cu b/csrc/all_to_all/intranode_dispatch.cu index 402773b..0175aab 100644 --- a/csrc/all_to_all/intranode_dispatch.cu +++ b/csrc/all_to_all/intranode_dispatch.cu @@ -4,10 +4,27 @@ #include "core/utils.h" #include "intranode.h" +#ifdef USE_ROCM + +#include + +#include + +#include "core/hip_defs.h" +#include "core/hip_roctx_defs.h" + +#define FULL_MASK 0xffffffffffffffff + +#else + #include #include #include +#define FULL_MASK 0xffffffff + +#endif // USE_ROCM + using namespace pplx; namespace { @@ -125,7 +142,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( __syncthreads(); // Send the token to the other ranks, one send per warp. - const unsigned WARP_SIZE = 32; + // const unsigned WARP_SIZE = 32; const unsigned warpId = threadIdx.x / WARP_SIZE; const unsigned laneId = threadIdx.x % WARP_SIZE; for (unsigned j = warpId; j < numExpertsPerToken; j += NUM_WARPS) { @@ -141,7 +158,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel( } else { index = 0; } - index = __shfl_sync(0xffffffff, index, 0); + index = __shfl_sync(FULL_MASK, index, 0); // Copy the token to the shared buffer. std::byte *buffer = remoteBuffer.getTokenPtr(dstRank, dstLocalExpert, index); diff --git a/csrc/all_to_all/test_all_to_all.cpp b/csrc/all_to_all/test_all_to_all.cpp index f549416..ae4397e 100644 --- a/csrc/all_to_all/test_all_to_all.cpp +++ b/csrc/all_to_all/test_all_to_all.cpp @@ -1,10 +1,20 @@ // All-to-all kernel test +#ifdef USE_ROCM +#include +#include +#include +#include + +#include "core/nvshmem_utils.h" + +#else #include #include #include #include #include +#endif // USE_ROCM #include #include @@ -12,6 +22,7 @@ #include #include "all_to_all/internode.h" + #include "all_to_all/intranode.h" #include "all_to_all/test_utils.h" #include "core/buffer.h" @@ -346,9 +357,52 @@ int main(int argc, char **argv) { MPICHECK(MPI_Finalize()); return EXIT_FAILURE; } + + MPI_Comm mpi_comm = MPI_COMM_WORLD; + +#if USE_ROCM + using namespace rocshmem; + + rocshmem_init_attr_t attr; + attr.mpi_comm = &mpi_comm; + + NVSHMEMCHECK(rocshmem_init_attr(ROCSHMEM_INIT_WITH_UNIQUEID, &attr)); + // roc_shmem_init(); + + int currentPE = rocshmem_my_pe(); + int numPEs = rocshmem_n_pes(); + + CUDACHECK(hipSetDevice(rank)); + + hipStream_t stream; + CUDACHECK(hipStreamCreate(&stream)); + + // Run the tests. + int exit_code = EXIT_SUCCESS; + + // Intra-node tests. + std::shared_ptr distributed = std::make_shared(rank, world_size); + if (!testDispatchCombine( + stream, rank / 2, 2, rank, world_size, distributed + )) { + exit_code = EXIT_FAILURE; + } + if (!testDispatchCombine<__hip_bfloat16, AllToAllIntraNode>( + stream, rank / 2, 2, rank, world_size, distributed + )) { + exit_code = EXIT_FAILURE; + } + + // Cleanup. + CUDACHECK(hipStreamDestroy(stream)); + rocshmem_barrier_all(); + rocshmem_finalize(); + + MPICHECK(MPI_Finalize()); + return exit_code; +#else // Set up NVSHMEM. - MPI_Comm mpi_comm = MPI_COMM_WORLD; nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; attr.mpi_comm = &mpi_comm; nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr); @@ -391,4 +445,5 @@ int main(int argc, char **argv) { nvshmem_finalize(); MPICHECK(MPI_Finalize()); return exit_code; +#endif } diff --git a/csrc/all_to_all/test_utils.h b/csrc/all_to_all/test_utils.h index c08fec1..aecd05e 100644 --- a/csrc/all_to_all/test_utils.h +++ b/csrc/all_to_all/test_utils.h @@ -10,7 +10,34 @@ #include #include +#if USE_ROCM + +#include + +// NOTE (yiakwy) : __hip_fp8_e4m3_fnuz only represents -240..0..240 +using __nv_fp8_e4m3 = __hip_fp8_e4m3_fnuz; + +__hip_fp8_storage_t +convert_float_to_fp8(float in, /* Input val */ + __hip_fp8_interpretation_t + interpret, /* interpretation of number E4M3/E5M2 */ + __hip_saturation_t sat /* Saturation behavior */ +) { + return __hip_cvt_float_to_fp8(in, sat, interpret); +} + +float convert_fp8_to_float( + __hip_fp8_storage_t in, /* Input val */ + __hip_fp8_interpretation_t + interpret /* interpretation of number E4M3/E5M2 */ +) { + __half hf = __hip_cvt_fp8_to_halfraw(in, interpret); + return static_cast(hf); +} + +#else #include +#endif namespace pplx { @@ -82,7 +109,8 @@ RankTestData::RankTestData( // Populate the tokens. if constexpr (std::is_integral::value) { - std::uniform_int_distribution<> value(-256, 256); + // std::uniform_int_distribution<> value(-256, 256); + std::uniform_int_distribution<> value(-224, 224); for (size_t i = 0; i < m; ++i) { for (size_t j = 0; j < hiddenDim; ++j) { x[i * hiddenDim + j] = value(gen); @@ -93,7 +121,11 @@ RankTestData::RankTestData( for (size_t i = 0; i < m; ++i) { for (size_t j = 0; j < hiddenDim; ++j) { if constexpr (std::is_same::value) { +#if USE_ROCM + x[i * hiddenDim + j] = static_cast<__hip_fp8_e4m3_fnuz>(convert_float_to_fp8(value(gen), __HIP_E4M3_FNUZ, __HIP_SATFINITE)); +#else x[i * hiddenDim + j] = __nv_cvt_float_to_fp8(value(gen), __NV_SATFINITE, __NV_E4M3); +#endif } else { x[i * hiddenDim + j] = value(gen); } diff --git a/csrc/cmake/py_helper.cmake b/csrc/cmake/py_helper.cmake new file mode 100644 index 0000000..6f02d54 --- /dev/null +++ b/csrc/cmake/py_helper.cmake @@ -0,0 +1,66 @@ +# Adapt from: https://github.com/neuralmagic/vllm-flash-attention/blob/main/cmake/utils.cmake + +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${Python_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +macro (append_torch_cmake_prefix_path) + run_python(TORCH_CMAKE_PREFIX_PATH + "import torch; print(torch.utils.cmake_prefix_path);" "Failed to locate torch path") + list(APPEND CMAKE_PREFIX_PATH ${TORCH_CMAKE_PREFIX_PATH}) + message(STATUS "TORCH_CMAKE_PREFIX_PATH : ${TORCH_CMAKE_PREFIX_PATH}") +endmacro() + +# +# Get additional GPU compiler flags from torch. +# +function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) + if (${GPU_LANG} STREQUAL "CUDA") + # + # Get common NVCC flags from torch. + # + run_python(GPU_FLAGS + "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) + list(APPEND GPU_FLAGS "-DENABLE_FP8") + endif() + if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) + list(REMOVE_ITEM GPU_FLAGS + "-D__CUDA_NO_HALF_OPERATORS__" + "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" + "-D__CUDA_NO_HALF2_OPERATORS__") + endif() + + elseif(${GPU_LANG} STREQUAL "HIP") + # + # Get common HIP/HIPCC flags from torch. + # + run_python(GPU_FLAGS + "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))" + "Failed to determine torch nvcc compiler flags") + + list(APPEND GPU_FLAGS + "-DUSE_ROCM" + "-DENABLE_FP8" + "-U__HIP_NO_HALF_CONVERSIONS__" + "-U__HIP_NO_HALF_OPERATORS__" + "-Werror=unused-variable" + "-fno-gpu-rdc") + + endif() + set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/csrc/cmake/rocm_helper.cmake b/csrc/cmake/rocm_helper.cmake new file mode 100644 index 0000000..9c7ce6d --- /dev/null +++ b/csrc/cmake/rocm_helper.cmake @@ -0,0 +1,54 @@ +set(common_libs + hiprtc + amdhip64 + hsa-runtime64::hsa-runtime64 + hip::host + hip::device +) + +function(rocshmem_add_library lib_name srcs) + cmake_parse_arguments(ARG "" "" "DEPS" ${ARGN}) + + set_source_files_properties(${srcs} PROPERTIES LANGUAGE HIP) + + hip_add_library(${lib_name} STATIC ${srcs}) + + set_target_properties(${lib_name} PROPERTIES HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + target_compile_options(${lib_name} PRIVATE -fgpu-rdc) + + target_link_libraries(${lib_name} + PRIVATE + ${common_libs} + MPI::MPI_CXX + roc::rocshmem + ${ARG_DEPS} + ) + + target_include_directories(${lib_name} PRIVATE ${ROCSHMEM_INCLUDE_DIR}) + # target_link_libraries(${lib_name} INTERFACE + # roc::rocshmem + # ) +endfunction() + +function(rocshmem_add_executable exec_name srcs) + cmake_parse_arguments(ARG "" "" "DEPS" ${ARGN}) + + set(CMAKE_CXX_COMPIILER "/opt/rocm/bin/hipcc") + + message(STATUS "exec_name : ${exec_name}") + message(STATUS "srcs : ${srcs}") + + set_source_files_properties(${srcs} PROPERTIES LANGUAGE HIP) + + add_executable(${exec_name} ${srcs}) + + # target_compile_options(${exec_name} PRIVATE -fgpu-rdc) + + target_link_libraries(${exec_name} + PRIVATE + ${common_libs} + MPI::MPI_CXX + roc::rocshmem + ${ARG_DEPS} + ) +endfunction() \ No newline at end of file diff --git a/csrc/core/CMakeLists.txt b/csrc/core/CMakeLists.txt index 821035d..5401a99 100644 --- a/csrc/core/CMakeLists.txt +++ b/csrc/core/CMakeLists.txt @@ -1,14 +1,44 @@ # Core library with common kernels. - -add_library(core_lib STATIC +set(core_srcs kernels.cu distributed.cpp ) -target_link_libraries(core_lib PUBLIC - CUDA::cudart -) -target_link_libraries(core_lib INTERFACE - nvshmem::nvshmem_host -) -target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) -set_cuda_compile_options(core_lib) + +if (HIP_FOUND) + set_source_files_properties(${core_srcs} PROPERTIES LANGUAGE HIP) + + hip_add_library(core_lib STATIC + ${core_srcs} + ) + + set_target_properties(core_lib PROPERTIES HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + target_link_libraries(core_lib PUBLIC + hiprtc + amdhip64 + hsa-runtime64::hsa-runtime64 + hip::device + hip::host + MPI::MPI_CXX + roc::rocshmem + -fgpu-rdc + ) + target_include_directories(core_lib PRIVATE ${ROCSHMEM_INCLUDE_DIR}) + target_link_libraries(core_lib INTERFACE + roc::rocshmem + ) +else() + add_library(core_lib STATIC + ${core_srcs} + ) + target_link_libraries(core_lib PUBLIC + MPI::MPI_CXX + CUDA::cudart + ) + target_include_directories(core_lib PRIVATE ${NVSHMEM_INCLUDE_DIR}) + target_link_libraries(core_lib INTERFACE + nvshmem::nvshmem_host + ) + +endif() + +message(STATUS "ROCSHMEM incr : ${ROCSHMEM_INCLUDE_DIR}") \ No newline at end of file diff --git a/csrc/core/atomic.cuh b/csrc/core/atomic.cuh index 6f6e7f6..0bf02fe 100644 --- a/csrc/core/atomic.cuh +++ b/csrc/core/atomic.cuh @@ -2,31 +2,67 @@ #include +#ifdef USE_ROCM +// ROCM device side atomic +#include +#endif + namespace pplx { __forceinline__ __device__ void st_flag_volatile(uint32_t *flag_addr, uint32_t flag) { + #ifdef __HIPCC__ + __threadfence_system(); + // *reinterpret_cast(flag_addr) = flag; + __builtin_nontemporal_store(flag, flag_addr); + #else asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); + #endif } __forceinline__ __device__ uint32_t ld_flag_volatile(uint32_t *flag_addr) { uint32_t flag; + #ifdef __HIPCC__ + // NOTE(yiakwy) : volatile access ensures direct access L2 cache/system memory bypassing L1 cache + __threadfence_system(); + // flag = *reinterpret_cast(flag_addr); + flag = __builtin_nontemporal_load(flag_addr); + #else asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + #endif return flag; } __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr) { uint32_t flag; + #ifdef __HIPCC__ + // NOTE (yiakwy) : ensuare all memory accesses before this inst visible to all threads, closest to ld.aquire ptx meaning + __threadfence_system(); + flag = *flag_addr; + #else asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + #endif return flag; } __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr, uint32_t flag) { + #ifdef __HIPCC__ + *flag_addr = flag; + // NOTE (yiakwy) : ensure all writes to memory visible by all threads + __threadfence_system(); + #else asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); + #endif } __forceinline__ __device__ uint32_t add_flag_release(uint32_t *addr, uint32_t val) { uint32_t flag; + #ifdef __HIPCC__ + flag = atomicAdd((unsigned int*)addr, val); + // NOTE (yiakwy) : ensure all writes to memory visible by all threads + __threadfence_system(); + #else asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(flag) : "l"(addr), "r"(val)); + #endif return flag; } diff --git a/csrc/core/buffer.h b/csrc/core/buffer.h index 970bd05..38a3839 100644 --- a/csrc/core/buffer.h +++ b/csrc/core/buffer.h @@ -1,6 +1,11 @@ #pragma once +#ifdef USE_ROCM +#include "core/hip_defs.h" +#include "core/hip_cuda_dtype_defs.h" +#else #include +#endif #include "core/cuda_utils.h" @@ -14,7 +19,7 @@ template class HostBuffer final { public: HostBuffer(size_t size) : size_(size) { - CUDACHECK(cudaMallocHost(&data_, size * sizeof(T))); + CUDACHECK(cudaMallocHost( (void **) &data_, size * sizeof(T))); } HostBuffer(const DeviceBuffer &device_buffer); @@ -90,7 +95,7 @@ template class DeviceBuffer final { template HostBuffer::HostBuffer(const DeviceBuffer &device_buffer) : size_(device_buffer.size()) { - CUDACHECK(cudaMallocHost(&data_, size_ * sizeof(T))); + CUDACHECK(cudaMallocHost( (void**) &data_, size_ * sizeof(T))); CUDACHECK(cudaMemcpy(data_, device_buffer.get(), size_ * sizeof(T), cudaMemcpyDeviceToHost)); } diff --git a/csrc/core/common_utils.h b/csrc/core/common_utils.h index c457a1f..18a7ea8 100644 --- a/csrc/core/common_utils.h +++ b/csrc/core/common_utils.h @@ -3,12 +3,27 @@ #include #include -#ifdef __CUDA_ARCH__ -#define PPLX_HOST_DEVICE __host__ __device__ +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) || defined(__CUDACC__) +#include + +#define PPLX_HOST __host__ +#define PPLX_DEVICE __device__ + +#define WARP_SIZE 64 // warpSize is restriced in device codes since SDK 7.0 + #else -#define PPLX_HOST_DEVICE + +#include + +#define PPLX_HOST +#define PPLX_DEVICE + +#define WARP_SIZE 32 + #endif +#define PPLX_HOST_DEVICE PPLX_HOST PPLX_DEVICE + namespace pplx { /// Return the next power of 2 following the given number. @@ -24,3 +39,36 @@ template PPLX_HOST_DEVICE T ceil_div(T x, T y) { return (x + y - 1) template PPLX_HOST_DEVICE T round_up(T x, T y) { return ceil_div(x, y) * y; } } // namespace pplx + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include +#include +#include + +namespace amdgpu { + +#define DEVICE_INLINE __device__ __inline__ + +template +DEVICE_INLINE T shfl_xor_sync(unsigned mask, T var, int laneMask, int width = warpSize); + +template <> +DEVICE_INLINE float shfl_xor_sync(unsigned mask, float var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +template <> +DEVICE_INLINE int shfl_xor_sync(unsigned mask, int var, int laneMask, int width) { + return __shfl_xor(var, laneMask, width); +} + +} // namespace amdgpu + +template +DEVICE_INLINE T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width = 32/*use CUDA warp size as default*/) { + return amdgpu::shfl_xor_sync(mask, var, laneMask, width); +} + +#endif // defined(__HIP_PLATFORM_AMD__) diff --git a/csrc/core/cuda_utils.h b/csrc/core/cuda_utils.h index 80909c9..609b0ad 100644 --- a/csrc/core/cuda_utils.h +++ b/csrc/core/cuda_utils.h @@ -2,7 +2,14 @@ #include #include + +#ifdef USE_ROCM + +#include "hip_defs.h" + +#else #include +#endif #define CUDACHECK(cmd) \ do { \ diff --git a/csrc/core/device_utils.cuh b/csrc/core/device_utils.cuh index c60aa87..9f48c71 100644 --- a/csrc/core/device_utils.cuh +++ b/csrc/core/device_utils.cuh @@ -26,6 +26,11 @@ template struct enable_sm90_or_later : Kernel { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif + +#if defined __HIPCC__ + Kernel::operator()(std::forward(args)...); +#endif + } }; diff --git a/csrc/core/distributed.cpp b/csrc/core/distributed.cpp index 1cd0065..a833523 100644 --- a/csrc/core/distributed.cpp +++ b/csrc/core/distributed.cpp @@ -4,6 +4,31 @@ using namespace pplx; +#ifdef USE_ROCM + +#include + +using namespace rocshmem; + +__global__ void alltoallmem(int *source, int *dest, size_t nelem, rocshmem_team_t team) { + __shared__ rocshmem_ctx_t ctx; + int64_t ctx_type = 0; + + rocshmem_wg_init(); + rocshmem_wg_ctx_create(ctx_type, &ctx); + int num_pes = rocshmem_ctx_n_pes(ctx); + + rocshmem_ctx_int_alltoall_wg(ctx, team, dest, source, nelem); + + rocshmem_ctx_quiet(ctx); + __syncthreads(); + + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); +} + +#endif + Distributed::Distributed(unsigned rank, unsigned worldSize) : rank(rank), worldSize(worldSize) {} @@ -14,6 +39,36 @@ DistributedNVSHMEM::DistributedNVSHMEM(unsigned rank, unsigned worldSize) : Distributed(rank, worldSize) {} void DistributedNVSHMEM::allToAllImpl(const void *input, void *output, size_t size, size_t count) { +#ifdef USE_ROCM + PPLX_ASSERT(count == worldSize, "count must be equal to world size"); + + void *srcBuffer = nvshmem_malloc(size * count); + PPLX_ASSERT(srcBuffer != nullptr, "Failed to allocate src buffer"); + void *dstBuffer = nvshmem_malloc(size * count); + PPLX_ASSERT(dstBuffer != nullptr, "Failed to allocate dst buffer"); + + CUDACHECK(cudaMemcpy(srcBuffer, input, size * count, cudaMemcpyHostToDevice)); + + rocshmem_team_t team_reduce_world_dup; + team_reduce_world_dup = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, count, nullptr, 0, + &team_reduce_world_dup); + + HIP_MPI_CHECK(hipDeviceSynchronize()); + + int threadsPerBlock=256; + dim3 grid(1); + dim3 block(threadsPerBlock); + + alltoallmem<<>>((int *)srcBuffer, (int *)dstBuffer, size, team_reduce_world_dup); + + HIP_MPI_CHECK(hipDeviceSynchronize()); + + nvshmem_free(srcBuffer); + nvshmem_free(dstBuffer); + + rocshmem_finalize(); +#else PPLX_ASSERT(count == worldSize, "count must be equal to world size"); void *srcBuffer = nvshmem_malloc(size * count); @@ -30,4 +85,5 @@ void DistributedNVSHMEM::allToAllImpl(const void *input, void *output, size_t si nvshmem_free(dstBuffer); nvshmem_free(srcBuffer); +#endif } diff --git a/csrc/core/distributed.h b/csrc/core/distributed.h index 7cc1d9e..0e5da2b 100644 --- a/csrc/core/distributed.h +++ b/csrc/core/distributed.h @@ -1,6 +1,18 @@ #pragma once +#define USE_ROCM 1 + +#ifdef USE_ROCM + +#include + +#include "hip_dist_defs.h" + +__global__ void alltoallmem(int *source, int *dest, size_t nelem, rocshmem::rocshmem_team_t team); + +#else #include +#endif #include diff --git a/csrc/core/hip_cuda_dtype_defs.h b/csrc/core/hip_cuda_dtype_defs.h new file mode 100644 index 0000000..c2b7e29 --- /dev/null +++ b/csrc/core/hip_cuda_dtype_defs.h @@ -0,0 +1,33 @@ +/* Copyright 2025 flashFloat authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if defined(__HIP_PLATFORM_AMD__) + +#include + +#include +#include +#include + +using nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat162 = __hip_bfloat162; + +using nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat16 = __hip_bfloat16; + +using half2 = __half2; +using nv_bfloat162 = __hip_bfloat162; + +#endif // defined(__HIP_PLATFORM_AMD__) \ No newline at end of file diff --git a/csrc/core/hip_defs.h b/csrc/core/hip_defs.h new file mode 100644 index 0000000..262eb61 --- /dev/null +++ b/csrc/core/hip_defs.h @@ -0,0 +1,163 @@ +// adpated from MSC mscclpp project, also see examples from cholla (https://github.com/cholla-hydro/cholla/blob/main/src/utils/gpu.hpp) +// Copyright LEI WANG (yiak.wy@gmail.com) +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef PPLX_HIP_DEFS_H_ +#define PPLX_HIP_DEFS_H_ + +#ifndef __HIP_PLATFORM_AMD__ +#define __HIP_PLATFORM_AMD__ +#endif + +#ifdef __HIP_PLATFORM_NVIDIA__ +#undef __HIP_PLATFORM_NVIDIA__ +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include + +// enum alias +using cudaFuncAttribute = hipFuncAttribute; +const cudaFuncAttribute cudaFuncAttributeMaxDynamicSharedMemorySize = hipFuncAttribute::hipFuncAttributeMaxDynamicSharedMemorySize; +const cudaFuncAttribute cudaFuncAttributePreferredSharedMemoryCarveout = hipFuncAttribute::hipFuncAttributePreferredSharedMemoryCarveout; +const cudaFuncAttribute cudaFuncAttributeMax = hipFuncAttribute::hipFuncAttributeMax; + +using cudaDeviceAttr = hipDeviceAttribute_t; +// Number of multiprocessors on the device +const cudaDeviceAttr cudaDevAttrMultiProcessorCount = hipDeviceAttribute_t::hipDeviceAttributeMultiprocessorCount; +const cudaDeviceAttr cudaDevAttrMaxSharedMemoryPerMultiprocessor = hipDeviceAttribute_t::hipDeviceAttributeMaxSharedMemoryPerMultiprocessor; + +// function alias +template +inline static hipError_t cudaFuncSetAttribute(Func&& func, const hipFuncAttribute& attr, int value) { + return hipFuncSetAttribute((void*)func, attr, value); +} + +template +static __inline__ __host__ __device__ +auto cudaLaunchKernel(Args&&... args) -> decltype(hipLaunchKernel(std::forward(args)...)) { + return hipLaunchKernel(std::forward(args)...); +} + +static __inline__ __host__ __device__ +hipError_t cudaDeviceGetAttribute(int *value, cudaDeviceAttr attr, int device) { + return hipDeviceGetAttribute(value, attr, device); +} + +template +inline static hipError_t cudaOccupancyMaxActiveBlocksPerMultiprocessor(int* numBlocks, + Func func, + int blockSize, + size_t dynamicSMemSize) { + return hipOccupancyMaxActiveBlocksPerMultiprocessor(numBlocks, (void*)func, + blockSize, dynamicSMemSize); +} + +static __inline__ __host__ __device__ hipError_t cudaLaunchCooperativeKernel(const void* f, dim3 gridDim, + dim3 blockDimX, void** kernelParams, size_t sharedMem = 0, hipStream_t stream = hipStreamDefault) { + return hipLaunchCooperativeKernel(f, gridDim, blockDimX, kernelParams, sharedMem, stream); +} + +// Type alias +using cudaError_t = hipError_t; +using cudaGraph_t = hipGraph_t; +using cudaGraphExec_t = hipGraphExec_t; +using cudaDeviceProp = hipDeviceProp_t; +using cudaStream_t = hipStream_t; +using cudaStreamCaptureMode = hipStreamCaptureMode; +using cudaMemcpyKind = hipMemcpyKind; +using cudaIpcMemHandle_t = hipIpcMemHandle_t; + +using CUresult = hipError_t; +using CUdeviceptr = hipDeviceptr_t; +using CUmemGenericAllocationHandle = hipMemGenericAllocationHandle_t; +using CUmemAllocationProp = hipMemAllocationProp; +using CUmemAccessDesc = hipMemAccessDesc; + +constexpr auto cudaSuccess = hipSuccess; +constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking; +constexpr auto cudaStreamCaptureModeGlobal = hipStreamCaptureModeGlobal; +constexpr auto cudaStreamCaptureModeRelaxed = hipStreamCaptureModeRelaxed; +constexpr auto cudaHostAllocMapped = hipHostMallocMapped; +constexpr auto cudaHostAllocWriteCombined = hipHostMallocWriteCombined; +constexpr auto cudaMemcpyDefault = hipMemcpyDefault; +constexpr auto cudaMemcpyDeviceToDevice = hipMemcpyDeviceToDevice; +constexpr auto cudaMemcpyHostToDevice = hipMemcpyHostToDevice; +constexpr auto cudaMemcpyDeviceToHost = hipMemcpyDeviceToHost; +constexpr auto cudaIpcMemLazyEnablePeerAccess = hipIpcMemLazyEnablePeerAccess; + +constexpr auto CU_MEM_ALLOCATION_TYPE_PINNED = hipMemAllocationTypePinned; +constexpr auto CU_MEM_LOCATION_TYPE_DEVICE = hipMemLocationTypeDevice; +constexpr auto CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = hipMemHandleTypePosixFileDescriptor; +constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWrite; + +#ifndef CUDA_SUCCESS +#define CUDA_SUCCESS hipSuccess +#endif // CUDA_SUCCESS + +#define cudaGetErrorString(...) hipGetErrorString(__VA_ARGS__) +#define cudaGetDevice(...) hipGetDevice(__VA_ARGS__) +#define cudaGetDeviceCount(...) hipGetDeviceCount(__VA_ARGS__) +#define cudaGetDeviceProperties(...) hipGetDeviceProperties(__VA_ARGS__) +#define cudaGetLastError(...) hipGetLastError(__VA_ARGS__) +#define cudaSetDevice(...) hipSetDevice(__VA_ARGS__) +#define cudaDeviceSynchronize(...) hipDeviceSynchronize(__VA_ARGS__) +#define cudaDeviceGetPCIBusId(...) hipDeviceGetPCIBusId(__VA_ARGS__) +#define cudaHostAlloc(...) hipHostMalloc(__VA_ARGS__) +#define cudaMalloc(...) hipMalloc(__VA_ARGS__) + +// #define cudaMallocHost(...) hipMallocHost(__VA_ARGS__) +#define cudaMallocHost(...) hipHostMalloc(__VA_ARGS__) + +#define cudaFree(...) hipFree(__VA_ARGS__) +#define cudaFreeHost(...) hipHostFree(__VA_ARGS__) +#define cudaMemset(...) hipMemset(__VA_ARGS__) +#define cudaMemsetAsync(...) hipMemsetAsync(__VA_ARGS__) +#define cudaMemcpy(...) hipMemcpy(__VA_ARGS__) +#define cudaMemcpyAsync(...) hipMemcpyAsync(__VA_ARGS__) +#define cudaMemcpyToSymbol(...) hipMemcpyToSymbol(__VA_ARGS__) +#define cudaMemcpyToSymbolAsync(...) hipMemcpyToSymbolAsync(__VA_ARGS__) +#define cudaStreamCreate(...) hipStreamCreate(__VA_ARGS__) +#define cudaStreamCreateWithFlags(...) hipStreamCreateWithFlags(__VA_ARGS__) +#define cudaStreamSynchronize(...) hipStreamSynchronize(__VA_ARGS__) +#define cudaStreamBeginCapture(...) hipStreamBeginCapture(__VA_ARGS__) +#define cudaStreamEndCapture(...) hipStreamEndCapture(__VA_ARGS__) +#define cudaStreamDestroy(...) hipStreamDestroy(__VA_ARGS__) +#define cudaGraphInstantiate(...) hipGraphInstantiate(__VA_ARGS__) +#define cudaGraphLaunch(...) hipGraphLaunch(__VA_ARGS__) +#define cudaGraphDestroy(...) hipGraphDestroy(__VA_ARGS__) +#define cudaGraphExecDestroy(...) hipGraphExecDestroy(__VA_ARGS__) +#define cudaThreadExchangeStreamCaptureMode(...) hipThreadExchangeStreamCaptureMode(__VA_ARGS__) +#define cudaIpcGetMemHandle(...) hipIpcGetMemHandle(__VA_ARGS__) +#define cudaIpcOpenMemHandle(...) hipIpcOpenMemHandle(__VA_ARGS__) +#define cudaIpcCloseMemHandle(...) hipIpcCloseMemHandle(__VA_ARGS__) + +#define cuGetErrorString(...) hipDrvGetErrorString(__VA_ARGS__) +#define cuMemAddressReserve(...) hipMemAddressReserve(__VA_ARGS__) +#define cuMemAddressFree(...) hipMemAddressFree(__VA_ARGS__) +#define cuMemGetAddressRange(...) hipMemGetAddressRange(__VA_ARGS__) +#define cuMemCreate(...) hipMemCreate(__VA_ARGS__) +#define cuMemRelease(...) hipMemRelease(__VA_ARGS__) +#define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__) +#define cuMemMap(...) hipMemMap(__VA_ARGS__) +#define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__) + +#else + +#include +#include + +#endif // defined(__HIP_PLATFORM_AMD__) + +// NVLS +#if !defined(__HIP_PLATFORM_AMD__) +#include +#define USE_NVLS ((CUDART_VERSION >= 12010) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) +#else // !defined(__HIP_PLATFORM_AMD__) +#define USE_NVLS 0 +#endif // !defined(__HIP_PLATFORM_AMD__) + +#endif // PPLX_HIP_DEFS_H_ \ No newline at end of file diff --git a/csrc/core/hip_dist_defs.h b/csrc/core/hip_dist_defs.h new file mode 100644 index 0000000..a58f79e --- /dev/null +++ b/csrc/core/hip_dist_defs.h @@ -0,0 +1,36 @@ +#pragma once + +#ifndef __HIP_PLATFORM_AMD__ +#define __HIP_PLATFORM_AMD__ +#endif + +#ifdef __HIP_PLATFORM_NVIDIA__ +#undef __HIP_PLATFORM_NVIDIA__ +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include + +#include +#include + +#include + +// function alias + +#define nvshmem_init(...) rocshmem::rocshmem_init(__VA_ARGS__) + +#define nvshmem_malloc(...) rocshmem::rocshmem_malloc(__VA_ARGS__) + +#define nvshmem_free(...) rocshmem::rocshmem_free(__VA_ARGS__) + +#define HIP_MPI_CHECK(condition) { \ + hipError_t error = condition; \ + if(error != hipSuccess){ \ + fprintf(stderr,"HIP error: %d line: %d\n", error, __LINE__); \ + MPI_Abort(MPI_COMM_WORLD, error); \ + } \ + } + +#endif // defined(__HIP_PLATFORM_AMD__) diff --git a/csrc/core/hip_roctx_defs.h b/csrc/core/hip_roctx_defs.h new file mode 100644 index 0000000..d93e726 --- /dev/null +++ b/csrc/core/hip_roctx_defs.h @@ -0,0 +1,24 @@ +#pragma once + +#ifndef __HIP_PLATFORM_AMD__ +#define __HIP_PLATFORM_AMD__ +#endif + +#ifdef __HIP_PLATFORM_NVIDIA__ +#undef __HIP_PLATFORM_NVIDIA__ +#endif + +#if defined(__HIP_PLATFORM_AMD__) + +#include + +#include +#include + +#include + +#define nvtxRangePush(...) roctxRangePushA(__VA_ARGS__) + +#define nvtxRangePop(...) roctxRangePop(__VA_ARGS__) + +#endif // defined(__HIP_PLATFORM_AMD__) diff --git a/csrc/core/kernels.cu b/csrc/core/kernels.cu index c3af8e5..79d930e 100644 --- a/csrc/core/kernels.cu +++ b/csrc/core/kernels.cu @@ -1,13 +1,37 @@ #include "core/cuda_utils.h" #include "kernels.h" +#include + #include -#include +// NOTE (yiakwy) : remove the file +// #include + +namespace amdgpu { + + __device__ void usleep(unsigned long long delay_us) + { + // NOTE(yiakwy) : 1600 MHz in MI300, 2000 MHz in MI350 + const double clock_frequency_mhz = 2000ULL; + + unsigned long long wait_cycles = static_cast(delay_us * clock_frequency_mhz); + + unsigned long long start_clock = clock64(); + while (clock64() - start_clock < wait_cycles) { + ; // Do nothing, just spin + } + } + +} // namespace amdgpu __global__ void sleep_kernel(uint64_t ms) { for (int i = 0; i < ms; i++) { - __nanosleep(1000000); +#ifdef USE_ROCM + amdgpu::usleep(1000); +#else + __nanosleep(1000000); // NOTE (yiakwy) : asm("nanosleep.u32 64;" ::: "memory"); supported by NV PTX +#endif } } diff --git a/csrc/core/nvshmem_utils.h b/csrc/core/nvshmem_utils.h index 5819176..682280f 100644 --- a/csrc/core/nvshmem_utils.h +++ b/csrc/core/nvshmem_utils.h @@ -1,9 +1,29 @@ #pragma once +#ifdef USE_ROCM + +#include + +#else #include +#endif #include "core/cuda_utils.h" +#ifdef USE_ROCM + +#define NVSHMEMCHECK(stmt) \ + do { \ + int result = (stmt); \ + if (ROCSHMEM_SUCCESS != result) { \ + fprintf(stderr, "[%s:%d] rocshmem failed with error %d \n", __FILE__, __LINE__, result); \ + MPI_Abort(MPI_COMM_WORLD, result); \ + exit(-1); \ + } \ + } while (0) + +#else + #define NVSHMEMCHECK(stmt) \ do { \ int result = (stmt); \ @@ -12,3 +32,5 @@ exit(-1); \ } \ } while (0) + +#endif // USE_ROCM diff --git a/helper_rocm.py b/helper_rocm.py new file mode 100644 index 0000000..3d1bb32 --- /dev/null +++ b/helper_rocm.py @@ -0,0 +1,121 @@ +# Copyright 2025 FlashFloat authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import re +import subprocess + +import torch + +try: + from torch.utils.cpp_extension import ROCM_HOME +except: + raise RuntimeError( + "Base env does not provide Torch with support of ROCM SDK. Exit." + ) + +HIP_VERSION_PAT = r"HIP version: (\S+)" +HIP_SDK_ROOT = "/opt/rocm" + + +def is_hip(hip_sdk_root=None) -> bool: + SDK_ROOT = f"{hip_sdk_root or HIP_SDK_ROOT}" + + def _check_sdk_installed() -> bool: + # return True if this dir points to a directory or symbolic link + return os.path.isdir(SDK_ROOT) + + if not _check_sdk_installed(): + return False, None + + # we provide torch for the base env, check whether it is valid installation + result = subprocess.run( + [f"{SDK_ROOT}/bin/rocminfo | grep -o -m1 'gfx.*'"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + if result.returncode != 0: + print("Use AMD pytorch, but no devices found!") + return False, None + + target_amdgpu_arch = result.stdout.strip() + print(f"target AMD gpu arch {target_amdgpu_arch}") + return True, [target_amdgpu_arch] + + +# currently only support MI30X (MI308X, MI300XA) datacenter intelligent computing accelerator +_is_hip, target_amdgpu_arch = is_hip() + +if _is_hip: + assert ROCM_HOME is not None, "ROCM_HOME is not set" + + ROCM_HOME = os.environ.get("ROCM_HOME", ROCM_HOME) + + +def get_hipcc_rocm_version(hip_sdk_root=None): + assert _is_hip + + SDK_ROOT = f"{hip_sdk_root or HIP_SDK_ROOT}" + + result = subprocess.run( + [f"{SDK_ROOT}/bin/hipcc", "--version"], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # Check if the command was executed successfully + if result.returncode != 0: + print("Error running 'hipcc --version'") + return None + + # Extract the version using a regular expression + match = re.search(HIP_VERSION_PAT, result.stdout) + if match: + # Return the version string + return match.group(1) + else: + print("Could not find HIP version in the output") + return None + + +amd_libraries = ["hiprtc", "amdhip64"] + +for flag in [ + # "-D__HIP_NO_HALF_OPERATORS__=1", + # "-D__HIP_NO_HALF_CONVERSIONS__=1", +]: + try: + from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS + + COMMON_HIPCC_FLAGS.remove(flag) + except ValueError: + pass + +hipcc_flags = [ + "-D__HIP_PLATFORM_AMD__=1", + f"--offload-arch={';'.join(target_amdgpu_arch if target_amdgpu_arch is not None else [])}", +] + + +def get_hip_libraries(): + return amd_libraries + + +def get_hipcc_flags(): + return hipcc_flags diff --git a/pyproject.toml b/pyproject.toml index 8651e88..a52c50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "pplx-kernels" +name = "pplx-kernels-rocm" version = "0.0.1" -description = "Perplexity CUDA Kernels" +description = "Perplexity HIP/CUDA Kernels" readme = "README.md" requires-python = ">=3.9" @@ -14,7 +14,7 @@ line-length = 88 [tool.ruff.lint.isort] combine-as-imports = true -known-first-party = ["tests", "pplx_kernels"] +known-first-party = ["tests", "pplx-kernels-rocm"] [tool.ruff.lint] select = [ diff --git a/setup.py b/setup.py index 43eaae3..234ea08 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,17 @@ from setuptools.command.build_ext import build_ext +# NOTE(yiakwy) +from helper_rocm import ( + ROCM_HOME, + _is_hip, + target_amdgpu_arch, + get_hip_libraries, + get_hipcc_flags, + get_hipcc_rocm_version, +) + + def _get_torch_cmake_prefix_path() -> str: import torch @@ -43,7 +54,7 @@ def build_extension(self, ext: Extension) -> None: "-G", "Ninja", "-DCMAKE_PREFIX_PATH=" + _get_torch_cmake_prefix_path(), - "-DTORCH_CUDA_ARCH_LIST=" + os.environ["TORCH_CUDA_ARCH_LIST"], + "-DTORCH_CUDA_ARCH_LIST=" + target_amdgpu_arch if _is_hip else os.environ["TORCH_CUDA_ARCH_LIST"], "-WITH_TESTS=OFF", ] ) @@ -80,7 +91,7 @@ def run(self) -> None: extensions = [ Extension( - "pplx-kernels", + "pplx-kernels-rocm", sources=[], ), ] @@ -89,7 +100,7 @@ def run(self) -> None: packages=find_packages(where="src"), package_dir={"": "src"}, package_data={ - "pplx_kernels": ["libpplx_kernels.so", "py.typed"], + "pplx-kernels-rocm": ["libpplx_kernels.so", "py.typed"], }, cmdclass={ "build_ext": CMakeBuild,