diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f70395a..089a9db 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -24,6 +24,10 @@ add_executable(concurrent_put_imm_thread shmem/concurrent_put_imm_thread.cpp) target_link_libraries(concurrent_put_imm_thread mori_shmem hip::host hip::device) +add_executable(concurrent_put_signal_thread shmem/concurrent_put_signal_thread.cpp) +target_link_libraries(concurrent_put_signal_thread mori_shmem hip::host + hip::device) + add_executable(atomic_nonfetch_thread shmem/atomic_nonfetch_thread.cpp) target_link_libraries(atomic_nonfetch_thread mori_shmem hip::host hip::device) diff --git a/examples/shmem/concurrent_put_signal_thread.cpp b/examples/shmem/concurrent_put_signal_thread.cpp new file mode 100644 index 0000000..54d26ad --- /dev/null +++ b/examples/shmem/concurrent_put_signal_thread.cpp @@ -0,0 +1,191 @@ +// Copyright © Advanced Micro Devices, Inc. All rights reserved. +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +#include + +#include + +#include "mori/application/utils/check.hpp" +#include "mori/shmem/shmem.hpp" + +using namespace mori::core; +using namespace mori::shmem; +using namespace mori::application; + +__global__ void ConcurrentPutSignalThreadKernelAdd(int myPe, const SymmMemObjPtr dataObj, + const SymmMemObjPtr signalObj) { + constexpr int sendPe = 0; + constexpr int recvPe = 1; + + int globalTid = blockIdx.x * blockDim.x + threadIdx.x; + int threadOffset = globalTid * sizeof(uint32_t); + + if (myPe == sendPe) { + RdmaMemoryRegion source = dataObj->GetRdmaMemoryRegion(myPe); + + // Test onlyOneSignal=true with AMO_ADD: only leader thread signals + ShmemPutMemNbiSignalThread(dataObj, threadOffset, source, threadOffset, + sizeof(uint32_t), signalObj, 0, 1, atomicType::AMO_ADD, + recvPe, 0); + __threadfence_system(); + + ShmemQuietThread(); + } else { + // Receiver: wait for all data to arrive by checking signal counter + if (threadIdx.x == 0) { + uint64_t* signalPtr = reinterpret_cast(signalObj->localPtr); + uint64_t expectedSignals = blockDim.x * gridDim.x / warpSize; // One signal per warp + while (atomicAdd(signalPtr, 0) != expectedSignals) { + // Busy wait for all signals + } + printf("PE %d: AMO_ADD test - Received all %lu signals!\n", myPe, expectedSignals); + } + __syncthreads(); + + // Verify data + uint32_t receivedData = atomicAdd(reinterpret_cast(dataObj->localPtr) + globalTid, 0); + if (receivedData != sendPe) { + printf("PE %d, thread %d: Data mismatch! Expected %d, got %d\n", myPe, globalTid, sendPe, + receivedData); + } + } +} + +__global__ void ConcurrentPutSignalThreadKernelSet(int myPe, const SymmMemObjPtr dataObj, + const SymmMemObjPtr signalObj) { + constexpr int sendPe = 0; + constexpr int recvPe = 1; + constexpr uint64_t MAGIC_VALUE = 0xDEADBEEF; + + int globalTid = blockIdx.x * blockDim.x + threadIdx.x; + int threadOffset = globalTid * sizeof(uint32_t); + + if (myPe == sendPe) { + RdmaMemoryRegion source = dataObj->GetRdmaMemoryRegion(myPe); + + // Test onlyOneSignal=true with AMO_SET: only leader thread signals + ShmemPutMemNbiSignalThread(dataObj, threadOffset, source, threadOffset, + sizeof(uint32_t), signalObj, 0, MAGIC_VALUE, + atomicType::AMO_SET, recvPe, 0); + __threadfence_system(); + + ShmemQuietThread(); + } else { + // Receiver: wait for signal to be set to magic value + if (threadIdx.x == 0) { + uint64_t* signalPtr = reinterpret_cast(signalObj->localPtr); + while (atomicAdd(signalPtr, 0) != MAGIC_VALUE) { + // Busy wait for signal + } + printf("PE %d: AMO_SET test - Received magic signal value 0x%lx!\n", myPe, MAGIC_VALUE); + } + __syncthreads(); + + // Verify data + uint32_t receivedData = atomicAdd(reinterpret_cast(dataObj->localPtr) + globalTid, 0); + if (receivedData != sendPe) { + printf("PE %d, thread %d: Data mismatch! Expected %d, got %d\n", myPe, globalTid, sendPe, + receivedData); + } + } +} + +void ConcurrentPutSignalThread() { + int status; + MPI_Init(NULL, NULL); + + status = ShmemMpiInit(MPI_COMM_WORLD); + assert(!status); + + // Assume in same node + int myPe = ShmemMyPe(); + int npes = ShmemNPes(); + assert(npes == 2); + + constexpr int threadNum = 128; + constexpr int blockNum = 3; + + // Allocate data buffer + int numEle = threadNum * blockNum; + int buffSize = numEle * sizeof(uint32_t); + + void* dataBuff = ShmemMalloc(buffSize); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(dataBuff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + + SymmMemObjPtr dataBuffObj = ShmemQueryMemObjPtr(dataBuff); + assert(dataBuffObj.IsValid()); + + // Allocate signal buffer + void* signalBuff = ShmemMalloc(sizeof(uint64_t)); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(signalBuff), 0, 2)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + + SymmMemObjPtr signalBuffObj = ShmemQueryMemObjPtr(signalBuff); + assert(signalBuffObj.IsValid()); + + MPI_Barrier(MPI_COMM_WORLD); + + // Test 1: AMO_ADD signal operation + if (myPe == 0) { + printf("\n=== Test 1: PutMemNbi with Signal (AMO_ADD) ===\n"); + } + MPI_Barrier(MPI_COMM_WORLD); + + ConcurrentPutSignalThreadKernelAdd<<>>(myPe, dataBuffObj, signalBuffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + + if (myPe == 0) { + printf("Test 1 completed successfully!\n"); + } + + // Reset buffers for next test + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(dataBuff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(signalBuff), 0, 2)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + + // Test 2: AMO_SET signal operation + if (myPe == 0) { + printf("\n=== Test 2: PutMemNbi with Signal (AMO_SET) ===\n"); + } + MPI_Barrier(MPI_COMM_WORLD); + + ConcurrentPutSignalThreadKernelSet<<>>(myPe, dataBuffObj, signalBuffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + + if (myPe == 0) { + printf("Test 2 completed successfully!\n"); + printf("\n=== All PutMemNbi with Signal tests passed! ===\n"); + } + + // Finalize + ShmemFree(dataBuff); + ShmemFree(signalBuff); + ShmemFinalize(); +} + +int main(int argc, char* argv[]) { + ConcurrentPutSignalThread(); + return 0; +} \ No newline at end of file diff --git a/include/mori/core/utils.hpp b/include/mori/core/utils.hpp index dbf2077..e3cfe49 100644 --- a/include/mori/core/utils.hpp +++ b/include/mori/core/utils.hpp @@ -151,6 +151,26 @@ inline __device__ void AtomicStoreSeqCstSystem(T* ptr, T val) { return __hip_atomic_store(ptr, val, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_SYSTEM); } +template +inline __device__ T AtomicAddSeqCst(T* ptr, T val) { + return __hip_atomic_fetch_add(ptr, val, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT); +} + +template +inline __device__ T AtomicAddSeqCstSystem(T* ptr, T val) { + return __hip_atomic_fetch_add(ptr, val, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template +inline __device__ T AtomicAddRelaxed(T* ptr, T val) { + return __hip_atomic_fetch_add(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); +} + +template +inline __device__ T AtomicAddRelaxedSystem(T* ptr, T val) { + return __hip_atomic_fetch_add(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + template __device__ T AtomicCompareExchange(T* address, T* compare, T val) { __hip_atomic_compare_exchange_strong(address, compare, val, __ATOMIC_RELAXED, __ATOMIC_RELAXED, diff --git a/include/mori/shmem/shmem_device_api.hpp b/include/mori/shmem/shmem_device_api.hpp index 082e644..242857d 100644 --- a/include/mori/shmem/shmem_device_api.hpp +++ b/include/mori/shmem/shmem_device_api.hpp @@ -45,6 +45,17 @@ namespace shmem { assert(false); \ } +#define DISPATCH_TRANSPORT_TYPE_WITH_BOOL(func, boolParam, pe, ...) \ + GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); \ + application::TransportType transportType = globalGpuStates->transportTypes[pe]; \ + if (transportType == application::TransportType::RDMA) { \ + func(__VA_ARGS__); \ + } else if (transportType == application::TransportType::P2P) { \ + func(__VA_ARGS__); \ + } else { \ + assert(false); \ + } + #define DISPATCH_TRANSPORT_DATA_TYPE_WITH_RETURN(func, pe, type, ...) \ [&]() { \ GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); \ @@ -77,6 +88,10 @@ inline __device__ void ShmemQuietThread(int pe, int qpId) { /* ---------------------------------------------------------------------------------------------- */ /* Point-to-Point */ /* ---------------------------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------------------------------- */ +/* PutNbi APIs */ +/* ---------------------------------------------------------------------------------------------- */ #define DEFINE_SHMEM_PUT_MEM_NBI_API_TEMPLATE(Scope) \ inline __device__ void ShmemPutMemNbi##Scope( \ const application::SymmMemObjPtr dest, size_t destOffset, \ @@ -156,6 +171,9 @@ DEFINE_SHMEM_PUT_TYPE_NBI_API(Int64, int64_t, Warp) DEFINE_SHMEM_PUT_TYPE_NBI_API(Float, float, Warp) DEFINE_SHMEM_PUT_TYPE_NBI_API(Double, double, Warp) +/* ---------------------------------------------------------------------------------------------- */ +/* PutNbi Inline APIs */ +/* ---------------------------------------------------------------------------------------------- */ // TODO: deal with bytes count limit #define SHMEM_PUT_SIZE_IMM_NBI_API(Scope) \ inline __device__ void ShmemPutSizeImmNbi##Scope(const application::SymmMemObjPtr dest, \ @@ -204,6 +222,109 @@ DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Int32, int32_t, Warp) DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Uint64, uint64_t, Warp) DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Int64, int64_t, Warp) +/* ---------------------------------------------------------------------------------------------- */ +/* PutNbi with Signal APIs */ +/* ---------------------------------------------------------------------------------------------- */ +// PutNbi with Signal - Memory version +#define DEFINE_SHMEM_PUT_MEM_NBI_SIGNAL_API_TEMPLATE(Scope) \ + template \ + inline __device__ void ShmemPutMemNbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, \ + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + DISPATCH_TRANSPORT_TYPE_WITH_BOOL(ShmemPutMemNbiSignal##Scope##Kernel, onlyOneSignal, pe, \ + dest, destOffset, source, sourceOffset, bytes, signalDest, \ + signalDestOffset, signalValue, signalOp, pe, qpId); \ + } \ + template \ + inline __device__ void ShmemPutMemNbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, \ + const application::SymmMemObjPtr source, size_t sourceOffset, size_t bytes, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + int rank = GetGlobalGpuStatesPtr()->rank; \ + ShmemPutMemNbiSignal##Scope( \ + dest, destOffset, source->GetRdmaMemoryRegion(rank), sourceOffset, bytes, signalDest, \ + signalDestOffset, signalValue, signalOp, pe, qpId); \ + } + +DEFINE_SHMEM_PUT_MEM_NBI_SIGNAL_API_TEMPLATE(Thread) +DEFINE_SHMEM_PUT_MEM_NBI_SIGNAL_API_TEMPLATE(Warp) + +// PutNbi with Signal - Typed version +#define DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API_TEMPLATE(Scope) \ + template \ + inline __device__ void ShmemPutTypeNbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + constexpr size_t typeSize = sizeof(T); \ + ShmemPutMemNbiSignal##Scope( \ + dest, destElmOffset * typeSize, source, srcElmOffset * typeSize, nelems * typeSize, \ + signalDest, signalDestOffset, signalValue, signalOp, pe, qpId); \ + } \ + template \ + inline __device__ void ShmemPutTypeNbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + int rank = GetGlobalGpuStatesPtr()->rank; \ + ShmemPutTypeNbiSignal##Scope( \ + dest, destElmOffset, source->GetRdmaMemoryRegion(rank), srcElmOffset, nelems, signalDest, \ + signalDestOffset, signalValue, signalOp, pe, qpId); \ + } + +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API_TEMPLATE(Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API_TEMPLATE(Warp) + +// PutNbi with Signal - Concrete typed versions +#define DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(TypeName, T, Scope) \ + template \ + inline __device__ void ShmemPut##TypeName##NbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + ShmemPutTypeNbiSignal##Scope(dest, destElmOffset, source, srcElmOffset, \ + nelems, signalDest, signalDestOffset, \ + signalValue, signalOp, pe, qpId); \ + } \ + template \ + inline __device__ void ShmemPut##TypeName##NbiSignal##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, \ + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, \ + core::atomicType signalOp, int pe, int qpId = 0) { \ + ShmemPutTypeNbiSignal##Scope(dest, destElmOffset, source, srcElmOffset, \ + nelems, signalDest, signalDestOffset, \ + signalValue, signalOp, pe, qpId); \ + } + +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint8, uint8_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int8, int8_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint16, uint16_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int16, int16_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint32, uint32_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int32, int32_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint64, uint64_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int64, int64_t, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Float, float, Thread) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Double, double, Thread) + +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint8, uint8_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int8, int8_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint16, uint16_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int16, int16_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint32, uint32_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int32, int32_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Uint64, uint64_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Int64, int64_t, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Float, float, Warp) +DEFINE_SHMEM_PUT_TYPE_NBI_SIGNAL_API(Double, double, Warp) + #define SHMEM_ATOMIC_SIZE_NONFETCH_API_TEMPLATE(Scope) \ inline __device__ void ShmemAtomicSizeNonFetch##Scope( \ const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, \ @@ -243,6 +364,9 @@ DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Uint64, uint64_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Int32, int32_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Int64, int64_t, Warp) +/* ---------------------------------------------------------------------------------------------- */ +/* Atomic Fetch APIs */ +/* ---------------------------------------------------------------------------------------------- */ #define SHMEM_ATOMIC_TYPE_FETCH_API_TEMPLATE(Scope) \ template \ inline __device__ T ShmemAtomicTypeFetch##Scope( \ @@ -274,6 +398,9 @@ DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(Uint64, uint64_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(Int32, int32_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(Int64, int64_t, Warp) +/* ---------------------------------------------------------------------------------------------- */ +/* Wait Until Greater Than APIs */ +/* ---------------------------------------------------------------------------------------------- */ template inline __device__ T ShmemTypeWaitUntilGreaterThan(T* addr, T val) { T got; @@ -297,6 +424,9 @@ DEFINE_SHMEM_TYPE_WAIT_UNTIL_GREATER_THAN(Int32, int32_t) DEFINE_SHMEM_TYPE_WAIT_UNTIL_GREATER_THAN(Uint64, uint64_t) DEFINE_SHMEM_TYPE_WAIT_UNTIL_GREATER_THAN(Int64, int64_t) +/* ---------------------------------------------------------------------------------------------- */ +/* Wait Until Equal APIs */ +/* ---------------------------------------------------------------------------------------------- */ template inline __device__ void ShmemTypeWaitUntilEquals(T* addr, T val) { while (core::AtomicLoadRelaxedSystem(addr) != val) { diff --git a/include/mori/shmem/shmem_device_kernels.hpp b/include/mori/shmem/shmem_device_kernels.hpp index 1a118a8..ab8ce51 100644 --- a/include/mori/shmem/shmem_device_kernels.hpp +++ b/include/mori/shmem/shmem_device_kernels.hpp @@ -53,6 +53,20 @@ inline __device__ void ShmemPutSizeImmNbiWarpKernel(const application::SymmMemOb size_t destOffset, void* val, size_t bytes, int pe, int qpId = 0); +template +inline __device__ void ShmemPutMemNbiSignalThreadKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId = 0); + +template +inline __device__ void ShmemPutMemNbiSignalWarpKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId = 0); + template inline __device__ void ShmemAtomicSizeNonFetchThreadKernel(const application::SymmMemObjPtr dest, size_t destOffset, void* val, diff --git a/include/mori/shmem/shmem_ibgda_kernels.hpp b/include/mori/shmem/shmem_ibgda_kernels.hpp index 118e6c9..dea6d7e 100644 --- a/include/mori/shmem/shmem_ibgda_kernels.hpp +++ b/include/mori/shmem/shmem_ibgda_kernels.hpp @@ -80,6 +80,15 @@ namespace shmem { } \ }() +#define DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_BOOL(func, boolParam, ...) \ + do { \ + if constexpr (DISPATCH_BNXT == 1) { \ + func(__VA_ARGS__); \ + } else { \ + func(__VA_ARGS__); \ + } \ + } while (0) + /* ---------------------------------------------------------------------------------------------- */ /* Synchronization */ /* ---------------------------------------------------------------------------------------------- */ @@ -245,7 +254,7 @@ inline __device__ void ShmemQuietThreadKernel( } template <> -inline __device__ void ShmemQuietThreadKernel(int pe){ +inline __device__ void ShmemQuietThreadKernel(int pe) { GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); int rank = globalGpuStates->rank; if (pe == rank) return; @@ -542,6 +551,234 @@ inline __device__ void ShmemPutSizeImmNbiWarpKernel +inline __device__ void ShmemPutMemNbiSignalThreadKernelImpl( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + if (bytes == 0) return; + uintptr_t laddr = source.addr + sourceOffset; + uintptr_t raddr = dest->peerPtrs[pe] + destOffset; + uintptr_t rkey = dest->peerRkeys[pe]; + + GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); + application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle* wq = &ep[epIndex].wqHandle; + core::CompletionQueueHandle* cq = &ep[epIndex].cqHandle; + uint32_t qpn = ep[epIndex].handle.qpn; + + uint64_t activemask = core::GetActiveLaneMask(); + uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); + uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); + bool is_leader{my_logical_lane_id == num_active_lanes - 1}; + const uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask); + uint32_t warp_sq_counter{0}; + uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; + uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0}; + uint32_t psnCnt = 0; + uint32_t num_wqes = onlyOneSignal ? num_active_lanes + 1 : num_active_lanes * 2; + + if constexpr (PrvdType == core::ProviderType::BNXT) { + psnCnt = (bytes + wq->mtuSize - 1) / wq->mtuSize; + } + if (is_leader) { + if constexpr (PrvdType == core::ProviderType::MLX5) { + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + core::atomic_add_packed_msn_and_psn( + &wq->msnPack, num_wqes, + psnCnt * num_active_lanes + (onlyOneSignal ? 1 : num_active_lanes), &warp_msntbl_counter, + &warp_psn_counter); + warp_sq_counter = warp_msntbl_counter; + __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else { + assert(false); + } + } + warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); + if constexpr (PrvdType == core::ProviderType::MLX5) { + my_sq_counter = warp_sq_counter + (onlyOneSignal ? my_logical_lane_id : my_logical_lane_id * 2); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + warp_msntbl_counter = __shfl(warp_msntbl_counter, leader_phys_lane_id); + warp_psn_counter = __shfl(warp_psn_counter, leader_phys_lane_id); + my_sq_counter = warp_sq_counter + (onlyOneSignal ? my_logical_lane_id : my_logical_lane_id * 2); + my_msntbl_counter = + warp_msntbl_counter + (onlyOneSignal ? my_logical_lane_id : my_logical_lane_id * 2); + my_psn_counter = warp_psn_counter + (onlyOneSignal ? psnCnt * my_logical_lane_id + : (psnCnt + 1) * my_logical_lane_id); + } else { + assert(false); + } + + while (true) { + uint64_t db_touched = + __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); + uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); + uint64_t num_active_sq_entries = db_touched - db_done; + uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries; + uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_wqes - db_touched; + if (num_free_entries > num_entries_until_warp_last_entry) { + break; + } + ShmemQuietThreadKernelImpl(pe, qpId); + } + // putmem nbi + if constexpr (PrvdType == core::ProviderType::MLX5) { + wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; + core::PostWrite(*wq, my_sq_counter, my_sq_counter, my_sq_counter, false, qpn, laddr, + source.lkey, raddr, rkey, bytes); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; + core::PostWrite(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, false, qpn, + laddr, source.lkey, raddr, rkey, bytes); + } else { + assert(false); + } + + // signal + uint64_t dbr_val; + uintptr_t signalRaddr = signalDest->peerPtrs[pe] + signalDestOffset; + uintptr_t signalRkey = signalDest->peerRkeys[pe]; + if (signalOp == core::atomicType::AMO_SET || signalOp == core::atomicType::AMO_SIGNAL_SET) { + // TODO: not support masked atomic yet, use write inline for now + bool should_signal = onlyOneSignal ? is_leader : true; + if (should_signal) { + if constexpr (PrvdType == core::ProviderType::MLX5) { + wq->outstandingWqe[(my_sq_counter + 1) % OUTSTANDING_TABLE_SIZE] = my_sq_counter + 1; + dbr_val = core::PostWriteInline(*wq, my_sq_counter + 1, my_sq_counter + 1, + my_sq_counter + 1, is_leader, qpn, &signalValue, + signalRaddr, signalRkey, sizeof(signalValue)); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + wq->outstandingWqe[(my_sq_counter + 1) % wq->sqWqeNum] = my_sq_counter + 1; + dbr_val = core::PostWriteInline(*wq, my_sq_counter + 1, my_msntbl_counter + 1, + my_psn_counter + 1, is_leader, qpn, &signalValue, + signalRaddr, signalRkey, sizeof(signalValue)); + } + } + + } else if (signalOp == core::atomicType::AMO_ADD || + signalOp == core::atomicType::AMO_SIGNAL_ADD) { + core::IbufHandle* ibuf = &ep[epIndex].atomicIbuf; + bool should_signal = onlyOneSignal ? is_leader : true; + if (should_signal) { + if constexpr (PrvdType == core::ProviderType::MLX5) { + wq->outstandingWqe[(my_sq_counter + 1) % OUTSTANDING_TABLE_SIZE] = my_sq_counter + 1; + dbr_val = core::PostAtomic( + *wq, my_sq_counter + 1, my_sq_counter + 1, my_sq_counter + 1, is_leader, qpn, + ibuf->addr, ibuf->lkey, signalRaddr, signalRkey, &signalValue, &signalValue, + sizeof(signalValue), core::atomicType::AMO_ADD); + } else if constexpr (PrvdType == core::ProviderType::BNXT) { + wq->outstandingWqe[(my_sq_counter + 1) % wq->sqWqeNum] = my_sq_counter + 1; + dbr_val = core::PostAtomic( + *wq, my_sq_counter + 1, my_msntbl_counter + 1, my_psn_counter + 1, is_leader, qpn, + ibuf->addr, ibuf->lkey, signalRaddr, signalRkey, &signalValue, &signalValue, + sizeof(signalValue), core::atomicType::AMO_ADD); + } + } + } else { + assert(false && "signal unsupported atomic type"); + } + + // __threadfence_system(); + if (is_leader) { + uint64_t db_touched{0}; + do { + db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } while (db_touched != warp_sq_counter); + + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); + // __threadfence_system(); + core::RingDoorbell(wq->dbrAddr, dbr_val); + // __threadfence_system(); + + __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalThreadKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + bool need_turn{true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_BOOL( + ShmemPutMemNbiSignalThreadKernelImpl, true, dest, destOffset, source, sourceOffset, bytes, + signalDest, signalDestOffset, signalValue, signalOp, pe, qpId); + need_turn = false; + } + turns = __ballot(need_turn); + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalThreadKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + bool need_turn{true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_BOOL( + ShmemPutMemNbiSignalThreadKernelImpl, false, dest, destOffset, source, sourceOffset, + bytes, signalDest, signalDestOffset, signalValue, signalOp, pe, qpId); + need_turn = false; + } + turns = __ballot(need_turn); + } +} + +template +inline __device__ void ShmemPutMemNbiSignalWarpKernelImpl( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + int laneId = threadIdx.x & (warpSize - 1); + if (laneId == 0) { + ShmemPutMemNbiSignalThreadKernelImpl( + dest, destOffset, source, sourceOffset, bytes, signalDest, signalDestOffset, signalValue, + signalOp, pe, qpId); + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalWarpKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_BOOL(ShmemPutMemNbiSignalWarpKernelImpl, true, dest, + destOffset, source, sourceOffset, bytes, signalDest, + signalDestOffset, signalValue, signalOp, pe, qpId); +} + +template <> +inline __device__ void ShmemPutMemNbiSignalWarpKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_BOOL(ShmemPutMemNbiSignalWarpKernelImpl, false, dest, + destOffset, source, sourceOffset, bytes, signalDest, + signalDestOffset, signalValue, signalOp, pe, qpId); +} + template inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, diff --git a/include/mori/shmem/shmem_p2p_kernels.hpp b/include/mori/shmem/shmem_p2p_kernels.hpp index 37f3a06..3692fed 100644 --- a/include/mori/shmem/shmem_p2p_kernels.hpp +++ b/include/mori/shmem/shmem_p2p_kernels.hpp @@ -101,6 +101,127 @@ inline __device__ void ShmemPutSizeImmNbiWarpKernel +inline __device__ void ShmemPutMemNbiSignalThreadKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + if (bytes == 0) return; + + // Execute put operation + uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); + uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); + core::ThreadCopy(destPtr, srcPtr, bytes); + + // Execute signal operation (only once for onlyOneSignal=true) + uint64_t activemask = core::GetActiveLaneMask(); + uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); + uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); + bool is_leader = (my_logical_lane_id == num_active_lanes - 1); + + if (is_leader) { + uint64_t* signalPtr = reinterpret_cast(signalDest->peerPtrs[pe] + signalDestOffset); + if (signalOp == core::atomicType::AMO_SET || signalOp == core::atomicType::AMO_SIGNAL_SET) { + core::AtomicStoreSeqCstSystem(signalPtr, signalValue); + } else if (signalOp == core::atomicType::AMO_ADD || + signalOp == core::atomicType::AMO_SIGNAL_ADD) { + core::AtomicAddSeqCstSystem(signalPtr, signalValue); + } else { + assert(false && "Unsupported signal operation"); + } + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalThreadKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + if (bytes == 0) return; + + // Execute put operation + uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); + uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); + core::ThreadCopy(destPtr, srcPtr, bytes); + + // Execute signal operation (every thread signals for onlyOneSignal=false) + uint64_t* signalPtr = reinterpret_cast(signalDest->peerPtrs[pe] + signalDestOffset); + if (signalOp == core::atomicType::AMO_SET || signalOp == core::atomicType::AMO_SIGNAL_SET) { + core::AtomicStoreSeqCstSystem(signalPtr, signalValue); + } else if (signalOp == core::atomicType::AMO_ADD || + signalOp == core::atomicType::AMO_SIGNAL_ADD) { + core::AtomicAddSeqCstSystem(signalPtr, signalValue); + } else { + assert(false && "Unsupported signal operation"); + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalWarpKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + if (bytes == 0) return; + + // Execute put operation (all lanes participate) + uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); + uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); + core::WarpCopy(destPtr, srcPtr, bytes); + + // Execute signal operation (only lane 0 for onlyOneSignal=true) + int laneId = threadIdx.x & (warpSize - 1); + if (laneId == 0) { + uint64_t* signalPtr = reinterpret_cast(signalDest->peerPtrs[pe] + signalDestOffset); + if (signalOp == core::atomicType::AMO_SET || signalOp == core::atomicType::AMO_SIGNAL_SET) { + core::AtomicStoreSeqCstSystem(signalPtr, signalValue); + } else if (signalOp == core::atomicType::AMO_ADD || + signalOp == core::atomicType::AMO_SIGNAL_ADD) { + core::AtomicAddSeqCstSystem(signalPtr, signalValue); + } else { + assert(false && "Unsupported signal operation"); + } + } +} + +template <> +inline __device__ void ShmemPutMemNbiSignalWarpKernel( + const application::SymmMemObjPtr dest, size_t destOffset, + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, + const application::SymmMemObjPtr signalDest, size_t signalDestOffset, uint64_t signalValue, + core::atomicType signalOp, int pe, int qpId) { + if (bytes == 0) return; + + // Execute put operation (all lanes participate) + uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); + uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); + core::WarpCopy(destPtr, srcPtr, bytes); + + // Execute signal operation (lane 0 only, but signals once per active lane) + int laneId = threadIdx.x & (warpSize - 1); + if (laneId == 0) { + uint64_t activemask = core::GetActiveLaneMask(); + uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); + + uint64_t* signalPtr = reinterpret_cast(signalDest->peerPtrs[pe] + signalDestOffset); + if (signalOp == core::atomicType::AMO_SET || signalOp == core::atomicType::AMO_SIGNAL_SET) { + core::AtomicStoreSeqCstSystem(signalPtr, signalValue); + } else if (signalOp == core::atomicType::AMO_ADD || + signalOp == core::atomicType::AMO_SIGNAL_ADD) { + core::AtomicAddSeqCstSystem(signalPtr, signalValue * num_active_lanes); + } else { + assert(false && "Unsupported signal operation"); + } + } +} + + template <> inline __device__ void ShmemAtomicSizeNonFetchThreadKernel( const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, diff --git a/src/application/transport/rdma/providers/bnxt/bnxt.cpp b/src/application/transport/rdma/providers/bnxt/bnxt.cpp index b8e31ba..ec77eca 100644 --- a/src/application/transport/rdma/providers/bnxt/bnxt.cpp +++ b/src/application/transport/rdma/providers/bnxt/bnxt.cpp @@ -273,15 +273,6 @@ BnxtQpContainer::BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig& memset(atomicIbufAddr, 0, atomicIbufSize); assert(!status); } - if (config.onGpu) { - HIP_RUNTIME_CHECK( - hipExtMallocWithFlags(&atomicIbufAddr, atomicIbufSize, hipDeviceMallocUncached)); - HIP_RUNTIME_CHECK(hipMemset(atomicIbufAddr, 0, atomicIbufSize)); - } else { - err = posix_memalign(&atomicIbufAddr, config.alignment, atomicIbufSize); - memset(atomicIbufAddr, 0, atomicIbufSize); - assert(!err); - } // Register atomic ibuf as independent memory region atomicIbufMr = ibv_reg_mr(pd, atomicIbufAddr, atomicIbufSize,