From 63b7683e7015cf2afcdbe91795605fe637fa0f99 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Mon, 17 Nov 2025 17:21:31 -0800 Subject: [PATCH 01/32] Staging initial integration of kernel code. --- csrc/trtllm_mnnvl_allreduce.cu | 117 +- .../comm/trtllm_mnnvl_allreduce.cuh | 1483 +++++++++++------ include/flashinfer/utils.cuh | 12 + 3 files changed, 1085 insertions(+), 527 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 6bac5372a8..05a1684aa0 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,77 +26,84 @@ using tvm::ffi::Optional; } \ }() -void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, - int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, - int64_t rank, bool wait_for_results, bool launch_with_pdl, - Optional out) { - cudaSetDevice(in.device().device_id); - auto stream = get_stream(in.device()); +// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this +// level +void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, + int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, + TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, + bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, + TensorView output, Optional residual_out, + Optional gamma, Optional epsilon) { + cudaSetDevice(input.device().device_id); + auto stream = get_stream(input.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] { // Extract parameters from tensors - int64_t num_tokens = in.size(0); - int64_t token_dim = in.size(1); + int64_t num_tokens = input.size(0); + int64_t token_dim = input.size(1); // Validate input parameters - TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) - << "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type); + TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0) + << "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type); + TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1)) + << "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << output.size(0) << ", " << output.size(1) << ")"; TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64) << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK(out.has_value() || !wait_for_results) - << "out tensor must be provided if wait_for_results is true"; + TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + !rmsnorm_fusion) + << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + + if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_out.size(0) == num_tokens && residual_out.size(1) == token_dim) + << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_out.size(0) << ", " << residual_out.size(1) << ")"; + TVM_FFI_ICHECK(gamma.size(0) == token_dim) + << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" + << gamma.size(0) << ")"; + } // Create the parameters struct - AllReduceParams params; - params.nranks = nranks; - params.rank = rank; - params.buffer_M = buffer_M; - params.num_tokens = num_tokens; - params.token_dim = token_dim; - params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); - params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); - params.buffer_flags = buffer_flags_mnnvl.data_ptr(); - params.wait_for_results = wait_for_results; - params.launch_with_pdl = launch_with_pdl; - params.input = in.data_ptr(); - params.output = out.has_value() ? out.value().data_ptr() : nullptr; - params.stream = stream; + AllReduceFusionParams params; - auto status = twoshot_allreduce_dispatch_world_size(params); - TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); - }); -} + // Aux Information + params.nRanks = nranks; + params.rank = rank; + params.numTokens = num_tokens; + params.tokenDim = token_dim; + params.bufferPtrsDev = reinterpret_cast(buffer_ptrs_dev); + params.bufferPtrLocal = reinterpret_cast(buffer_ptr_local); + params.multicastPtr = reinterpret_cast(multicast_buffer_ptr); + params.bufferFlags = reinterpret_cast(buffer_flags_mnnvl.data_ptr()); + params.rmsNormFusion = rmsnorm_fusion; + params.launchWithPdl = launch_with_pdl; -void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, - TensorView normed_output, TensorView gamma, double epsilon, - TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { - cudaSetDevice(prenorm_output.device().device_id); - auto stream = get_stream(prenorm_output.device()); + // input data + params.input = const_cast(input.data_ptr()); + params.residualIn = residual_out.has_value() + ? const_cast(residual_out.value().data_ptr()) + : nullptr; + params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; + params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] { - // Create the parameters struct - RMSNormParams params; - params.residual_output = prenorm_output.data_ptr(); - params.output = normed_output.data_ptr(); - params.input = reinterpret_cast(multicast_buffer_ptr); - params.gamma = gamma.data_ptr(); - params.epsilon = epsilon; - params.residual = residual.data_ptr(); - params.buffer_flags = reinterpret_cast(buffer_flags.data_ptr()); - params.batch = normed_output.size(0); - params.hidden_dim = normed_output.size(1); + // output data + params.output = const_cast(output.data_ptr()); + params.residualOut = + residual_out.has_value() ? const_cast(residual_out.value().data_ptr()) : nullptr; params.stream = stream; - params.launch_with_pdl = launch_with_pdl; - auto status = twoshot_rmsnorm_dispatch_hidden_dim(params); + + cudaError_t status; + if (use_oneshot) { + status = oneshotAllreduceFusionDispatch(params); + } else { + status = twoshotAllreduceFusionDispatch(params); + } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_rmsnorm_dispatch_hidden_dim failed with error code " + << "twoshot_allreduce_dispatch_world_size failed with error code " << cudaGetErrorString(status); }); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion); diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 3dbed4b649..9198df8775 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -18,52 +18,54 @@ #include #include #include +#include #include +#include #include "../exception.h" #include "../logging.h" +#include "../utils.cuh" namespace flashinfer { namespace trtllm_mnnvl_allreduce { -template -struct AllReduceParams { - int nranks; +struct AllReduceFusionParams { + int nRanks; int rank; - int buffer_M; - int num_tokens; - int token_dim; - void** buffer_ptrs_dev; - void* multicast_ptr; - void* buffer_flags; - bool wait_for_results; - bool launch_with_pdl; - - void* input; - void* output; - cudaStream_t stream; -}; + int numTokens; + int tokenDim; + void** bufferPtrsDev; + void* bufferPtrLocal; + void* multicastPtr; + uint32_t* bufferFlags; + bool rmsNormFusion; + bool launchWithPdl; -template -struct RMSNormParams { - void* residual_output; - void* output; void const* input; + void const* residualIn; void const* gamma; double epsilon; - void* residual; - uint32_t* buffer_flags; - int batch; - int hidden_dim; - cudaStream_t stream; - bool launch_with_pdl; + + void* residualOut; + void* output; + cudaStream_t stream = nullptr; }; -__device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } +namespace utils { + +constexpr uint16_t kNEGZERO_FP16 = 0x8000U; + +template +union Fp16BitCast { + T mFp; + uint16_t mInt; + + constexpr Fp16BitCast() : mInt(0) {} -__device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } + constexpr Fp16BitCast(T val) : mFp(val) {} -__device__ bool isNegZero(__nv_half val) { return isNegZero(__half2float(val)); } + constexpr Fp16BitCast(uint16_t val) : mInt(val) {} +}; template inline __device__ float toFloat(T val) { @@ -74,7 +76,6 @@ template <> inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } - template <> inline __device__ float toFloat<__nv_half>(__nv_half val) { return __half2float(val); @@ -95,581 +96,1119 @@ inline __device__ __nv_half fromFloat<__nv_half>(float val) { return __float2half(val); } -inline __device__ float2 loadfloat2(void const* ptr) { - float2 return_value; - asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" - : "=f"(return_value.x), "=f"(return_value.y) - : "l"(ptr)); - return return_value; +template +static constexpr __device__ __host__ T negZero() { + if constexpr (std::is_same_v) { + return -0.0F; + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(kNEGZERO_FP16).mFp; + } else { + static_assert(sizeof(T) == 0, "negativeZero not specialized for this type"); + } + return T{}; // Never reached, but needed for compilation } template -inline __device__ T divUp(T val, T divisor) { - return (val + divisor - 1) / divisor; +static inline __device__ bool isNegZero(T val) { + if constexpr (std::is_same_v) { + return val == 0.F && signbit(val); + } else if constexpr (std::is_same_v || std::is_same_v) { + return Fp16BitCast(val).mInt == kNEGZERO_FP16; + } else { + static_assert(sizeof(T) == 0, "isNegZero not specialized for this type"); + } + return false; // Never reached, but needed for compilation } +template +constexpr __device__ __host__ PackedType getPackedLamportInit() { + static_assert(sizeof(PackedType) % sizeof(T) == 0, "PackedType size must be divisible by T size"); + constexpr int kNumElements = sizeof(PackedType) / sizeof(T); + + union PackedT { + PackedType mPacked; + std::array mElements; + + constexpr PackedT() : mElements{} { + for (int i = 0; i < kNumElements; i++) { + mElements[i] = negZero(); + } + } + }; + + PackedT initValue{}; + return initValue.mPacked; +} + +// A helper class to get the correct base pointer for a given layout +struct LamportBufferLayout { + uint32_t numStages = 1; + uint32_t bytesPerBuffer = 0; + static constexpr uint32_t sNumLamportBuffers = 3; + + // Implicitly inlined + [[nodiscard]] __device__ __host__ size_t getTotalBytes() const { + return numStages * static_cast(bytesPerBuffer / numStages) * sNumLamportBuffers; + } + + // Implicitly inlined + [[nodiscard]] __device__ __host__ void* getStagePtr(void* bufferBasePtr, uint32_t lamportIndex, + uint32_t stageIndex) const { + // Typecast to avoid warnings + return reinterpret_cast( + reinterpret_cast(bufferBasePtr) + + static_cast((lamportIndex * numStages + stageIndex) * + static_cast(bytesPerBuffer / numStages))); + } +}; +// Current Index +// Dirty Index +// bytes_per_buffer +// Dirty num_stages +// Dirty bytes_to_clear = {stage0, stage1, stage2, stage3} # We fix this to 4 stages +// offset_access_ptr + +namespace cg = cooperative_groups; + +// PackedType is the one used in kernel for Lamport buffer (LDG.128 or LDG.64) +template __device__ struct __attribute__((aligned(32))) LamportFlags { - uint32_t buffer_size; - uint32_t input_offset; - uint32_t clear_offset; - uint32_t num_tokens_prev; - uint32_t* offset_access_ptr; - uint32_t* buffer_flags; - - __device__ explicit LamportFlags(uint32_t* buffer_flags) - : offset_access_ptr(&buffer_flags[4]), buffer_flags(buffer_flags) { - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; - input_offset = flag.x * (buffer_size << 1U); - clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; - } - - __device__ void cta_arrive() { + public: + __device__ explicit LamportFlags(uint32_t* bufferFlags, uint32_t numStages = 1) + : mBufferFlagsPtr(bufferFlags), mFlagAccessPtr(&bufferFlags[8]) { + mCurBufferLayout.numStages = numStages; + uint4 flag = reinterpret_cast(bufferFlags)[0]; + mCurrentIndex = flag.x; + mDirtyIndex = flag.y; + // Buffer size is unchanged as the flag should be coupled to each buffer + mCurBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.bytesPerBuffer = flag.z; + mDirtyBufferLayout.numStages = flag.w; + *reinterpret_cast(&mBytesToClear) = reinterpret_cast(bufferFlags)[1]; + } + + // Return the base pointer of the lamport buffer indexed by mCurrentIndex and the stageIdx + [[nodiscard]] __device__ void* getCurLamportBuf(void* bufferBasePtr, int stageIdx = 0) const { + return mCurBufferLayout.getStagePtr(bufferBasePtr, mCurrentIndex, stageIdx); + } + + // Fill the dirty lamport buffer with the init value; Use stageIdx to select the stage to clear, + // -1 to clear all + // FIXME: Current kernel may use less stages than the dirty numStages; How to guarantee the + // correctness? CAUTION: This function requires all threads in the grid to participate and ASSUME + // 1D thread block layout! + __device__ void clearDirtyLamportBuf(void* bufferBasePtr, int stageIdx = -1) { + // Rasterize the threads to 1D for flexible clearing + + uint32_t globalCtaIdx = blockIdx.x * gridDim.y + blockIdx.y; + uint32_t globalTid = globalCtaIdx * blockDim.x + threadIdx.x; + uint32_t numThreads = gridDim.x * gridDim.y * blockDim.x; + + if (stageIdx == -1) { + // Clear all stages + for (uint32_t i = 0; i < mDirtyBufferLayout.numStages; i++) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[i], mDirtyIndex, i); + } + } else if (stageIdx < mDirtyBufferLayout.numStages) { + clearPackedBuf(bufferBasePtr, globalTid, numThreads, mBytesToClear[stageIdx], mDirtyIndex, + stageIdx); + } + } + + __device__ void ctaArrive() { + int tid{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + cg::cluster_group cluster = cg::this_cluster(); + // We update the atomic counter per cluster + tid = cluster.thread_rank(); + cluster.sync(); +#else + tid = threadIdx.x; __syncthreads(); - if (threadIdx.x == 0) { +#endif + if (tid == 0) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) - asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) + : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + asm volatile("red.release.global.gpu.add.u32 [%0], %1;" ::"l"(mFlagAccessPtr), "r"(1) : "memory"); -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); #else - atomicAdd(offset_access_ptr, 1); + atomicAdd(mFlagAccessPtr, 1); #endif } } - __device__ void wait_and_update(uint32_t num_tokens) { - if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) { - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + __device__ void waitAndUpdate(uint4 bytesToClearPerStage) { + bool isLastCtaT0{false}; + int targetCount{0}; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cg::grid_group grid = cg::this_grid(); + // Use the first thread instead of the last thread as the last thread may exit early + isLastCtaT0 = grid.thread_rank() == 0; + targetCount = grid.num_clusters(); +#else + isLastCtaT0 = threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0; + targetCount = gridDim.x * gridDim.y; +#endif + if (isLastCtaT0) { + uint4* flagPtr = reinterpret_cast(mBufferFlagsPtr); + while (*reinterpret_cast(mFlagAccessPtr) < targetCount) { } - uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_flags[0] = (flag.x + 1) % 3; - buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; - *(offset_access_ptr) = 0; + // 'Current' becomes 'Dirty' + flagPtr[0] = {(mCurrentIndex + 1) % 3, // Current index + mCurrentIndex, // Dirty index + mCurBufferLayout.bytesPerBuffer, // Buffer size + mCurBufferLayout.numStages}; // Dirty - Number of stages + flagPtr[1] = bytesToClearPerStage; + *mFlagAccessPtr = 0; + } + } + + private: + uint32_t* mBufferFlagsPtr; + uint32_t* mFlagAccessPtr; + + uint32_t mCurrentIndex, mDirtyIndex; + // So that we can access it with uint4 + alignas(16) std::array mBytesToClear; + LamportBufferLayout mCurBufferLayout, mDirtyBufferLayout; + + inline __device__ void clearPackedBuf(void* bufferBasePtr, uint32_t globalTid, + uint32_t numThreads, uint32_t bytesToClear, + uint8_t dirtyIndex, uint8_t stageIdx) { + // Round up to the float4 boundary + uint32_t clearBoundary = ceil_div(bytesToClear, sizeof(PackedType)); + for (uint32_t packedIdx = globalTid; packedIdx < clearBoundary; packedIdx += numThreads) { + reinterpret_cast( + mDirtyBufferLayout.getStagePtr(bufferBasePtr, dirtyIndex, stageIdx))[packedIdx] = + getPackedLamportInit(); + } + } +}; + +template +union PackedVec { + PackedType packed; + T elements[sizeof(PackedType) / sizeof(T)]; + + __device__ PackedVec& operator+=(PackedVec& other) { +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + elements[i] += other.elements[i]; + } + return *this; + } + + __device__ PackedVec operator+(PackedVec& other) { + PackedVec result; +#pragma unroll + for (int i = 0; i < sizeof(PackedType) / sizeof(T); i++) { + result.elements[i] = elements[i] + other.elements[i]; } + return result; } }; -template -__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, - int num_tokens, int buffer_M, int token_dim, int rank, - uint32_t* buffer_flags, bool wait_for_results) { - int elt = blockIdx.y * blockDim.x + threadIdx.x; +template +inline __device__ PackedType loadPacked(T* ptr) { + return *reinterpret_cast(ptr); +} + +template +inline __device__ const PackedType loadPacked(T const* ptr) { + return *reinterpret_cast(ptr); +} + +template +inline __device__ PackedType loadPackedVolatile(void const* ptr) { + static_assert(sizeof(PackedType) == 0, "Not implemented"); + return PackedType{}; +} + +template <> +inline __device__ float4 loadPackedVolatile(void const* ptr) { + float4 returnValue; + asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(returnValue.x), "=f"(returnValue.y), "=f"(returnValue.z), "=f"(returnValue.w) + : "l"(ptr)); + return returnValue; +} + +template <> +inline __device__ float2 loadPackedVolatile(void const* ptr) { + float2 returnValue; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" + : "=f"(returnValue.x), "=f"(returnValue.y) + : "l"(ptr)); + return returnValue; +} - if (elt >= token_dim) return; +template +inline __device__ void copyF4(T_IN* dst, T_IN const* src) { + float4* dst4 = reinterpret_cast(dst); + float4 const* src4 = reinterpret_cast(src); + __pipeline_memcpy_async(dst4, src4, sizeof(float4)); +} + +uint32_t constexpr kWARP_SIZE = 32U; +uint32_t constexpr kLOG2_WARP_SIZE = 5U; +uint32_t constexpr kLANE_ID_MASK = 0x1f; +uint32_t constexpr kFINAL_MASK = 0xffffffff; + +template +inline __device__ T warpReduceSumFull(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(kFINAL_MASK, val, mask, kWARP_SIZE); + } + return val; +} + +template +inline __device__ T warpReduceSumPartial(T val) { + int laneId = threadIdx.x & kLANE_ID_MASK; + // We make sure only the last warp will call this function + int warpSize = blockDim.x - (threadIdx.x & ~(kWARP_SIZE - 1)); + unsigned int active_mask = (1U << warpSize) - 1; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + int targetLane = laneId ^ mask; + auto tmp = __shfl_xor_sync(active_mask, val, mask, kWARP_SIZE); + val += targetLane < warpSize ? tmp : 0; + } + return val; +} + +// SYNC: +// - True: share the sum across all threads +// - False: only thread 0 get the sum; Other thread's value is undefined. +template +inline __device__ T blockReduceSumPartial(T val) { + __shared__ T smem[kWARP_SIZE]; + int laneId = threadIdx.x & kLANE_ID_MASK; + int warpId = threadIdx.x >> kLOG2_WARP_SIZE; + int warpNum = (blockDim.x + kWARP_SIZE - 1) >> + kLOG2_WARP_SIZE; // Ceiling division to include partial warps + + val = (warpId == warpNum - 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + if (laneId == 0) { + smem[warpId] = val; + } + __syncthreads(); + + if (warpId == 0) { + val = (laneId < warpNum) ? smem[laneId] : (T)0.f; + // Need to consider the corner case where we only have one warp and it is partial + val = (warpNum == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val); + + if constexpr (SYNC) { + if (laneId == 0) { + smem[warpId] = val; + } + } + } + if constexpr (SYNC) { + __syncthreads(); + val = smem[0]; + } + return val; +} + +template +inline __device__ T blockReduceSumFull(T val) { + __shared__ T smem[kWARP_SIZE]; + int lane_id = threadIdx.x & kLANE_ID_MASK; + int warp_id = threadIdx.x >> kLOG2_WARP_SIZE; + int warp_num = blockDim.x >> kLOG2_WARP_SIZE; + + val = warpReduceSumFull(val); + if (lane_id == 0) { + smem[warp_id] = val; + } + __syncthreads(); + + val = (lane_id < warp_num) ? smem[lane_id] : (T)0.f; + val = warpReduceSumFull(val); + + return val; +} + +template +inline __device__ T blockReduceSum(T val) { + bool hasPartialWarp = (blockDim.x & kLANE_ID_MASK) != 0; + if (hasPartialWarp) { + return blockReduceSumPartial(val); + } else { + return blockReduceSumFull(val); + } +} +// A helper function to tune the grid configuration for fused oneshot and rmsnorm kernels +// Return (block_size, cluster_size, loads_per_thread) +std::tuple adjustGridConfig(int numTokens, int dim, int eltsPerThread) { + // Start with preferred block_size and cluster_size +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int clusterSize = 8; +#else + int clusterSize = 1; +#endif + int blockSize = 128; + // ========================== Adjust the grid configuration ========================== + int threadsNeeded = ceil_div(dim, eltsPerThread); + int loadsPerThread = 1; + + blockSize = ceil_div(threadsNeeded, clusterSize); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + while (threadsNeeded % clusterSize != 0 && clusterSize > 1) { + clusterSize /= 2; + } + blockSize = ceil_div(threadsNeeded, clusterSize); + while (blockSize < 128 && clusterSize >= 2) { + blockSize *= 2; + clusterSize /= 2; + } + int smCount = GetCudaMultiProcessorCount(); + while (numTokens * clusterSize > smCount && clusterSize > 1 && blockSize <= 512) { + blockSize *= 2; + clusterSize /= 2; + } +#endif + + // Trying to scale up use multiple loads or CGA + while (blockSize > 1024) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (clusterSize < 8) { + clusterSize = clusterSize << 1; + } else { + break; + } +#else + if (loadsPerThread < 8) { + loadsPerThread += 1; + } else { + break; + } +#endif + blockSize = ceil_div(threadsNeeded, clusterSize * loadsPerThread); + } + return {blockSize, clusterSize, loadsPerThread}; +} +}; // namespace utils + +using utils::blockReduceSum; +using utils::fromFloat; +using utils::isNegZero; +using utils::LamportFlags; +using utils::loadPacked; +using utils::loadPackedVolatile; +using utils::PackedVec; +using utils::toFloat; + +template +__global__ void __launch_bounds__(1024) + oneshotAllreduceFusionKernel(T* outputPtr, T* prenormedPtr, T const* shardPtr, + T const* residualInPtr, T const* gammaPtr, T** inputPtrs, + T* mcastPtr, int const numTokens, int const tokenDim, + float epsilon, int const rank, uint32_t* bufferFlags) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int packedIdx = cluster.thread_rank(); int token = blockIdx.x; + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); +#else + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; #endif - LamportFlags flags(buffer_flags); - - // Capture the number of tokens in previous iteration so that we can properly clear the buffer - // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up - uint32_t clr_toks_cta = - divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, - WORLD_SIZE) * - WORLD_SIZE; - clr_toks_cta = divUp(clr_toks_cta, gridDim.x); - - if (elt < token_dim) { - // Scatter token - int dest_rank = token % WORLD_SIZE; - int dest_token_offset = token / WORLD_SIZE; - T val = shard_ptr[token * token_dim + elt]; - if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + - rank * token_dim + elt] = val; - - // Clear the buffer used by the previous call. Note the number of tokens to clear could be - // larger than the - // number of tokens in the current call. - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + // We only use 1 stage for the oneshot allreduce + LamportFlags flag(bufferFlags, 1); + T* stagePtrMcast = reinterpret_cast(flag.getCurLamportBuf(mcastPtr, 0)); + T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); + + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; + } + + // ==================== Broadcast tokens to each rank ============================= + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) val.elements[i] = toFloat(0.f); + } + + reinterpret_cast( + &stagePtrMcast[token * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = val.packed; + + flag.ctaArrive(); + // ======================= Lamport Sync and clear the output buffer from previous iteration + // ============================= + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; +#pragma unroll + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &stagePtrLocal[token * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); + +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); } } + if (valid) { + break; + } + } - // Reduce and broadcast - if ((token % WORLD_SIZE) == rank) { - int local_token = token / WORLD_SIZE; - float accum = 0.f; - - T values[WORLD_SIZE]; - - while (1) { - bool valid = true; - for (int r = 0; r < WORLD_SIZE; r++) { - T volatile* lamport_ptr = - (T volatile*)&input_ptrs[rank] - [flags.input_offset + local_token * token_dim * WORLD_SIZE + - r * token_dim + elt]; - values[r] = *lamport_ptr; - valid &= !isNegZero(values[r]); - } - if (valid) break; - } - for (int r = 0; r < WORLD_SIZE; r++) { - accum += toFloat(values[r]); - } - mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = - fromFloat(accum); + auto values = reinterpret_cast*>(valuesLamport); + // ======================= Reduction ============================= + float accum[kELTS_PER_THREAD]; + PackedVec packedAccum; + +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] = toFloat(values[0].elements[i]); + } + +#pragma unroll + for (int r = 1; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); } } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + - elt] = fromFloat(-0.f); + if constexpr (RMSNormFusion) { + // =============================== Residual =============================== + PackedVec residualIn; + residualIn.packed = *reinterpret_cast(&residualInPtr[threadOffset]); + packedAccum += residualIn; + *reinterpret_cast(&prenormedPtr[threadOffset]) = packedAccum.packed; + // =============================== Rmsnorm ================================ + PackedVec gamma; + gamma.packed = *reinterpret_cast(&gammaPtr[packedIdx * kELTS_PER_THREAD]); + + float threadSum = 0.F; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + // FIXME: Use float square if accuracy issue + threadSum += toFloat(packedAccum.elements[i] * packedAccum.elements[i]); } - } + float blockSum = blockReduceSum(threadSum); - // Optionally wait for results if the next layer isn't doing the Lamport check - if (wait_for_results) { - // Update the atomic counter to indicate the block has read the offsets - flags.cta_arrive(); - // Only use a set of CTAs for lamport sync, reargange the grid - constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); - // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) - if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = - blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = - (void*)&input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be - // aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*)&val)) { - val = loadfloat2(lamport_ptr); + __shared__ float sharedVal[8]; // Temporary variable to share the sum within block + float fullSum = blockSum; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; } - if (output_ptr) { - *((float2*)&output_ptr[current_pos]) = val; + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; } } - - // Update the buffer flags - flags.wait_and_update(num_tokens); +#endif + float rcpRms = rsqrtf(fullSum / tokenDim + epsilon); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * + fromFloat(gamma.elements[i])); + } } + reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; + flag.waitAndUpdate( + {static_cast(numTokens * tokenDim * WorldSize * kELT_SIZE), 0, 0, 0}); } -// Template-based dispatch functions following the same pattern as trtllm_allreduce.cuh -template -cudaError_t twoshot_allreduce_dispatch(AllReduceParams& params) { - int const num_threads = 128; - int const num_blocks = (params.token_dim + num_threads - 1) / num_threads; - - dim3 grid(params.num_tokens, num_blocks); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.dynamicSmemBytes = 0; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = num_threads; - config.attrs = attrs; +using utils::adjustGridConfig; + +template +cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const eltsPerThread = sizeof(float4) / sizeof(T); + + auto [blockSize, clusterSize, loadsPerThread] = + adjustGridConfig(numTokens, tokenDim, eltsPerThread); + dim3 grid(numTokens, clusterSize, 1); + + FLASHINFER_CHECK(blockSize <= 1024 && loadsPerThread == 1, + "Hidden Dimension %d exceeds the maximum supported hidden dimension (%d)", + tokenDim, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + 1024 * 8 * eltsPerThread); +#else + 1024 * eltsPerThread); +#endif + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceOneShot] Dispatch: grid size: (%d, %d, 1), block_size: %d, cluster_size: " + "%d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, clusterSize, blockSize, clusterSize, loadsPerThread, + ceil_div(tokenDim, eltsPerThread)); + + cudaLaunchAttribute attrs[2]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; + attrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + attrs[1].id = cudaLaunchAttributeClusterDimension; + attrs[1].val.clusterDim.x = 1; + attrs[1].val.clusterDim.y = clusterSize; + attrs[1].val.clusterDim.z = 1; +#endif - cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.buffer_ptrs_dev), - reinterpret_cast(params.multicast_ptr), params.num_tokens, params.buffer_M, - params.token_dim, params.rank, - reinterpret_cast(params.buffer_flags), params.wait_for_results); + cudaLaunchConfig_t config{ + .gridDim = grid, + .blockDim = blockSize, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = attrs, +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + .numAttrs = 2, +#else + .numAttrs = 1, +#endif + }; + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, RMSNORM) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &config, &oneshotAllreduceFusionKernel, output, residualOut, input, \ + residualIn, gamma, ucPtrs, mcPtr, numTokens, tokenDim, static_cast(params.epsilon), \ + params.rank, params.bufferFlags)); +#define DISPATCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + if (params.rmsNormFusion) { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, true); \ + } else { \ + LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE, false); \ + } - return cudaSuccess; -} + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T* residualOut = reinterpret_cast(params.residualOut); + T const* input = reinterpret_cast(params.input); + T const* residualIn = reinterpret_cast(params.residualIn); + T const* gamma = reinterpret_cast(params.gamma); -template -cudaError_t twoshot_allreduce_dispatch_world_size(AllReduceParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_allreduce_dispatch_world_size"); - switch (params.nranks) { + switch (params.nRanks) { + // FIXME: Do we need other world sizes? case 2: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(2); case 4: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(4); case 8: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(8); case 16: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(16); case 32: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(32); case 64: - return twoshot_allreduce_dispatch(params); + DISPATCH_ALLREDUCE_KERNEL(64); default: FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } +#undef LAUNCH_ALLREDUCE_KERNEL + return cudaSuccess; } -template -__device__ void copy_f4(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4 const*)src; - __pipeline_memcpy_async(dst4, src4, sizeof(float4)); -} - -template -__device__ void copy_f4_ldg(T_IN* dst, T_IN const* src) { - float4* dst4 = (float4*)dst; - float4 const* src4 = (float4*)src; - *dst4 = *src4; -} - -__device__ float4 loadfloat4(void const* ptr) { - // Check alignment - ptr should be 16-byte aligned for safe float4 load - if (reinterpret_cast(ptr) % 16 != 0) { - // Fall back to scalar loads if not aligned - float4 return_value; - float const* float_ptr = reinterpret_cast(ptr); - return_value.x = float_ptr[0]; - return_value.y = float_ptr[1]; - return_value.z = float_ptr[2]; - return_value.w = float_ptr[3]; - return return_value; - } +enum MNNVLTwoShotStage : uint8_t { + SCATTER = 0, + BROADCAST = 1, + NUM_STAGES = 2, +}; - float4 return_value; +template +__global__ __launch_bounds__(128) void twoshotAllreduceKernel( + T* outputPtr, T const* shardPtr, T** inputPtrs, T* mcastPtr, uint32_t const numTokens, + uint32_t const tokenDim, uint32_t const rank, uint32_t* bufferFlags, + bool const wait_for_results) { + constexpr int kELTS_PER_THREAD = sizeof(PackedType) / sizeof(T); + constexpr int kLAMPORT_ELTS_PER_PACKED = sizeof(PackedType) / sizeof(float); + constexpr uint32_t kELT_SIZE = sizeof(T); - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), - "=f"(return_value.w) - : "l"(ptr)); + int packedIdx = blockIdx.y * blockDim.x + threadIdx.x; + int token = blockIdx.x; + // Offset w.r.t. the input shard + int threadOffset = token * tokenDim + packedIdx * kELTS_PER_THREAD; - return return_value; -} + int destRank = token % WorldSize; + int destTokenOffset = token / WorldSize; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); -// Safer version that checks bounds before loading -template -__device__ float4 loadfloat4_safe(T const* ptr, int remaining_elements) { - float return_value[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + T* scatterBufLocal = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER)); + T* scatterBufDest = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[destRank], MNNVLTwoShotStage::SCATTER)); + T* broadcastBufW = + reinterpret_cast(flag.getCurLamportBuf(mcastPtr, MNNVLTwoShotStage::BROADCAST)); + T* broadcastBufR = + reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST)); - if (remaining_elements <= 0) { - return *(float4*)return_value; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Make sure the clear function is called before OOB thread exits + if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.clearDirtyLamportBuf(inputPtrs[rank], -1); + return; } - // Check alignment - ptr should be 16-byte aligned for safe float4 load - bool is_aligned = (reinterpret_cast(ptr) % 16 == 0); + // =============================== Scatter =============================== - if (is_aligned && remaining_elements >= 4) { - // Safe to do vectorized load - asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), - "=f"(return_value[3]) - : "l"(ptr)); - } else { - // Fall back to scalar loads with bounds checking - float const* float_ptr = reinterpret_cast(ptr); - for (int i = 0; i < 4 && i < remaining_elements; i++) { - return_value[i] = toFloat(float_ptr[i]); + // Load vectorized data + PackedVec val; + val.packed = loadPacked(&shardPtr[threadOffset]); +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + if (isNegZero(val.elements[i])) { + val.elements[i] = fromFloat(0.F); } } - return *(float4*)return_value; -} + // Store vectorized data + reinterpret_cast( + &scatterBufDest[destTokenOffset * tokenDim * WorldSize + rank * tokenDim])[packedIdx] = + val.packed; -template -inline __device__ T add(T a, T b) { - return a + b; -} + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::SCATTER); -#define FINAL_MASK 0xffffffff + // =============================== Reduction and Broadcast =============================== -template -__inline__ __device__ T warpReduceSum(T val) { + if ((token % WorldSize) == rank) { + int localToken = token / WorldSize; + float accum[kELTS_PER_THREAD] = {0.F}; + + // Use float as we only check each float value for validity + PackedVec valuesLamport[WorldSize]; + while (1) { + bool valid = true; #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, - 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} + for (int r = 0; r < WorldSize; r++) { + valuesLamport[r].packed = loadPackedVolatile( + &scatterBufLocal[localToken * tokenDim * WorldSize + r * tokenDim + + packedIdx * kELTS_PER_THREAD]); -inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; - val = warpReduceSum(val); - if (lane_id == 0) { - smem[warp_id] = val; - } - __syncthreads(); - val = lane_id < warp_num ? smem[lane_id] : 0.f; - val = warpReduceSum(val); - return val; -} + // Check validity across all elements +#pragma unroll + for (int i = 0; i < kLAMPORT_ELTS_PER_PACKED; i++) { + valid &= !isNegZero(valuesLamport[r].elements[i]); + } + } + if (valid) { + break; + } + } -template -__global__ void __launch_bounds__(128, 1) - RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, - T_IN const* gamma, float epsilon, T_IN const* residual, int batch_size, - uint32_t* buffer_flags) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Now we view it as the value for reduction + auto values = reinterpret_cast*>(valuesLamport); +#pragma unroll + for (int r = 0; r < WorldSize; r++) { +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + accum[i] += toFloat(values[r].elements[i]); + } + } - static bool const LAMPORT = true; + // Store vectorized result + PackedVec packedAccum; +#pragma unroll + for (int i = 0; i < kELTS_PER_THREAD; i++) { + packedAccum.elements[i] = fromFloat(accum[i]); + } + reinterpret_cast(&broadcastBufW[token * tokenDim])[packedIdx] = packedAccum.packed; + } - extern __shared__ uint8_t smem[]; + flag.clearDirtyLamportBuf(inputPtrs[rank], MNNVLTwoShotStage::BROADCAST); - int sample = blockIdx.y; + // Optionally wait for results if the next layer isn't doing the Lamport check + if (wait_for_results) { + // Update the atomic counter to indicate the block has read the offsets + flag.ctaArrive(); - static int const CGA_THREADS = NUM_THREADS * 1; + PackedVec valLamport; + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + while (isNegZero(valLamport.elements[0])) { + valLamport.packed = loadPackedVolatile(&broadcastBufR[threadOffset]); + } + if (outputPtr) { + reinterpret_cast(&outputPtr[threadOffset])[0] = valLamport.packed; + } - static int const ITERS = DIM / CGA_THREADS; - float r_input[ITERS]; - float r_gamma[ITERS]; + // Update the buffer flags + flag.waitAndUpdate( + {static_cast(round_up(numTokens, WorldSize) * tokenDim * + kELT_SIZE), // Clear Size for scatter stage + static_cast(numTokens * tokenDim * kELT_SIZE), // Clear Size for broadcast stage + 0, 0}); + // If not wait for results, we will rely on the following kernel to update the buffer + } +} - T_IN* sh_input = (T_IN*)&smem[0]; - T_IN* sh_residual = (T_IN*)&smem[NUM_INPUTS * NUM_THREADS * ITERS * sizeof(T_IN)]; - T_IN* sh_gamma = (T_IN*)&smem[(NUM_INPUTS + 1) * NUM_THREADS * ITERS * sizeof(T_IN)]; +using utils::copyF4; +// This kernel works performant when loads_per_thread is 1. +// For this mode, we are able to support up to 1024 (threads) x 8 (elements) = 8192 hidden +// dimension. There are two options for further scaling up: +// 1. Use CGA if supported. It expands the hidden dimension to 8k x 8 = 64k. +// 2. Set loads_per_thread >1. Which can be used if CGA is not supported. Note that this will +// be limited by the shared memory size and register count. +template +__global__ __launch_bounds__(1024) void rmsNormLamport(T_IN* outputPreNorm, T_OUT* outputNorm, + T_IN* bufferInput, T_IN const* gamma, + float epsilon, T_IN const* residual, + uint32_t numTokens, uint32_t dim, + uint32_t worldSize, uint32_t* bufferFlags) { + static_assert(std::is_same_v, "T_IN and T_OUT must be the same type"); + static int const kELTS_PER_LOAD = sizeof(float4) / sizeof(T_IN); + + uint32_t const token = blockIdx.x; + uint32_t const blockSize = blockDim.x; + uint32_t const threadOffset = threadIdx.x; + + uint32_t numThreads = blockSize; + uint32_t clusterSize = 1; + uint32_t blockOffset = 0; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + namespace cg = cooperative_groups; + cg::cluster_group cluster = cg::this_cluster(); + numThreads = cluster.num_threads(); + clusterSize = cluster.num_blocks(); + blockOffset = cluster.block_rank(); +#endif + uint32_t const dimPadded = round_up(dim, kELTS_PER_LOAD * numThreads); + uint32_t const elemsPerThread = dimPadded / numThreads; + uint32_t const loadStride = blockSize; - static int const ELTS_PER_THREAD = sizeof(float4) / sizeof(T_IN); + extern __shared__ uint8_t smem[]; + float rInput[LoadsPerThread * kELTS_PER_LOAD]; + uint32_t offsets[LoadsPerThread * kELTS_PER_LOAD]; - int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; + uint32_t const smemBufferSize = blockSize * elemsPerThread * sizeof(T_IN); + T_IN* smemInput = (T_IN*)&smem[0]; + T_IN* smemResidual = (T_IN*)&smem[smemBufferSize]; + T_IN* smemGamma = (T_IN*)&smem[2 * smemBufferSize]; - LamportFlags flags(buffer_flags); - T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; + LamportFlags flag(bufferFlags, MNNVLTwoShotStage::NUM_STAGES); + T_IN* input = reinterpret_cast( + flag.getCurLamportBuf(reinterpret_cast(bufferInput), MNNVLTwoShotStage::BROADCAST)); -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif + // The offset that current thread should load from. Note that the hidden dimension is split by CGA + // size and each block loads a contiguous chunk; The size of chunk that each block processes + uint32_t const blockChunkSize = ceil_div(dim, clusterSize * kELTS_PER_LOAD) * kELTS_PER_LOAD; + uint32_t const blockLoadOffset = token * dim + blockOffset * blockChunkSize; - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - offsets[i][j] = - i * batch_size * DIM + sample * DIM + blockIdx.x * DIM / 1 + k * ELTS_PER_THREAD; - } +#pragma unroll + for (uint32_t i = 0; i < LoadsPerThread; i++) { + // Each block load a contiguous chunk of tokens + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + offsets[i] = blockLoadOffset + threadLoadOffset; } #pragma unroll - for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_residual[i * ELTS_PER_THREAD], - &residual[sample * DIM + blockIdx.x * DIM + i * ELTS_PER_THREAD]); + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemResidual[threadLoadOffset], &residual[blockLoadOffset + threadLoadOffset]); + } } - __pipeline_commit(); - #pragma unroll - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int i = j * NUM_THREADS + threadIdx.x; - copy_f4(&sh_gamma[i * ELTS_PER_THREAD], &gamma[blockIdx.x * DIM + i * ELTS_PER_THREAD]); + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t const threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + copyF4(&smemGamma[threadLoadOffset], &gamma[blockOffset * blockChunkSize + threadLoadOffset]); + } } - __pipeline_commit(); - flags.cta_arrive(); - // Load all inputs + flag.ctaArrive(); bool valid = false; - -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if (!LAMPORT) cudaGridDependencySynchronize(); -#endif - + // ACQBLK if not lamport while (!valid) { valid = true; #pragma unroll - for (int i = 0; i < NUM_INPUTS; i++) { - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; - - float4* dst4 = (float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - - // Calculate the absolute element offset from the start of buffer_input - int element_offset = offsets[i][j]; - - // The input pointer is already offset to: &buffer_input[buffer_offset + buffer_size] - // So the actual pointer we're accessing is: input + element_offset - // Which equals: &buffer_input[buffer_offset + buffer_size + element_offset] - - float4* src4 = (float4*)&input[element_offset]; - - float4 value; - // Check if we have enough elements remaining for a safe float4 load - if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= flags.buffer_size) { - value = loadfloat4(src4); - } else { - // Use safe load for boundary cases or out-of-bounds - int remaining_elements = flags.buffer_size - element_offset; - if (remaining_elements <= 0) { - // Completely out of bounds, return zeros - float4 return_value = {0.0f, 0.0f, 0.0f, 0.0f}; - value = return_value; - } else { - value = loadfloat4_safe(reinterpret_cast(src4), remaining_elements); - } - } - - if (LAMPORT) { - // Assume that the 16B were written atomically, so we only need to check one value - T_IN lowest_val = *(T_IN*)&value; - valid &= !isNegZero(lowest_val); - } - *dst4 = value; - } - } - } - - __syncthreads(); - - // Perform the initial input reduction - if (NUM_INPUTS > 0) { - T_IN accum[ELTS_PER_THREAD]; - float4* accum4 = (float4*)&accum; - - for (int j = 0; j < DIM / (ELTS_PER_THREAD * NUM_THREADS); j++) { - int k = j * NUM_THREADS + threadIdx.x; + for (uint32_t i = 0; i < LoadsPerThread; i++) { + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; - *accum4 = *(float4*)&sh_input[k * ELTS_PER_THREAD]; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + float4* dst4 = reinterpret_cast(&smemInput[threadLoadOffset]); + float4 const* src4 = reinterpret_cast(&input[offsets[i]]); - for (int i = 1; i < NUM_INPUTS; i++) { - float4 data = *(float4*)&sh_input[i * NUM_THREADS * ITERS + k * ELTS_PER_THREAD]; - T_IN* p_d = (T_IN*)&data; - for (int x = 0; x < ELTS_PER_THREAD; x++) { - accum[x] += p_d[x]; - } + float4 value = loadPackedVolatile(src4); + // Assume that the 16B were written atomically, so we only need to check one value + valid &= !isNegZero(value.x); + *dst4 = value; } - - // Write back to input 0's staging location. No sync needed since all data localized to - // thread. - *(float4*)&sh_input[k * ELTS_PER_THREAD] = *accum4; } } - // Wait for residual __pipeline_wait_prior(1); __syncthreads(); - float thread_sum = 0.f; - + float threadSum = 0.f; #pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 inp4 = - *(float4*)&sh_input[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - float4 res4 = - *(float4*)&sh_residual[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - - T_IN* r_inp = (T_IN*)&inp4; - T_IN* r_res = (T_IN*)&res4; + for (int i = 0; i < LoadsPerThread; i++) { + int threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec inp{.packed = loadPacked(&smemInput[threadLoadOffset])}; + PackedVec res{.packed = loadPacked(&smemResidual[threadLoadOffset])}; - float4 out4; - - T_IN* r_out = (T_IN*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - - T_IN inp_plus_resid = r_inp[ii] + r_res[ii]; - r_out[ii] = inp_plus_resid; - r_input[i] = toFloat(inp_plus_resid); + PackedVec inp_plus_res = inp + res; +#pragma unroll + for (int j = 0; j < kELTS_PER_LOAD; j++) { + rInput[i * kELTS_PER_LOAD + j] = toFloat(inp_plus_res.elements[j]); + threadSum += toFloat(inp_plus_res.elements[j] * inp_plus_res.elements[j]); + } - // Accumulate the squares for RMSNorm - thread_sum += toFloat(inp_plus_resid * inp_plus_resid); + *reinterpret_cast(&outputPreNorm[blockLoadOffset + threadLoadOffset]) = + inp_plus_res.packed; } - - *(float4*)&input_plus_residual[sample * DIM + blockIdx.x * DIM + - io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; } - // Wait for Gamma. There will be a global synchronization as part of the reduction __pipeline_wait_prior(0); - float cluster_sum = block_reduce_sum(thread_sum); - - float rcp_rms = rsqrtf(cluster_sum / DIM + epsilon); + float blockSum = blockReduceSum(threadSum); -#pragma unroll - for (int io = 0; io < ITERS / ELTS_PER_THREAD; io++) { - float4 gamma4 = - *(float4*)&sh_gamma[io * NUM_THREADS * ELTS_PER_THREAD + threadIdx.x * ELTS_PER_THREAD]; - T_IN* r_g4 = (T_IN*)&gamma4; - - float4 out4; - // FIXME: this only works if T_OUT == T_IN - T_OUT* r_out = (T_OUT*)&out4; - - for (int ii = 0; ii < ELTS_PER_THREAD; ii++) { - int i = io * ELTS_PER_THREAD + ii; - r_gamma[i] = toFloat(r_g4[ii]); - r_out[ii] = fromFloat(r_gamma[i] * r_input[i] * rcp_rms); + float fullSum = blockSum; + __shared__ float sharedVal[8]; + // Use CGA Reduction if supported +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int const numBlocks = cluster.num_blocks(); + if (numBlocks > 1) { + fullSum = 0.F; + // Need to reduce over the entire cluster + int const blockRank = cluster.block_rank(); + if (threadIdx.x < numBlocks) { + cluster.map_shared_rank(&sharedVal[0], threadIdx.x)[blockRank] = blockSum; + } + cluster.barrier_wait(cluster.barrier_arrive()); + for (int i = 0; i < numBlocks; ++i) { + fullSum += sharedVal[i]; } - - *(float4*)&output_norm[sample * DIM + blockIdx.x * DIM + io * NUM_THREADS * ELTS_PER_THREAD + - threadIdx.x * ELTS_PER_THREAD] = out4; } - // Update the buffer pointers - flags.wait_and_update(batch_size); #endif -} -template -cudaError_t twoshot_rmsnorm_dispatch(RMSNormParams& params) { - static constexpr int NUM_THREADS = 128; - static constexpr int CGA_THREADS = NUM_THREADS; - constexpr int iters = H_DIM / CGA_THREADS; + float rcpRms = rsqrtf(fullSum / dim + epsilon); - dim3 grid(1, params.batch, 1); - - cudaLaunchConfig_t config; - cudaLaunchAttribute attrs[1]; - config.stream = params.stream; - config.gridDim = grid; - config.blockDim = NUM_THREADS; - config.attrs = attrs; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = params.launch_with_pdl ? 1 : 0; - config.numAttrs = 1; - - size_t shmem_size = 3 * NUM_THREADS * iters * sizeof(T); - config.dynamicSmemBytes = shmem_size; +#pragma unroll + for (int i = 0; i < LoadsPerThread; i++) { + PackedVec r_out; + uint32_t threadLoadOffset = (i * loadStride + threadOffset) * kELTS_PER_LOAD; + if (blockOffset * blockChunkSize + threadLoadOffset < dim) { + PackedVec gamma = {.packed = loadPacked(&smemGamma[threadLoadOffset])}; - cudaFuncSetAttribute(&RMSNorm, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); +#pragma unroll + for (uint32_t j = 0; j < kELTS_PER_LOAD; j++) { + r_out.elements[j] = fromFloat(toFloat(gamma.elements[j]) * + rInput[i * kELTS_PER_LOAD + j] * rcpRms); + } - cudaLaunchKernelEx( - &config, &RMSNorm, reinterpret_cast(params.residual_output), - reinterpret_cast(params.output), reinterpret_cast(params.input), - reinterpret_cast(params.gamma), static_cast(params.epsilon), - reinterpret_cast(params.residual), params.batch, params.buffer_flags); + *reinterpret_cast(&outputNorm[blockLoadOffset + threadLoadOffset]) = r_out.packed; + } + } + constexpr int kELTS_SIZE = sizeof(T_IN); - return cudaSuccess; + // Update the buffer pointers + flag.waitAndUpdate({static_cast(round_up(numTokens, worldSize) * dim * kELTS_SIZE), + static_cast(numTokens * dim * kELTS_SIZE), 0, 0}); } template -cudaError_t twoshot_rmsnorm_dispatch_hidden_dim(RMSNormParams& params) { - FLASHINFER_LOG_DEBUG("twoshot_rmsnorm_dispatch_hidden_dim"); - switch (params.hidden_dim) { - case 2048: - return twoshot_rmsnorm_dispatch(params); - case 4096: - return twoshot_rmsnorm_dispatch(params); - case 5120: - return twoshot_rmsnorm_dispatch(params); // Llama-4 - case 7168: - return twoshot_rmsnorm_dispatch(params); // DeepSeek - case 8192: - return twoshot_rmsnorm_dispatch(params); +cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) { + int const numTokens = params.numTokens; + int const tokenDim = params.tokenDim; + int const numEltsPerThread = sizeof(float4) / sizeof(T); + FLASHINFER_CHECK(tokenDim % numEltsPerThread == 0, + "[MNNVL AllReduceTwoShot] token_dim must be divisible by %d", numEltsPerThread); + + int const arNumThreads = ceil_div(tokenDim, numEltsPerThread); + int const arNumBlocksPerToken = ceil_div(arNumThreads, 128); + + dim3 arGrid(numTokens, arNumBlocksPerToken); + + cudaLaunchAttribute arAttrs[1]; + arAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + arAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; + + cudaLaunchConfig_t arConfig{ + .gridDim = arGrid, + .blockDim = 128, + .dynamicSmemBytes = 0, + .stream = params.stream, + .attrs = arAttrs, + .numAttrs = 1, + }; + + FLASHINFER_LOG_DEBUG("[MNNVL AllReduceTwoShot] Dispatch: grid size: (%d, %d, 1), block_size: 128", + numTokens, arNumBlocksPerToken); + +#define LAUNCH_ALLREDUCE_KERNEL(WORLD_SIZE) \ + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \ + &arConfig, &twoshotAllreduceKernel, output, input, ucPtrs, mcastPtr, \ + numTokens, tokenDim, params.rank, params.bufferFlags, (!params.rmsNormFusion))); + T** ucPtrs = reinterpret_cast(params.bufferPtrsDev); + T* mcastPtr = reinterpret_cast(params.multicastPtr); + T* output = reinterpret_cast(params.output); + T const* input = reinterpret_cast(params.input); + switch (params.nRanks) { + case 2: + LAUNCH_ALLREDUCE_KERNEL(2); + break; + case 4: + LAUNCH_ALLREDUCE_KERNEL(4); + break; + case 8: + LAUNCH_ALLREDUCE_KERNEL(8); + break; + case 16: + LAUNCH_ALLREDUCE_KERNEL(16); + break; + case 32: + LAUNCH_ALLREDUCE_KERNEL(32); + break; + case 64: + LAUNCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL TwoShot RMSNorm: unsupported hidden_dim " + - std::to_string(params.hidden_dim) + - ". Supported sizes: {2048, 4096, 5120, 7168, 8192}"); + FLASHINFER_ERROR("[MNNVL AllReduceTwoShot] Unsupported world_size" + + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } -} +#undef LAUNCH_ALLREDUCE_KERNEL + + // Launch the rmsnorm lamport kernel if fusion is enabled + if (params.rmsNormFusion) { + auto gridConfig = adjustGridConfig(numTokens, tokenDim, numEltsPerThread); + int rnBlockSize = std::get<0>(gridConfig); + int rnClusterSize = std::get<1>(gridConfig); + int rnLoadsPerThread = std::get<2>(gridConfig); + + int rnNumThreads = rnClusterSize * rnBlockSize; + dim3 rnGrid(numTokens, rnClusterSize, 1); + cudaLaunchConfig_t rnConfig; + cudaLaunchAttribute rnAttrs[2]; + rnConfig.stream = params.stream; + rnConfig.gridDim = rnGrid; + rnConfig.blockDim = rnBlockSize; + rnConfig.attrs = rnAttrs; + rnAttrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + rnAttrs[0].val.programmaticStreamSerializationAllowed = params.launchWithPdl ? 1 : 0; +#ifndef DISABLE_CGA + rnAttrs[1].id = cudaLaunchAttributeClusterDimension; + rnAttrs[1].val.clusterDim.x = 1; + rnAttrs[1].val.clusterDim.y = rnClusterSize; + rnAttrs[1].val.clusterDim.z = 1; + rnConfig.numAttrs = 2; +#else + rnConfig.numAttrs = 1; +#endif + bool const rnUseCGA = rnClusterSize > 1; + int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); + int const iters = dimPadded / rnNumThreads; + + size_t const smemSize = 3 * rnBlockSize * iters * getDTypeSize(params.dType); + + FLASHINFER_LOG_DEBUG( + "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " + "cluster_size: %d, " + "loads_per_thread: %d, " + "threads_needed: %d", + numTokens, rnClusterSize, rnBlockSize, rnClusterSize, rnLoadsPerThread, + ceil_div(tokenDim, numEltsPerThread)); + +#define RUN_RMSNORM_KERNEL(LOADS_PER_THREAD) \ + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(&rmsNormLamport, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + smemSize)); \ + rnConfig.dynamicSmemBytes = smemSize; \ + FLASHINFER_CUDA_CALL( \ + cudaLaunchKernelEx(&rnConfig, &rmsNormLamport, residualOut, output, \ + bufferInput, gamma, static_cast(params.epsilon), residualIn, \ + numTokens, tokenDim, params.nRanks, params.bufferFlags)); + + T* residualOut = reinterpret_cast(params.residualOut); + T* output = reinterpret_cast(params.output); + T* bufferInput = reinterpret_cast(params.bufferPtrLocal); + T const* gamma = reinterpret_cast(params.gamma); + T const* residualIn = reinterpret_cast(params.residualIn); + if (rnUseCGA) { + RUN_RMSNORM_KERNEL(1); + } else { + switch (rnLoadsPerThread) { + case 1: + RUN_RMSNORM_KERNEL(1); + break; + case 2: + RUN_RMSNORM_KERNEL(2); + break; + case 3: + RUN_RMSNORM_KERNEL(3); + break; + case 4: + RUN_RMSNORM_KERNEL(4); + break; + case 5: + RUN_RMSNORM_KERNEL(5); + break; + case 6: + RUN_RMSNORM_KERNEL(6); + break; + case 7: + RUN_RMSNORM_KERNEL(7); + break; + case 8: + RUN_RMSNORM_KERNEL(8); + break; + default: + FLASHINFER_ERROR("[MNNVL AllReduceTwoShotRMSNorm] Unsupported loads_per_thread" + + std::to_string(rnLoadsPerThread) + + ". Supported sizes: {1, 2, 3, 4, 5, 6, 7, 8}"); + return cudaErrorInvalidValue; + } // switch (rnLoadsPerThread) + } // if (rnUseCGA) +#undef RUN_RMSNORM_KERNEL + + } // if (params.rmsNormFusion) + return cudaSuccess; +} } // namespace trtllm_mnnvl_allreduce } // namespace flashinfer diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 0471bd1081..20c19a0eae 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -335,6 +335,18 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +inline int GetCudaMultiProcessorCount() { + static int sm_count = 0; + if (sm_count == 0) { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp device_prop; + cudaGetDeviceProperties(&device_prop, device_id); + sm_count = device_prop.multiProcessorCount; + } + return sm_count; +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); From 1230273a478dc6e194891e9345cf27ca6b2dcb58 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 18 Nov 2025 17:13:28 -0800 Subject: [PATCH 02/32] Initial python interface. Need adjustment. --- csrc/trtllm_mnnvl_allreduce.cu | 8 +- flashinfer/comm/mnnvl.py | 18 +- flashinfer/comm/trtllm_mnnvl_ar.py | 418 ++++++++++++++--------------- 3 files changed, 230 insertions(+), 214 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 05a1684aa0..ad23037ff3 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -33,7 +33,8 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot, TensorView output, Optional residual_out, - Optional gamma, Optional epsilon) { + Optional residual_in, Optional gamma, + Optional epsilon) { cudaSetDevice(input.device().device_id); auto stream = get_stream(input.device()); @@ -82,9 +83,8 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt // input data params.input = const_cast(input.data_ptr()); - params.residualIn = residual_out.has_value() - ? const_cast(residual_out.value().data_ptr()) - : nullptr; + params.residualIn = + residual_in.has_value() ? const_cast(residual_in.value().data_ptr()) : nullptr; params.gamma = gamma.has_value() ? const_cast(gamma.value().data_ptr()) : nullptr; params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5; diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 2d280a68e8..a1a8c58d02 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -1005,7 +1005,7 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_mc_buffer( + def get_multicast_buffer( self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 ) -> torch.Tensor: """ @@ -1019,12 +1019,28 @@ def get_mc_buffer( Returns: A PyTorch tensor wrapping the multicast buffer section """ + + # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. + raise NotImplementedError("Not implemented yet") + + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: + """ + Returns a PyTorch tensor view of the unicast buffer portion. + """ + + # TODO: How can I warp a raw pointer to a tensor in python level? raise NotImplementedError("Not implemented yet") def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() + def get_unicast_ptr(self, rank: int) -> int: + """Get the raw unicast pointer to a given rank""" + return self.mcast_device_memory.get_unicast_ptr(rank) + def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 84a9c150de..f26a37d069 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -5,9 +5,10 @@ import functools import math -import os +import logging from types import SimpleNamespace -from typing import Optional, Tuple +from typing import Optional +from enum import Enum import torch @@ -25,238 +26,241 @@ def mpi_barrier(): MPI.COMM_WORLD.Barrier() +class MNNVLAllreduceFusionStrategy(Enum): + ONESHOT = 0 + TWOSHOT = 1 + AUTO = 99 + + @staticmethod + def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: + elem_size = torch.tensor([], dtype=dtype).element_size() + return num_tokens * hidden_dim * tp_size * elem_size <= kMNNVLOneShotThreshold + + +# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size +kMNNVLOneShotThreshold = 64 * 1024 * 8 * 2 + + +class MNNVLAllreduceFusionWorkspace: + NUM_LAMPORT_BUFFERS = 3 + + def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None): + """ + Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. + + Args: + mapping: Mapping configuration containing rank info + buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. + """ + if buffer_size_in_bytes is None: + # Default to 16MB workspace size if not provided + buffer_size_in_bytes = 16 * (1024**2) + else: + # Round up to the nearest multiple of 8MB + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2)) * (8 * (1024**2))) + + if buffer_size_in_bytes > (2**32 - 1): + raise ValueError( + f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." + ) + + self.buffer_size_bytes = buffer_size_in_bytes + self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + self.rank = mapping.tp_rank + self.tp_size = mapping.tp_size + logging.debug( + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes." + ) + self.mcast_buffer_handle = McastGPUBuffer( + self.workspace_size_bytes, + mapping.tp_size, + mapping.tp_rank, + torch.device("cuda", mapping.local_rank), + mapping.is_multi_node(), + ) + + # We use FP32 for sentinel value regardless of the real dtype + self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) + # Wait until the initialization is done + torch.cuda.synchronize() + # FIXME: We are assuming using the COMM_WORLD. + mpi_barrier() + + # This is a buffer to maintain the state of this allreduce Op + # Should have the same lifetime with self._buffer + # The flag should be binded to each buffer allocation + # Layout: [cur idx, dirty idx, bytes per buffer, dirty num stages, numBytesToClear[4], access count ptr] + num_bytes_to_clear = [0] * 4 + self.buffer_flags = torch.tensor( + [0, 2, self.buffer_size_bytes, 0, *num_bytes_to_clear, 0], + dtype=torch.uint32, + device=torch.device("cuda", mapping.local_rank), + ) + + self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() + self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) + self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + + @staticmethod + def get_required_buffer_size_bytes( + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> int: + """ + Calculate the required buffer size for a given problem size. + """ + elem_size = torch.tensor([], dtype=dtype).element_size() + is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot + ): + # For one-shot, each rank needs to store num_tokens * tp_size tokens + buffer_size = num_tokens * hidden_dim * tp_size * elem_size + else: + # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. + # 2 Stage is required for the two-shot allreduce. + buffer_size = 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + return buffer_size + + @functools.cache def get_trtllm_mnnvl_comm_module(): module = gen_trtllm_mnnvl_comm_module().build_and_load() @register_custom_op( - "flashinfer::trtllm_mnnvl_all_reduce", + "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ "inp", "multicast_buffer_ptr", "buffer_ptrs_dev", - "buffer_mnnvl", + "buffer_ptr_local", "buffer_flags_mnnvl", "nranks", "rank", - "wait_for_results", + "rmsnorm_fusion", "launch_with_pdl", + "use_oneshot", "out", + "residual_out", + "residual_in", + "gamma", + "epsilon", ], ) - def trtllm_mnnvl_all_reduce( + def trtllm_mnnvl_allreduce_fusion( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer - buffer_mnnvl: torch.Tensor, + buffer_ptr_local: int, # Pointer address as integer buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, - wait_for_results: bool, + rmsnorm_fusion: bool, launch_with_pdl: bool, + use_oneshot: bool, out: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + gamma: Optional[torch.Tensor], + epsilon: Optional[float], ) -> None: - module.trtllm_mnnvl_all_reduce( + """ + Perform a multi-node NVLink all-reduce operation with fusion. + Args: + inp: Input tensor + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer + buffer_ptr_local: Pointer to local buffer as an integer + buffer_flags_mnnvl: Buffer flags tensor for synchronization + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + rmsnorm_fusion: Whether to perform RMSNorm fusion + launch_with_pdl: Whether to launch with PDL + use_oneshot: Whether to use one-shot (true) or two-shot (false) + outp: Output tensor + residual_out: Residual output tensor (if rmsnorm) + gamma: Gamma tensor (if rmsnorm) + epsilon: Epsilon value (if rmsnorm) + """ + module.trtllm_mnnvl_allreduce_fusion( inp, multicast_buffer_ptr, buffer_ptrs_dev, - buffer_mnnvl, + buffer_ptr_local, buffer_flags_mnnvl, nranks, rank, - wait_for_results, + rmsnorm_fusion, launch_with_pdl, + use_oneshot, out, - ) - - @register_custom_op( - "flashinfer::trtllm_mnnvl_rmsnorm", - mutates_args=[ - "mcast_buffer_input", - "prenorm_output", - "normed_output", - "gamma", - "epsilon", - "residual", - "buffer_flags", - "launch_with_pdl", - ], - ) - def trtllm_mnnvl_rmsnorm( - mcast_buffer_input: int, - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - buffer_flags: torch.Tensor, - launch_with_pdl: bool, - ) -> None: - """Performs MNNVL TwoShot RMSNorm on the communication buffer. - - Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - mcast_buffer_input: Input tensor - gamma: The gamma parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add - buffer_flags: Buffer flags for synchronization - launch_with_pdl: Whether to launch with PDL - """ - return module.trtllm_mnnvl_rmsnorm( - mcast_buffer_input, - prenorm_output, - normed_output, + residual_out, + residual_in, gamma, epsilon, - residual, - buffer_flags, - launch_with_pdl, ) return SimpleNamespace( - trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, - trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm, - ) - - -def get_allreduce_mnnvl_workspace( - mapping: Mapping, - dtype: torch.dtype, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, - buffer_size_in_bytes: Optional[int] = None, -) -> Tuple[McastGPUBuffer, torch.Tensor, int]: - """Get workspace buffers needed for multi-node NVLink all-reduce operation. - - This function allocates and initializes the workspace buffers required for performing - multi-node NVLink all-reduce operations. It creates: - 1. A multicast GPU buffer for communication between nodes - 2. A flags tensor to track buffer state - 3. Maximum number of elements that can fit in the buffer - - The buffer size is calculated to efficiently handle common hidden dimensions - (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. - - Args: - mapping: Tensor parallel mapping configuration containing rank info - dtype: Data type of the tensors being reduced - comm: Optional communication backend for multi-node synchronization - buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens - - Returns: - Tuple containing: - - McastGPUBuffer: Multicast buffer for inter-node communication - - torch.Tensor: Buffer flags tensor tracking state - - int: Maximum number of elements that can fit in buffer - """ - force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - stride = 3 * 2 * dtype.itemsize - # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 - # max_num_elements must be a multiple of 286720 - lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = ( - buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 - ) - buffer_size_in_bytes = math.ceil( - TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) - ) * (lcm_hidden_dim * stride) - max_num_elements = buffer_size_in_bytes // stride - - mcast_buffer = McastGPUBuffer( - buffer_size_in_bytes, - mapping.tp_size, - mapping.tp_rank, - torch.device("cuda", mapping.local_rank), - mapping.is_multi_node() or force_mn, - comm_backend_for_handle_transfer=comm_backend_for_handle_transfer, - ) - - # Initialize the unicast buffer with -0.0 - mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) - - # CPU barrier since we assume this should not be called in cuda graph - torch.cuda.synchronize() - if comm_backend_for_handle_transfer is None: - mpi_barrier() - else: - comm_backend_for_handle_transfer.barrier() - - # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] - buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", mapping.local_rank), - ) - - return ( - mcast_buffer, - buffer_flags, - max_num_elements, + trtllm_mnnvl_allreduce_fusion=trtllm_mnnvl_allreduce_fusion, ) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, - nranks: int, - rank: int, - wait_for_results: bool, + workspace: MNNVLAllreduceFusionWorkspace, launch_with_pdl: bool, out: Optional[torch.Tensor] = None, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> None: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) technology to efficiently combine tensors across multiple GPUs and nodes. - There are 3 steps: - 1. scatter each GPU's input shard to the right unicast buffer - 2. perform all-reduce on each GPU - 3. broadcast the result to all GPUs + There are 2 variants: One-shot and Two-shot: + - One-shot: Each rank stores local shard to all other ranks. Each ranks will receive all shards at the end of the communication round and perfom local reduction. Suitable for small data size and is optimized for low latency. + - Two-shot: There will be 3 steps: + 1. Scatter each GPU's input shard to other ranks. Each rank will received all shards of a slice of tokens. + 2. Each rank perform reduction on the local tokens. + 3. Each rank broadcast the result to all ranks. + Suitable for large data size and is optimized for balancing throughput and latency. Args: - inp: Local Input Shard - multicast_buffer_ptr: Pointer to the multicast buffer as an integer - buffer_ptrs_dev: Pointer to device buffer pointers as an integer - buffer_M: Maximum number of elements // hidden_dim - buffer_flags_mnnvl: Tensor containing buffer state flags - nranks: Total number of ranks participating in the all-reduce - rank: Current process rank - wait_for_results: If True, store the result to out - launch_with_pdl: If True, launch using Programmatic Dependent Launch - [Optional] out: Output tensor to store the result (required if wait_for_results is True) - + inp: Local Input Shard [num_tokens, hidden_dim] + workspace: MNNVLAllreduceFusionWorkspace + launch_with_pdl: Whether to launch with PDL + out: Output tensor to store the result + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. """ if len(inp.shape) != 2: - raise ValueError( - f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." - ) - - if inp.shape[0] > buffer_M: - raise ValueError( - f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." - ) + raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") module = get_trtllm_mnnvl_comm_module() - module.trtllm_mnnvl_all_reduce( + + use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO + and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, inp.shape[0], inp.shape[1], inp.dtype) + ) + module.trtllm_mnnvl_allreduce_fusion( inp, - multicast_buffer_ptr, - int(buffer_ptrs_dev), - buffer_M, - buffer_flags_mnnvl, - nranks, - rank, - wait_for_results, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + False, # No RMSNorm Fusion launch_with_pdl, + use_oneshot, out, + None, + None, + None, + None, ) @@ -264,19 +268,14 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: torch.Tensor, normed_output: torch.Tensor, shard_input: torch.Tensor, - multicast_buffer_ptr: int, # Pointer address as integer - buffer_ptrs_dev: int, # Pointer address as integer - unicast_ptr: int, # Local unicast buffer pointer - buffer_M: int, - buffer_flags_mnnvl: torch.Tensor, - nranks: int, - rank: int, + workspace: MNNVLAllreduceFusionWorkspace, gamma: torch.Tensor, epsilon: float, residual: torch.Tensor, launch_with_pdl: bool, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> None: - """Performs MNNVL TwoShot Allreduce + RMSNorm. + """Performs MNNVL Allreduce + RMSNorm. This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. @@ -286,43 +285,44 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: Output tensor for prenorm results normed_output: Output tensor for normalized results shard_input: Input tensor shard - multicast_buffer_ptr: Pointer address as integer for multicast buffer - buffer_ptrs_dev: Pointer address as integer for device buffer pointers - unicast_ptr: Pointer address as integer for unicast buffer - buffer_M: Maximum number of elements // hidden_dim - buffer_flags_mnnvl: Buffer flags for synchronization - nranks: Number of ranks in the tensor parallel group - rank: Current rank in the tensor parallel group + workspace: MNNVLAllreduceFusionWorkspace gamma: The gamma (norm weight) parameter for RMSNorm epsilon: The epsilon parameter for RMSNorm residual: The residual tensor to add launch_with_pdl: Whether to launch with PDL """ - # allreduce_result = Σ(shard_input across all ranks) - trtllm_mnnvl_all_reduce( - shard_input, - multicast_buffer_ptr, - buffer_ptrs_dev, - buffer_M, - buffer_flags_mnnvl, - nranks, - rank, - False, # No need to wait to write AR results here as we are not writing them - launch_with_pdl, - None, # out parameter - None since wait_for_results=False + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( + strategy == MNNVLAllreduceFusionStrategy.AUTO + and MNNVLAllreduceFusionStrategy.is_one_shot( + workspace.tp_size, + shard_input.shape[0], + shard_input.shape[1], + shard_input.dtype, + ) ) - # prenorm_output = AllReduce(shard_input) + residual - # rms = sqrt(mean(prenorm_output²) + epsilon) - # normed_output = (prenorm_output / rms) * gamma - get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( - unicast_ptr, - prenorm_output, + module.trtllm_mnnvl_allreduce_fusion( + shard_input, + workspace.mc_ptr, + workspace.uc_ptrs_dev, + workspace.uc_ptr_local, + workspace.buffer_flags, + workspace.tp_size, + workspace.rank, + True, # RMSNorm Fusion + launch_with_pdl, + use_oneshot, normed_output, + prenorm_output, + residual, gamma, epsilon, - residual, - buffer_flags_mnnvl, - launch_with_pdl, ) From 874c228f4ae74f20932dcd4e7751307c163bc57d Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 18 Nov 2025 17:40:17 -0800 Subject: [PATCH 03/32] Refactor the interface. --- flashinfer/comm/trtllm_mnnvl_ar.py | 126 ++++++++++++++++++----------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index f26a37d069..839e03411c 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -7,7 +7,7 @@ import math import logging from types import SimpleNamespace -from typing import Optional +from typing import Optional, Tuple from enum import Enum import torch @@ -34,11 +34,12 @@ class MNNVLAllreduceFusionStrategy(Enum): @staticmethod def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: elem_size = torch.tensor([], dtype=dtype).element_size() - return num_tokens * hidden_dim * tp_size * elem_size <= kMNNVLOneShotThreshold + return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size -kMNNVLOneShotThreshold = 64 * 1024 * 8 * 2 +# TODO(Refactor): Consider moving this to a configuration class or file +MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 class MNNVLAllreduceFusionWorkspace: @@ -133,7 +134,7 @@ def get_trtllm_mnnvl_comm_module(): @register_custom_op( "flashinfer::trtllm_mnnvl_allreduce_fusion", mutates_args=[ - "inp", + "input", "multicast_buffer_ptr", "buffer_ptrs_dev", "buffer_ptr_local", @@ -143,7 +144,7 @@ def get_trtllm_mnnvl_comm_module(): "rmsnorm_fusion", "launch_with_pdl", "use_oneshot", - "out", + "output", "residual_out", "residual_in", "gamma", @@ -151,7 +152,7 @@ def get_trtllm_mnnvl_comm_module(): ], ) def trtllm_mnnvl_allreduce_fusion( - inp: torch.Tensor, + input: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer buffer_ptr_local: int, # Pointer address as integer @@ -161,7 +162,7 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion: bool, launch_with_pdl: bool, use_oneshot: bool, - out: Optional[torch.Tensor], + output: torch.Tensor, residual_out: Optional[torch.Tensor], residual_in: Optional[torch.Tensor], gamma: Optional[torch.Tensor], @@ -170,7 +171,7 @@ def trtllm_mnnvl_allreduce_fusion( """ Perform a multi-node NVLink all-reduce operation with fusion. Args: - inp: Input tensor + input: Input tensor multicast_buffer_ptr: Pointer to the multicast buffer as an integer buffer_ptrs_dev: Pointer to the device array of buffer pointers as an integer buffer_ptr_local: Pointer to local buffer as an integer @@ -180,13 +181,13 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion: Whether to perform RMSNorm fusion launch_with_pdl: Whether to launch with PDL use_oneshot: Whether to use one-shot (true) or two-shot (false) - outp: Output tensor + output: Output tensor residual_out: Residual output tensor (if rmsnorm) gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ module.trtllm_mnnvl_allreduce_fusion( - inp, + input, multicast_buffer_ptr, buffer_ptrs_dev, buffer_ptr_local, @@ -196,7 +197,7 @@ def trtllm_mnnvl_allreduce_fusion( rmsnorm_fusion, launch_with_pdl, use_oneshot, - out, + output, residual_out, residual_in, gamma, @@ -208,13 +209,13 @@ def trtllm_mnnvl_allreduce_fusion( ) -def trtllm_mnnvl_all_reduce( - inp: torch.Tensor, +def trtllm_mnnvl_allreduce( + input: torch.Tensor, workspace: MNNVLAllreduceFusionWorkspace, launch_with_pdl: bool, - out: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, -) -> None: +) -> torch.Tensor: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) @@ -229,24 +230,32 @@ def trtllm_mnnvl_all_reduce( Suitable for large data size and is optimized for balancing throughput and latency. Args: - inp: Local Input Shard [num_tokens, hidden_dim] + input: Local Input Shard [num_tokens, hidden_dim] workspace: MNNVLAllreduceFusionWorkspace launch_with_pdl: Whether to launch with PDL - out: Output tensor to store the result + output: Output tensor to store the result, empty tensor will be created if not provided. strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Reduced tensor [num_tokens, hidden_dim] """ - if len(inp.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") + # Check ndims here as the shape check is done in the kernel launch code. + if len(input.shape) != 2: + raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") module = get_trtllm_mnnvl_comm_module() use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, inp.shape[0], inp.shape[1], inp.dtype) + and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) ) module.trtllm_mnnvl_allreduce_fusion( - inp, + input, workspace.mc_ptr, workspace.uc_ptrs_dev, workspace.uc_ptr_local, @@ -256,7 +265,7 @@ def trtllm_mnnvl_all_reduce( False, # No RMSNorm Fusion launch_with_pdl, use_oneshot, - out, + output, None, None, None, @@ -265,36 +274,60 @@ def trtllm_mnnvl_all_reduce( def trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output: torch.Tensor, - normed_output: torch.Tensor, - shard_input: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + input: torch.Tensor, + residual_in: torch.Tensor, gamma: torch.Tensor, - epsilon: float, - residual: torch.Tensor, - launch_with_pdl: bool, + workspace: MNNVLAllreduceFusionWorkspace, + epsilon: Optional[float] = None, + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: """Performs MNNVL Allreduce + RMSNorm. - This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank. Args: - prenorm_output: Output tensor for prenorm results - normed_output: Output tensor for normalized results - shard_input: Input tensor shard + input: Input tensor [num_tokens, hidden_dim] + residual_in: Residual input tensor [num_tokens, hidden_dim] + gamma: Gamma tensor [hidden_dim] workspace: MNNVLAllreduceFusionWorkspace - gamma: The gamma (norm weight) parameter for RMSNorm - epsilon: The epsilon parameter for RMSNorm - residual: The residual tensor to add + epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. + output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. + residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. launch_with_pdl: Whether to launch with PDL + strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. + Returns: + output: Normalized tensor [num_tokens, hidden_dim] + residual_out: Residual output tensor [num_tokens, hidden_dim] """ - if len(shard_input.shape) != 2: + + if epsilon is None: + epsilon = torch.finfo(input.dtype).eps + + if len(input.shape) != 2: + raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + if len(residual_in.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." + ) + if gamma.numel() != input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {input.shape[1]} elements." + ) + if output is None: + output = torch.empty_like(input) + elif len(output.shape) != 2: + raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + if residual_out is None: + residual_out = torch.empty_like(residual_in) + elif len(residual_out.shape) != 2: raise ValueError( - f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + f"The residual output tensor must be 2D, got {len(residual_out.shape)}D. The shape is {residual_out.shape}." ) module = get_trtllm_mnnvl_comm_module() @@ -303,14 +336,14 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( strategy == MNNVLAllreduceFusionStrategy.AUTO and MNNVLAllreduceFusionStrategy.is_one_shot( workspace.tp_size, - shard_input.shape[0], - shard_input.shape[1], - shard_input.dtype, + input.shape[0], + input.shape[1], + input.dtype, ) ) module.trtllm_mnnvl_allreduce_fusion( - shard_input, + input, workspace.mc_ptr, workspace.uc_ptrs_dev, workspace.uc_ptr_local, @@ -320,9 +353,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( True, # RMSNorm Fusion launch_with_pdl, use_oneshot, - normed_output, - prenorm_output, - residual, + output, + residual_out, + residual_in, gamma, epsilon, ) + return output, residual_out From 17a129207dec879eb1236d4195c641430285cbed Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 16:40:48 -0800 Subject: [PATCH 04/32] Staging changes, result is wrong. --- csrc/trtllm_mnnvl_allreduce.cu | 10 +- flashinfer/comm/mnnvl.py | 415 ++++++++++-------- flashinfer/comm/trtllm_mnnvl_ar.py | 9 +- flashinfer/jit/comm.py | 1 + .../comm/trtllm_mnnvl_allreduce.cuh | 15 +- tests/comm/test_trtllm_mnnvl_allreduce.py | 265 +++++------ 6 files changed, 363 insertions(+), 352 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index ad23037ff3..c7215a4241 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -58,12 +58,14 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; if (rmsnorm_fusion) { - TVM_FFI_ICHECK(residual_out.size(0) == num_tokens && residual_out.size(1) == token_dim) + TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && + residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) - << ") but got (" << residual_out.size(0) << ", " << residual_out.size(1) << ")"; - TVM_FFI_ICHECK(gamma.size(0) == token_dim) + << ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1) + << ")"; + TVM_FFI_ICHECK(gamma.value().size(0) == token_dim) << "gamma must have the same shape as token dimension (" << token_dim << ") but got (" - << gamma.size(0) << ")"; + << gamma.value().size(0) << ")"; } // Create the parameters struct diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index a1a8c58d02..520f6e4880 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -16,6 +16,12 @@ import ctypes import logging import os +import socket +import array +import random + +import contextlib + from abc import ABC, abstractmethod from dataclasses import dataclass import platform @@ -35,8 +41,7 @@ from cuda import cuda except ImportError as e: raise ImportError( - "Could not import the 'cuda' module. " - "Please install cuda-python that matches your CUDA version." + "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." ) from e from ..cuda_utils import checkCudaErrors @@ -57,9 +62,7 @@ def round_up(val: int, gran: int) -> int: return (val + gran - 1) & ~(gran - 1) -def create_tensor_from_cuda_memory( - ptr: int, shape: tuple, dtype: torch.dtype, device_id: int -) -> torch.Tensor: +def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, device_id: int) -> torch.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. @@ -81,9 +84,7 @@ def create_tensor_from_cuda_memory( element_size = torch.tensor([], dtype=dtype).element_size() # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) - capsule_wrapper = create_dlpack_capsule( - ptr, element_size, element_size, numel, dtype, device_id - ) + capsule_wrapper = create_dlpack_capsule(ptr, element_size, element_size, numel, dtype, device_id) # Convert to tensor and reshape tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) @@ -123,24 +124,25 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: return False -def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ if not host_ptr_array: return None + for addr in host_ptr_array: + print(f"DEBUG: ptr_array: 0x{addr:x}") + ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) - checkCudaErrors( - cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) - ) + checkCudaErrors(cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)) # c_array should be freed by GC - return device_ptr + return int(device_ptr) class CommBackend(ABC): @@ -155,6 +157,9 @@ def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... + @abstractmethod + def bcast(self, data: Any, root: int) -> Any: ... + @abstractmethod def barrier(self) -> None: ... @@ -212,6 +217,9 @@ def Get_size(self) -> int: def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) + def bcast(self, data: Any, root: int) -> Any: + return self._mpicomm.bcast(data, root) + def barrier(self): self._mpicomm.Barrier() @@ -287,18 +295,14 @@ def initialize(): @staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) MnnvlMemory.comm = comm # type: ignore[assignment] @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm - comm = MpiComm().Split( - mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank - ) + comm = MpiComm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) MnnvlMemory.comm = comm return comm @@ -314,9 +318,7 @@ def get_allocation_prop(dev_id: int): arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch if is_on_aarch64: - allocation_prop.requestedHandleTypes = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ) + allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC else: allocation_prop.requestedHandleTypes = ( cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR @@ -332,27 +334,19 @@ def get_allocation_granularity(dev_id: int): option = cuda.CUmemAllocationGranularity_flags( cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ) - granularity = checkCudaErrors( - cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option) - ) + granularity = checkCudaErrors(cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)) MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): - page_count = ( - size + MnnvlMemory.fabric_page_size - 1 - ) // MnnvlMemory.fabric_page_size + page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size - logging.info( - f"[MnnvlMemory] creating address with stride={current_rank_stride}" - ) + logging.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}") comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size - ptr = checkCudaErrors( - cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0) - ) + ptr = checkCudaErrors(cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)) MnnvlMemory.current_start_address = int(ptr) MnnvlMemory.current_rank_stride = current_rank_stride MnnvlMemory.current_mem_offset = 0 @@ -363,44 +357,29 @@ def open_mnnvl_memory(mapping: Mapping, size: int): dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id - assert dev_id == MnnvlMemory.dev_id, ( - f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" - ) + assert ( + dev_id == MnnvlMemory.dev_id + ), f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size - assert all(x == size for x in all_rank_allocate_sizes), ( - "Not all rank allocating same size." - ) + assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size." granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if ( - MnnvlMemory.current_mem_offset + aligned_size - > MnnvlMemory.current_rank_stride - ): + if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride: MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) - assert ( - MnnvlMemory.current_mem_offset + aligned_size - <= MnnvlMemory.current_rank_stride - ) + assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) - allocated_mem_handle = checkCudaErrors( - cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) - ) + allocated_mem_handle = checkCudaErrors(cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)) exported_fabric_handle = checkCudaErrors( - cuda.cuMemExportToShareableHandle( - allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 - ) + cuda.cuMemExportToShareableHandle(allocated_mem_handle, allocation_prop.requestedHandleTypes, 0) ) - if ( - allocation_prop.requestedHandleTypes - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): + if allocation_prop.requestedHandleTypes == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: all_handles_data = comm.allgather(exported_fabric_handle.data) else: all_handles_data = comm.allgather(exported_fabric_handle) @@ -414,9 +393,7 @@ def open_mnnvl_memory(mapping: Mapping, size: int): pidfd = syscall(SYS_pidfd_open, pid, 0) if pidfd < 0: err = ctypes.get_errno() - raise RuntimeError( - f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" - ) + raise RuntimeError(f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}") pidfds.append(pidfd) remote_fds = [] @@ -431,9 +408,7 @@ def open_mnnvl_memory(mapping: Mapping, size: int): "to your docker run command." ) else: - error_msg += ( - " This may be due to kernel version (requires Linux 5.6+)." - ) + error_msg += " This may be due to kernel version (requires Linux 5.6+)." raise RuntimeError(error_msg) remote_fds.append(remote_fd) @@ -449,27 +424,19 @@ def open_mnnvl_memory(mapping: Mapping, size: int): for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( - MnnvlMemory.current_start_address - + MnnvlMemory.current_rank_stride * i - + MnnvlMemory.current_mem_offset + MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset ) if i == comm_rank: # Local memory mapping mem_handles[i] = allocated_mem_handle - checkCudaErrors( - cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0) - ) + checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0)) else: # Fabric memory mapping imported_mem_handle = checkCudaErrors( - cuda.cuMemImportFromShareableHandle( - remote_handle_data, allocation_prop.requestedHandleTypes - ) + cuda.cuMemImportFromShareableHandle(remote_handle_data, allocation_prop.requestedHandleTypes) ) mem_handles[i] = imported_mem_handle - checkCudaErrors( - cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) - ) + checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0)) checkCudaErrors(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)) @@ -526,20 +493,14 @@ def support_nvlink(need_all_up: bool = True): available_links = 0 for link_idx in range(link_count): try: - if pynvml.nvmlDeviceGetNvLinkCapability( - handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED - ): + if pynvml.nvmlDeviceGetNvLinkCapability(handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED): available_links += 1 is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx) if is_active: active_links += 1 except pynvml.NVMLError_NotSupported: continue - return ( - active_links == available_links and available_links > 0 - if need_all_up - else available_links > 0 - ) + return active_links == available_links and available_links > 0 if need_all_up else available_links > 0 @staticmethod def supports_mnnvl() -> bool: @@ -551,6 +512,103 @@ def supports_mnnvl() -> bool: return support_nvlink_and_all_up +# The helper class for passing the FD handle over the socket. +class IpcSocket: + """Unix Domain Socket for IPC file descriptor passing""" + + def __init__(self, rank: int, op_id: int, use_abstract=True): + """ + Initialize IPC socket + + Args: + rank: Process rank + op_id: Unique operation ID (hash) + use_abstract: Use Linux abstract socket namespace + """ + self.rank = rank + self.op_id = op_id + self.use_abstract = use_abstract + + # Create Unix domain socket (DGRAM for compatibility with C code) + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + + # Create unique socket name + socket_name = f"/tmp/mcastmem-socket-{rank}-{op_id:x}" + + if use_abstract: + # Linux abstract socket: prepend null byte + self.socket_path = "\0" + socket_name + else: + self.socket_path = socket_name + # Remove existing socket file if it exists + with contextlib.suppress(FileNotFoundError): + os.unlink(socket_name) + + # Bind socket + self.sock.bind(self.socket_path) + + def send_fd(self, fd: int, dest_rank: int, dest_op_id: Optional[int] = None): + """ + Send a file descriptor to another process + + Args: + fd: File descriptor to send + dest_rank: Destination process rank + dest_op_id: Destination operation ID + """ + # Construct destination socket path + dest_op_id = dest_op_id or self.op_id + dest_socket_name = f"/tmp/mcastmem-socket-{dest_rank}-{dest_op_id:x}" + + if self.use_abstract: + dest_path = "\0" + dest_socket_name + else: + dest_path = dest_socket_name + + # Prepare message with file descriptor + # Send dummy byte as data (required) + dummy_data = b"\x00" + + # Pack file descriptor in ancillary data (SCM_RIGHTS) + fds = array.array("i", [fd]) + ancillary = [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds.tobytes())] + + # Send message with file descriptor + self.sock.sendmsg([dummy_data], ancillary, 0, dest_path) + + def recv_fd(self): + """ + Receive a file descriptor from another process + + Returns: + int: Received file descriptor + """ + # Receive message with ancillary data + # Maximum size for ancillary data containing one fd + fds = array.array("i") + msg, ancdata, flags, addr = self.sock.recvmsg( + 1, + socket.CMSG_SPACE(fds.itemsize), # Buffer size for dummy data # Ancillary data size + ) + + # Extract file descriptor from ancillary data + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds = array.array("i") + fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + return fds[0] + + raise RuntimeError("No file descriptor received") + + def close(self): + """Close the socket""" + self.sock.close() + if not self.use_abstract and self.socket_path: + with contextlib.suppress(FileNotFoundError): + os.unlink(self.socket_path) + + +# TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -562,6 +620,7 @@ def __init__( device_idx: int, is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -588,6 +647,7 @@ def __init__( self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 + self.comm_backend = comm_backend_for_handle_transfer or MPIBackend() # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr @@ -596,9 +656,9 @@ def __init__( self.signal_pads_dev = 0 # std::vector mSignalPadsDev self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle - self.uc_handles: List[ - int - ] = [] # std::vector mUcHandles + self.uc_handles: List[int] = [] # std::vector mUcHandles + + self._shareable_handle_type = None # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 @@ -612,9 +672,7 @@ def __init__( ) ) if multicast_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support multicasting." - ) + raise RuntimeError("[McastDeviceMemory] Device does not support multicasting.") # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) @@ -634,23 +692,21 @@ def __init__( ) ) if fabric_handle_supported == 0: - raise RuntimeError( - "[McastDeviceMemory] Device does not support fabric handle." - ) - - self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer) + raise RuntimeError("[McastDeviceMemory] Device does not support fabric handle.") + # Use fabric handle for multi-node NVLS + self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC else: - # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem - raise NotImplementedError("Single-node NVLS allocation not implemented yet") + self._init_ipc_socket() + # Use NVLink handle for single-node NVLS + self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads self.signal_pads = [0] * self.group_size for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: - checkCudaErrors( - cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) - ) + checkCudaErrors(cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)) # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) @@ -693,29 +749,19 @@ def __del__(self): checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: - checkCudaErrors( - cuda.cuMemUnmap( - self.uc_ptrs[rank], self.allocation_size - ) - ) + checkCudaErrors(cuda.cuMemUnmap(self.uc_ptrs[rank], self.allocation_size)) except Exception as e: - print( - f"Destructor: Failed to release UC handle for rank {rank}: {e}" - ) + print(f"Destructor: Failed to release UC handle for rank {rank}: {e}") # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: - checkCudaErrors( - cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) - ) + checkCudaErrors(cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size)) # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) - checkCudaErrors( - cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) - ) + checkCudaErrors(cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size)) checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") @@ -760,9 +806,16 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size - def _alloc_mn_mcast_mem( - self, buf_size: int, comm_backend_for_handle_transfer: Any = None - ): + def _init_ipc_socket(self): + if self.group_rank == 0: + # Gnerate the opId + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm_backend.bcast(opId, root=0) + self._ipc_socket = IpcSocket(self.group_rank, opId) + + def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context @@ -770,25 +823,16 @@ def _alloc_mn_mcast_mem( current_device = checkCudaErrors(cuda.cuCtxGetDevice()) if int(current_device) != self.device_idx: - print( - f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" - ) + print(f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}") except Exception as e: print(f"Error checking CUDA context: {e}") - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - # Set up allocation properties - handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + # Set up allocation properties allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = handle_type + allocation_prop.requestedHandleTypes = self._shareable_handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() - allocation_prop.location.type = ( - cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - ) + allocation_prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE allocation_prop.location.id = self.device_idx allocation_prop.allocFlags.gpuDirectRDMACapable = 1 @@ -802,15 +846,13 @@ def _alloc_mn_mcast_mem( ) # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); - self.allocation_size = round_up( - buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity - ) + self.allocation_size = round_up(buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity) # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = handle_type + mc_prop.handleTypes = self._shareable_handle_type # Get multicast granularity mc_granularity = checkCudaErrors( @@ -826,30 +868,43 @@ def _alloc_mn_mcast_mem( self.uc_handles = [0] * self.group_size # Allocate local GPU memory - self.uc_handles[self.group_rank] = checkCudaErrors( - cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) - ) + self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) # Export local handle to fabric handle - my_fabric_handle = checkCudaErrors( + local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) - # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + # All-gather fabric handles + all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) + else: + # Implement the allgather logic with ipc socket + # TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help. + all_shareable_uc_handles = [None] * self.group_size + for i in range(self.group_size): + self.comm_backend.barrier() + # Send to peer at offset i + dest_rank = (self.group_rank + i) % self.group_size + self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) + # Receive from peer at offset -i + src_rank = (self.group_rank + self.group_size - i) % self.group_size + all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() cuda.cuCtxSynchronize() + print(f"[Rank {self.group_rank}] all_shareable_uc_handles: {all_shareable_uc_handles}") + # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - all_fabric_handles[p], - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + all_shareable_uc_handles[p], + self._shareable_handle_type, ) ) @@ -858,29 +913,43 @@ def _alloc_mn_mcast_mem( # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - # Export multicast handle - mc_fabric_handle = checkCudaErrors( + # Export multicast handle, there's only one handle for the entire group + shareable_mc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + self._shareable_handle_type, 0, ) ) else: - mc_fabric_handle = None - - # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) + shareable_mc_handle = None + if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + # Broadcast multicast handle + shareable_mc_handle = self.comm_backend.bcast( + shareable_mc_handle.data if shareable_mc_handle else None, root=0 + ) + else: + # Implement bcast logic with ipc socket + if self.group_rank == 0: + for p in range(1, self.group_size): + self.comm_backend.barrier() + self._ipc_socket.send_fd(shareable_mc_handle, p) + else: + # Other ranks receive from rank 0 + # We need to order the receive to avoid a race condition bug we encountered. If driver fixed this issue, the additional barriers used for ordering can be removed. + for _ in range(self.group_rank): + self.comm_backend.barrier() + shareable_mc_handle = self._ipc_socket.recv_fd() + for _ in range(self.group_size - self.group_rank - 1): + self.comm_backend.barrier() # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( - mc_fabric_handle_data, - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, + shareable_mc_handle, + self._shareable_handle_type, ) ) @@ -893,9 +962,7 @@ def _alloc_mn_mcast_mem( # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) - ) + uc_base_ptr = checkCudaErrors(cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0)) self.uc_base_ptr = uc_base_ptr # Store for cleanup # Set up memory access descriptor @@ -909,27 +976,15 @@ def _alloc_mn_mcast_mem( for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors( - cuda.cuMemMap( - self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 - ) - ) + checkCudaErrors(cuda.cuMemMap(self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0)) # Set memory access permissions - checkCudaErrors( - cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) - ) + checkCudaErrors(cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1)) # Bind MC pointer - self.mc_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) - ) - checkCudaErrors( - cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) - ) - checkCudaErrors( - cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) - ) + self.mc_ptr = checkCudaErrors(cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0)) + checkCudaErrors(cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0)) + checkCudaErrors(cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1)) # Bind memory to multicast checkCudaErrors( @@ -958,9 +1013,7 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): # Calculate number of elements that fit in allocation_size num_elements = self.allocation_size // dsize - checkCudaErrors( - memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) - ) + checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) class McastGPUBuffer: @@ -1005,9 +1058,7 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_multicast_buffer( - self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 - ) -> torch.Tensor: + def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. @@ -1023,9 +1074,7 @@ def get_multicast_buffer( # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. raise NotImplementedError("Not implemented yet") - def get_unicast_buffer( - self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 - ) -> torch.Tensor: + def get_unicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: """ Returns a PyTorch tensor view of the unicast buffer portion. """ diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 839e03411c..0b5db72628 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -58,7 +58,7 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB - buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2)) * (8 * (1024**2))) + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) if buffer_size_in_bytes > (2**32 - 1): raise ValueError( @@ -186,6 +186,9 @@ def trtllm_mnnvl_allreduce_fusion( gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ + print( + f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" + ) module.trtllm_mnnvl_allreduce_fusion( input, multicast_buffer_ptr, @@ -342,6 +345,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( ) ) + print( + f"[Rank {workspace.rank}] workspace.mc_ptr: {workspace.mc_ptr}, workspace.uc_ptrs_dev: {workspace.uc_ptrs_dev}, workspace.uc_ptr_local: {workspace.uc_ptr_local}" + ) + module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 27661b1fe2..4f59c8930e 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -36,6 +36,7 @@ def gen_trtllm_mnnvl_comm_module() -> JitSpec: [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", ], + extra_cuda_cflags=["-lineinfo"], ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 9198df8775..2177cfc618 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -536,6 +536,7 @@ __global__ void __launch_bounds__(1024) T* stagePtrLocal = reinterpret_cast(flag.getCurLamportBuf(inputPtrs[rank], 0)); if (packedIdx * kELTS_PER_THREAD >= tokenDim) { + flag.ctaArrive(); flag.clearDirtyLamportBuf(inputPtrs[rank], -1); return; } @@ -545,7 +546,7 @@ __global__ void __launch_bounds__(1024) val.packed = loadPacked(&shardPtr[threadOffset]); #pragma unroll for (int i = 0; i < kELTS_PER_THREAD; i++) { - if (isNegZero(val.elements[i])) val.elements[i] = toFloat(0.f); + if (isNegZero(val.elements[i])) val.elements[i] = fromFloat(0.f); } reinterpret_cast( @@ -641,7 +642,7 @@ __global__ void __launch_bounds__(1024) #pragma unroll for (int i = 0; i < kELTS_PER_THREAD; i++) { packedAccum.elements[i] = fromFloat(toFloat(packedAccum.elements[i]) * rcpRms * - fromFloat(gamma.elements[i])); + toFloat(gamma.elements[i])); } } reinterpret_cast(&outputPtr[threadOffset])[0] = packedAccum.packed; @@ -725,18 +726,24 @@ cudaError_t oneshotAllreduceFusionDispatch(AllReduceFusionParams const& params) // FIXME: Do we need other world sizes? case 2: DISPATCH_ALLREDUCE_KERNEL(2); + break; case 4: DISPATCH_ALLREDUCE_KERNEL(4); + break; case 8: DISPATCH_ALLREDUCE_KERNEL(8); + break; case 16: DISPATCH_ALLREDUCE_KERNEL(16); + break; case 32: DISPATCH_ALLREDUCE_KERNEL(32); + break; case 64: DISPATCH_ALLREDUCE_KERNEL(64); + break; default: - FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nranks) + + FLASHINFER_ERROR("MNNVL AllReduce: unsupported world_size " + std::to_string(params.nRanks) + ". Supported sizes: {2, 4, 8, 16, 32, 64}"); return cudaErrorInvalidValue; } @@ -1145,7 +1152,7 @@ cudaError_t twoshotAllreduceFusionDispatch(AllReduceFusionParams const& params) int const dimPadded = round_up(tokenDim, numEltsPerThread * rnNumThreads); int const iters = dimPadded / rnNumThreads; - size_t const smemSize = 3 * rnBlockSize * iters * getDTypeSize(params.dType); + size_t const smemSize = 3 * rnBlockSize * iters * sizeof(T); FLASHINFER_LOG_DEBUG( "[MNNVL AllReduceTwoShotRMSNorm] Dispatch: grid size: (%d, %d, 1), block_size: %d, " diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e7274c46f0..e0758c271c 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -5,6 +5,10 @@ import torch from mpi4py import MPI # Added MPI import +from flashinfer.utils import set_log_level + +set_log_level("debug") + import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import CommBackend, MpiComm @@ -24,19 +28,13 @@ def row_linear_residual_norm_fusion_forward( mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], - multicast_ptr: int, - buffer_ptrs_dev: int, - unicast_ptr: int, - max_num_elements_mnnvl: int, - buffer_flags_mnnvl: torch.Tensor, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, + workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): x = x.cuda() residual = residual.cuda() norm_weight = norm_weight.cuda() reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank if comm_backend_for_handle_transfer is None: comm = MpiComm() @@ -50,75 +48,40 @@ def func( norm_weight, eps, enable_fusion, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, + workspace, ): # For both fused and unfused cases: shape = input.shape - - assert max_num_elements_mnnvl % hidden_size == 0 - input = input.view(-1, shape[-1]) - - buffer_M = max_num_elements_mnnvl // hidden_size + use_pdl = True if enable_fusion: - use_pdl = True - - prenorm_output = torch.empty_like(residual) - normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() - trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output, - normed_output, + output, residual_out = trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( input, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - buffer_M, - buffer_flags_mnnvl, - tensor_parallel_size, - tensor_parallel_rank, + residual, norm_weight, + workspace, eps, - residual, - use_pdl, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, ) - return normed_output.view(shape), prenorm_output.view(shape) + return output.view(shape), residual_out.view(shape) else: output = torch.empty_like(input) - trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( + output = trtllm_mnnvl_ar.trtllm_mnnvl_allreduce( input, - multicast_ptr, - buffer_ptrs_dev, - buffer_M, - buffer_flags_mnnvl, - tensor_parallel_size, - tensor_parallel_rank, - True, # wait_for_results - False, # launch_with_pdl - output, # Need to provide output tensor since we are writing them out. + workspace, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, ) return (output.view(shape),) - output = func( - x.clone(), - residual.clone(), - norm_weight, - eps, - fusion, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - ) + output = func(x.clone(), residual.clone(), norm_weight, eps, fusion, workspace) assert output[0].shape == reference_output[0].shape @@ -173,7 +136,8 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + if monkeypatch is not None: + monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. # Get MPI info rank = MPI.COMM_WORLD.Get_rank() @@ -198,43 +162,32 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: - print( - f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" - ) - print( - f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" - ) + print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") + print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") tensor_parallel_size = world_size eps = 1e-5 - torch.manual_seed(42) + torch.manual_seed(42 + rank) # Track if this rank failed rank_failed = False failure_message = "" try: - # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list - # This workspace is sized for the maximum expected sequence length and can be reused within each list - # Each parameterized list gets its own fresh workspace allocation - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( - mapping, dtype, buffer_size_in_bytes=explicit_workspace_bytes - ) - ) - - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() - buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() - unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( - mapping.tp_rank + required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( + mapping.tp_size, + max(seq_lens), + hidden_size, + dtype, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.TWOSHOT, ) + workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) # Test each sequence length with the same workspace (reusing allocated buffers within this list) for seq_len in seq_lens: if rank == 0: - print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" - ) + print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") + print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") # Generate test data (same on all ranks due to same seed) x_full = torch.randn( @@ -242,12 +195,8 @@ def run_mnnvl_ar_full( dtype=dtype, device=torch.device("cuda"), ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") - ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") - ) + residual = torch.randn((seq_len, hidden_size), dtype=dtype, device=torch.device("cuda")) + norm_weight = torch.randn((hidden_size,), dtype=dtype, device=torch.device("cuda")) # Each rank gets its slice of the input x = x_full[rank, :, :] @@ -258,11 +207,7 @@ def run_mnnvl_ar_full( # Fused case: AllReduce + Residual Add + RMS Norm allreduce_result = torch.sum(x_full, dim=0) # AllReduce result residual_out = allreduce_result + residual # Add residual - print( - "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device - ) - ) + print("Device of residual_out:{}, norm_weight:{}".format(residual_out.device, norm_weight.device)) norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) reference_output = (norm_out, residual_out) @@ -282,24 +227,21 @@ def run_mnnvl_ar_full( mapping, fusion, reference_output, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, - buffer_flags_mnnvl, + workspace, ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print( - f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" - ) + print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}") except Exception as e: rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) + import traceback + + print(traceback.format_exc()) # Gather failure status from all ranks for logging all_failures = MPI.COMM_WORLD.allgather(rank_failed) @@ -310,16 +252,16 @@ def run_mnnvl_ar_full( print(f"Test failed on ranks: {failed_ranks}") # Cleanup before re-raising - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Re-raise the original exception so it can be caught by pytest.raises in negative tests raise finally: # Ensure cleanup happens for this list's workspace - if "mcast_buffer_mnnvl" in locals(): - del mcast_buffer_mnnvl + if "workspace" in locals(): + del workspace # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() @@ -348,61 +290,64 @@ def test_mnnvl_allreduce_default_workspace( run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) -"""Test with explicit workspace size""" - - -@pytest.mark.parametrize( - "seq_lens", - [ - [1, 4, 180], - ], -) -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -def test_mnnvl_allreduce_explicit_workspace( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test MNNVL AllReduce with explicitly calculated workspace size.""" - # Calculate workspace to fit the maximum sequence length - # buffer shape: [3, 2, buffer_tokens, hidden_dim] - explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) - run_mnnvl_ar_full( - monkeypatch, - seq_lens, - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=explicit_workspace_bytes, - ) - - -"""Negative test: workspace too small""" - - -@pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096]) -def test_mnnvl_allreduce_workspace_too_small( - monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -): - """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" - # Use a large sequence length that won't fit in a small workspace - seq_len = 180 - - # Create a workspace that's too small (only enough for 10 tokens) - small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - - # Expect a ValueError with a message about buffer_M being too small - with pytest.raises((ValueError, RuntimeError)) as exc_info: - run_mnnvl_ar_full( - monkeypatch, - [seq_len], - fusion, - dtype, - hidden_size, - explicit_workspace_bytes=small_workspace_bytes, - ) - - # Verify the error message contains the expected text - assert "greater than the buffer_M" in str(exc_info.value) +if __name__ == "__main__": + run_mnnvl_ar_full(None, [15], False, torch.bfloat16, 4096) + +# """Test with explicit workspace size""" + + +# @pytest.mark.parametrize( +# "seq_lens", +# [ +# [1, 4, 180], +# ], +# ) +# @pytest.mark.parametrize("fusion", [False, True]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +# def test_mnnvl_allreduce_explicit_workspace( +# monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +# ): +# """Test MNNVL AllReduce with explicitly calculated workspace size.""" +# # Calculate workspace to fit the maximum sequence length +# # buffer shape: [3, 2, buffer_tokens, hidden_dim] +# explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) +# run_mnnvl_ar_full( +# monkeypatch, +# seq_lens, +# fusion, +# dtype, +# hidden_size, +# explicit_workspace_bytes=explicit_workspace_bytes, +# ) + + +# """Negative test: workspace too small""" + + +# @pytest.mark.parametrize("fusion", [False, True]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("hidden_size", [2048, 4096]) +# def test_mnnvl_allreduce_workspace_too_small( +# monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int +# ): +# """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" +# # Use a large sequence length that won't fit in a small workspace +# seq_len = 180 + +# # Create a workspace that's too small (only enough for 10 tokens) +# small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 + +# # Expect a ValueError with a message about buffer_M being too small +# with pytest.raises((ValueError, RuntimeError)) as exc_info: +# run_mnnvl_ar_full( +# monkeypatch, +# [seq_len], +# fusion, +# dtype, +# hidden_size, +# explicit_workspace_bytes=small_workspace_bytes, +# ) + +# # Verify the error message contains the expected text +# assert "greater than the buffer_M" in str(exc_info.value) From 4caf71aa32bec398b5b6bfcb05f2abdd93d7cfc3 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 19:32:26 -0800 Subject: [PATCH 05/32] Passing the test. --- flashinfer/comm/trtllm_mnnvl_ar.py | 2 + flashinfer/jit/comm.py | 1 - tests/comm/test_trtllm_mnnvl_allreduce.py | 179 ++++++++++------------ 3 files changed, 81 insertions(+), 101 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 0b5db72628..eae919e5e0 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -275,6 +275,8 @@ def trtllm_mnnvl_allreduce( None, ) + return output + def trtllm_mnnvl_fused_allreduce_rmsnorm( input: torch.Tensor, diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 4f59c8930e..27661b1fe2 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -36,7 +36,6 @@ def gen_trtllm_mnnvl_comm_module() -> JitSpec: [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", ], - extra_cuda_cflags=["-lineinfo"], ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index e0758c271c..6b89661650 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -5,10 +5,6 @@ import torch from mpi4py import MPI # Added MPI import -from flashinfer.utils import set_log_level - -set_log_level("debug") - import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import CommBackend, MpiComm @@ -23,24 +19,21 @@ def row_linear_residual_norm_fusion_forward( residual: torch.Tensor, norm_weight: torch.Tensor, eps: float, - hidden_size: int, - dtype: torch.dtype, mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): - x = x.cuda() - residual = residual.cuda() - norm_weight = norm_weight.cuda() - reference_output = tuple(t.cuda() for t in reference_output) - tensor_parallel_rank = mapping.tp_rank +<<<<<<< HEAD if comm_backend_for_handle_transfer is None: comm = MpiComm() else: comm = comm_backend_for_handle_transfer comm.barrier() +======= + MPI.COMM_WORLD.barrier() +>>>>>>> bca4f5d9 (Passing the test.) def func( input, @@ -65,7 +58,7 @@ def func( workspace, eps, launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return output.view(shape), residual_out.view(shape) @@ -77,7 +70,7 @@ def func( input, workspace, launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.ONESHOT, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) return (output.view(shape),) @@ -118,13 +111,49 @@ def func( """Helper function to run the core MNNVL AllReduce test logic""" +def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion: bool): + # Communicator used for passing data between ranks + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + if rank == 0: + x_full = torch.randn((world_size, seq_len, hidden_size), dtype=dtype) + residual = torch.randn((seq_len, hidden_size), dtype=dtype) + norm_weight = torch.randn((hidden_size,), dtype=dtype) + else: + x_full = None + residual = None + norm_weight = None + + # Use lowercase bcast() for Python object broadcasting + x_full = comm.bcast(x_full, root=0) + residual = comm.bcast(residual, root=0) + norm_weight = comm.bcast(norm_weight, root=0) + + x_full = x_full.cuda() + residual = residual.cuda() + norm_weight = norm_weight.cuda() + + x_local = x_full[rank, :, :] + reference_output: Tuple[torch.Tensor, ...] = None + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + norm_out = rmsnorm( + residual_out, norm_weight, torch.finfo(dtype).eps, enable_pdl=False + ) + + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + return (x_local, residual, norm_weight), reference_output + + def run_mnnvl_ar_full( - monkeypatch, - seq_lens: list[int], - fusion: bool, - dtype: torch.dtype, - hidden_size: int, - explicit_workspace_bytes: int | None = None, + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): """Core test logic for MNNVL AllReduce operations. @@ -136,18 +165,15 @@ def run_mnnvl_ar_full( hidden_size: Hidden dimension size explicit_workspace_bytes: If provided, use this workspace size instead of default """ - if monkeypatch is not None: - monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. + comm = MPI.COMM_WORLD # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + rank = comm.Get_rank() + world_size = comm.Get_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") - - # Ensure we have exactly 2 ranks for this test if world_size < 2: pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}") @@ -162,10 +188,19 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: +<<<<<<< HEAD print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") tensor_parallel_size = world_size +======= + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) +>>>>>>> bca4f5d9 (Passing the test.) eps = 1e-5 torch.manual_seed(42 + rank) @@ -179,13 +214,23 @@ def run_mnnvl_ar_full( max(seq_lens), hidden_size, dtype, - trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.TWOSHOT, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, ) workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) - # Test each sequence length with the same workspace (reusing allocated buffers within this list) + test_data = [] for seq_len in seq_lens: + (x_local, residual, norm_weight), reference_output = prepare_test_data( + seq_len, hidden_size, dtype, fusion + ) + test_data.append( + (seq_len, x_local, residual, norm_weight, reference_output) + ) + + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + for seq_len, x, residual, norm_weight, reference_output in test_data: if rank == 0: +<<<<<<< HEAD print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") @@ -215,6 +260,11 @@ def run_mnnvl_ar_full( # Non-fused case: Only AllReduce allreduce_result = torch.sum(x_full, dim=0) # AllReduce result reference_output = (allreduce_result,) +======= + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" + ) +>>>>>>> bca4f5d9 (Passing the test.) # Run the test with the same workspace row_linear_residual_norm_fusion_forward( @@ -222,8 +272,6 @@ def run_mnnvl_ar_full( residual, norm_weight, eps, - hidden_size, - dtype, mapping, fusion, reference_output, @@ -272,82 +320,13 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize( "seq_lens", - [ - [1], - [4], - [15], - [27, 11, 24], - [127], - ], + [[1], [4], [15], [27, 11, 24, 256], [127], [998, 2048]], ) @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +@pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) def test_mnnvl_allreduce_default_workspace( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): """Test MNNVL AllReduce with default workspace size.""" run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) - - -if __name__ == "__main__": - run_mnnvl_ar_full(None, [15], False, torch.bfloat16, 4096) - -# """Test with explicit workspace size""" - - -# @pytest.mark.parametrize( -# "seq_lens", -# [ -# [1, 4, 180], -# ], -# ) -# @pytest.mark.parametrize("fusion", [False, True]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) -# def test_mnnvl_allreduce_explicit_workspace( -# monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int -# ): -# """Test MNNVL AllReduce with explicitly calculated workspace size.""" -# # Calculate workspace to fit the maximum sequence length -# # buffer shape: [3, 2, buffer_tokens, hidden_dim] -# explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) -# run_mnnvl_ar_full( -# monkeypatch, -# seq_lens, -# fusion, -# dtype, -# hidden_size, -# explicit_workspace_bytes=explicit_workspace_bytes, -# ) - - -# """Negative test: workspace too small""" - - -# @pytest.mark.parametrize("fusion", [False, True]) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("hidden_size", [2048, 4096]) -# def test_mnnvl_allreduce_workspace_too_small( -# monkeypatch, fusion: bool, dtype: torch.dtype, hidden_size: int -# ): -# """Test that MNNVL AllReduce fails gracefully when workspace is too small.""" -# # Use a large sequence length that won't fit in a small workspace -# seq_len = 180 - -# # Create a workspace that's too small (only enough for 10 tokens) -# small_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * 10 - -# # Expect a ValueError with a message about buffer_M being too small -# with pytest.raises((ValueError, RuntimeError)) as exc_info: -# run_mnnvl_ar_full( -# monkeypatch, -# [seq_len], -# fusion, -# dtype, -# hidden_size, -# explicit_workspace_bytes=small_workspace_bytes, -# ) - -# # Verify the error message contains the expected text -# assert "greater than the buffer_M" in str(exc_info.value) From 9a6beec1a41a8b24092d01321bbabd3fca433f35 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 19:54:54 -0800 Subject: [PATCH 06/32] Remove debug prints and add compatability interface. --- flashinfer/comm/mnnvl.py | 9 - flashinfer/comm/trtllm_mnnvl_ar.py | 237 +++++++++++++++++++++- tests/comm/test_trtllm_mnnvl_allreduce.py | 18 +- 3 files changed, 238 insertions(+), 26 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 520f6e4880..787f243995 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -131,9 +131,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: if not host_ptr_array: return None - for addr in host_ptr_array: - print(f"DEBUG: ptr_array: 0x{addr:x}") - ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) @@ -719,9 +716,6 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if not self.is_multi_node: - return - # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): @@ -884,7 +878,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) else: # Implement the allgather logic with ipc socket - # TODO: Do we need to model ipc socket as a comm backend? My tenative answer is no as it is not able to perform bootstrap without other communicator's help. all_shareable_uc_handles = [None] * self.group_size for i in range(self.group_size): self.comm_backend.barrier() @@ -896,8 +889,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() cuda.cuCtxSynchronize() - print(f"[Rank {self.group_rank}] all_shareable_uc_handles: {all_shareable_uc_handles}") - # Import remote handles for p in range(self.group_size): if p != self.group_rank: diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index eae919e5e0..a9c7a026e4 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -11,6 +11,7 @@ from enum import Enum import torch +from typing_extensions import deprecated from flashinfer.comm.mapping import Mapping @@ -278,7 +279,7 @@ def trtllm_mnnvl_allreduce( return output -def trtllm_mnnvl_fused_allreduce_rmsnorm( +def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: torch.Tensor, residual_in: torch.Tensor, gamma: torch.Tensor, @@ -289,10 +290,10 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( launch_with_pdl: bool = False, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Performs MNNVL Allreduce + RMSNorm. + """Performs MNNVL Allreduce + Residual + RMSNorm. This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_allreduce on the shard_input. - After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + After this, it performs residual addition and RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank. Args: @@ -307,8 +308,8 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. Returns: - output: Normalized tensor [num_tokens, hidden_dim] - residual_out: Residual output tensor [num_tokens, hidden_dim] + output: Add-residual and normalized tensor [num_tokens, hidden_dim] + residual_out: Add-residual tensor [num_tokens, hidden_dim] """ if epsilon is None: @@ -347,10 +348,6 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( ) ) - print( - f"[Rank {workspace.rank}] workspace.mc_ptr: {workspace.mc_ptr}, workspace.uc_ptrs_dev: {workspace.uc_ptrs_dev}, workspace.uc_ptr_local: {workspace.uc_ptr_local}" - ) - module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -369,3 +366,225 @@ def trtllm_mnnvl_fused_allreduce_rmsnorm( epsilon, ) return output, residual_out + + +# Legacy API that has been deprecated; Left for backward compatibility +@deprecated( + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" +) +def get_allreduce_mnnvl_workspace( + mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None +) -> Tuple[McastGPUBuffer, torch.Tensor, int]: + """Get workspace buffers needed for multi-node NVLink all-reduce operation. + + This function allocates and initializes the workspace buffers required for performing + multi-node NVLink all-reduce operations. It creates: + 1. A multicast GPU buffer for communication between nodes + 2. A flags tensor to track buffer state + 3. Maximum number of elements that can fit in the buffer + + The buffer size is calculated to efficiently handle common hidden dimensions + (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. + + Args: + mapping: Tensor parallel mapping configuration containing rank info + dtype: Data type of the tensors being reduced + buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens + + Returns: + Tuple containing: + - McastGPUBuffer: Multicast buffer for inter-node communication + - torch.Tensor: Buffer flags tensor tracking state + - int: Maximum number of elements that can fit in buffer + """ + # buffer shape: [3, 2, buffer_tokens, hidden_dim] + stride = 3 * 2 * dtype.itemsize + # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 + # max_num_elements must be a multiple of 286720 + lcm_hidden_dim = 286720 + TARGET_WORKSPACE_SIZE_BYTES = ( + buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + ) + buffer_size_in_bytes = math.ceil( + TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) + ) * (lcm_hidden_dim * stride) + + # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. + workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes) + + mcast_buffer = workspace.mcast_buffer_handle + buffer_flags = workspace.buffer_flags + max_num_elements = workspace.buffer_size_bytes // stride + + return ( + mcast_buffer, + buffer_flags, + max_num_elements, + ) + + +@deprecated( + "trtllm_mnnvl_all_reduce is deprecated, use trtllm_mnnvl_allreduce instead. This function will be removed in the future." +) +def trtllm_mnnvl_all_reduce( + inp: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + buffer_M: int, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + wait_for_results: bool, + launch_with_pdl: bool, + out: Optional[torch.Tensor] = None, +) -> None: + """Perform a multi-node NVLink all-reduce operation across multiple GPUs. + + This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) + technology to efficiently combine tensors across multiple GPUs and nodes. + + There are 3 steps: + 1. scatter each GPU's input shard to the right unicast buffer + 2. perform all-reduce on each GPU + 3. broadcast the result to all GPUs + + Args: + inp: Local Input Shard + multicast_buffer_ptr: Pointer to the multicast buffer as an integer + buffer_ptrs_dev: Pointer to device buffer pointers as an integer + buffer_M: Maximum number of elements // hidden_dim + buffer_flags_mnnvl: Tensor containing buffer state flags + nranks: Total number of ranks participating in the all-reduce + rank: Current process rank + wait_for_results: If True, store the result to out + launch_with_pdl: If True, launch using Programmatic Dependent Launch + [Optional] out: Output tensor to store the result (required if wait_for_results is True) + + """ + + if len(inp.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if inp.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {inp.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) + module = get_trtllm_mnnvl_comm_module() + module.trtllm_mnnvl_allreduce_fusion( + input, + multicast_buffer_ptr, + buffer_ptrs_dev, + 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. + buffer_flags_mnnvl, + nranks, + rank, + False, # No RMSNorm Fusion + launch_with_pdl, + False, # Use two-shot + out, + None, + None, + None, + None, + ) + + +@deprecated( + "trtllm_mnnvl_fused_allreduce_rmsnorm is deprecated, use trtllm_mnnvl_fused_allreduce_add_rmsnorm instead. This function will be removed in the future." +) +def trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output: torch.Tensor, + normed_output: torch.Tensor, + shard_input: torch.Tensor, + multicast_buffer_ptr: int, # Pointer address as integer + buffer_ptrs_dev: int, # Pointer address as integer + unicast_ptr: int, # Local unicast buffer pointer + buffer_M: int, + buffer_flags_mnnvl: torch.Tensor, + nranks: int, + rank: int, + gamma: torch.Tensor, + epsilon: float, + residual: torch.Tensor, + launch_with_pdl: bool, +) -> None: + """Performs MNNVL TwoShot Allreduce + RMSNorm. + + This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. + After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. + Note: multicast buffer is the same as the unicast buffer for the current rank. + + Args: + prenorm_output: Output tensor for prenorm results + normed_output: Output tensor for normalized results + shard_input: Input tensor shard + multicast_buffer_ptr: Pointer address as integer for multicast buffer + buffer_ptrs_dev: Pointer address as integer for device buffer pointers + unicast_ptr: Pointer address as integer for unicast buffer + buffer_M: Maximum number of elements // hidden_dim + buffer_flags_mnnvl: Buffer flags for synchronization + nranks: Number of ranks in the tensor parallel group + rank: Current rank in the tensor parallel group + gamma: The gamma (norm weight) parameter for RMSNorm + epsilon: The epsilon parameter for RMSNorm + residual: The residual tensor to add + launch_with_pdl: Whether to launch with PDL + + """ + if len(shard_input.shape) != 2: + raise ValueError( + f"The input tensor must be 2D, got {len(shard_input.shape)}D. The shape is {shard_input.shape}." + ) + + # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. + if shard_input.shape[0] > buffer_M: + raise ValueError( + f"The number of tokens in the input tensor {shard_input.shape[0]} is greater than the buffer_M {buffer_M}. This is not supported. Please increase the workspace size, or decrease the amount of tokens to at most {buffer_M}." + ) + + if len(residual.shape) != 2: + raise ValueError( + f"The residual input tensor must be 2D, got {len(residual.shape)}D. The shape is {residual.shape}." + ) + if gamma.numel() != shard_input.shape[1]: + raise ValueError( + f"The gamma tensor must have the same number of elements as the hidden dimension, got {gamma.numel()} elements but expected {shard_input.shape[1]} elements." + ) + + if len(normed_output.shape) != 2: + raise ValueError( + f"The output tensor must be 2D, got {len(normed_output.shape)}D. The shape is {normed_output.shape}." + ) + + if len(prenorm_output.shape) != 2: + raise ValueError( + f"The prenorm output tensor must be 2D, got {len(prenorm_output.shape)}D. The shape is {prenorm_output.shape}." + ) + + module = get_trtllm_mnnvl_comm_module() + + module.trtllm_mnnvl_allreduce_fusion( + shard_input, + multicast_buffer_ptr, + buffer_ptrs_dev, + unicast_ptr, + buffer_flags_mnnvl, + nranks, + rank, + True, # RMSNorm Fusion + launch_with_pdl, + False, + normed_output, + prenorm_output, + residual, + gamma, + epsilon, + ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 6b89661650..b77c5a91d1 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -51,14 +51,16 @@ def func( if enable_fusion: trtllm_mnnvl_ar.mpi_barrier() - output, residual_out = trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( - input, - residual, - norm_weight, - workspace, - eps, - launch_with_pdl=use_pdl, - strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + output, residual_out = ( + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_add_rmsnorm( + input, + residual, + norm_weight, + workspace, + eps, + launch_with_pdl=use_pdl, + strategy=trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) ) return output.view(shape), residual_out.view(shape) From a4d1a1757e007c2abbda400c48b3a368289aa277 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 20:12:22 -0800 Subject: [PATCH 07/32] Incorporate 2056; Add test for legacy APIs --- flashinfer/comm/mnnvl.py | 4 + flashinfer/comm/trtllm_mnnvl_ar.py | 44 +++-- tests/comm/test_trtllm_mnnvl_allreduce.py | 227 +++++++++++++++++++--- 3 files changed, 230 insertions(+), 45 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 787f243995..6ca3a4b866 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -1033,7 +1033,11 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used +<<<<<<< HEAD comm_backend_for_handle_transfer: Communication backend for handle transfer +======= + comm_backend_for_handle_transfer: The communicator to use for handle transfer +>>>>>>> a2670e8c (Incorporate 2056; Add test for legacy APIs) """ self.mcast_device_memory = McastDeviceMemory( buf_size, diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index a9c7a026e4..82c40a7c83 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -17,7 +17,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op -from .mnnvl import McastGPUBuffer, CommBackend +from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend def mpi_barrier(): @@ -39,14 +39,18 @@ def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dty # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size -# TODO(Refactor): Consider moving this to a configuration class or file MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 class MNNVLAllreduceFusionWorkspace: NUM_LAMPORT_BUFFERS = 3 - def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None): + def __init__( + self, + mapping: Mapping, + buffer_size_in_bytes: Optional[int] = None, + comm_backend: Optional[CommBackend] = None, + ): """ Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. @@ -60,7 +64,8 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) - + if comm_backend is None: + comm_backend = MPIBackend() if buffer_size_in_bytes > (2**32 - 1): raise ValueError( f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." @@ -79,14 +84,14 @@ def __init__(self, mapping: Mapping, buffer_size_in_bytes: Optional[int] = None) mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node(), + comm_backend, ) # We use FP32 for sentinel value regardless of the real dtype self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) # Wait until the initialization is done torch.cuda.synchronize() - # FIXME: We are assuming using the COMM_WORLD. - mpi_barrier() + comm_backend.barrier() # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer @@ -373,7 +378,10 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" ) def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None + mapping: Mapping, + dtype: torch.dtype, + comm_backend_for_handle_transfer: Optional[CommBackend] = None, + buffer_size_in_bytes: Optional[int] = None, ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -402,15 +410,13 @@ def get_allreduce_mnnvl_workspace( # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = ( - buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + TARGET_WORKSPACE_SIZE_BYTES = buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 + buffer_size_in_bytes = math.ceil(TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)) * ( + lcm_hidden_dim * stride ) - buffer_size_in_bytes = math.ceil( - TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) - ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes) + workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer) mcast_buffer = workspace.mcast_buffer_handle buffer_flags = workspace.buffer_flags @@ -463,9 +469,7 @@ def trtllm_mnnvl_all_reduce( """ if len(inp.shape) != 2: - raise ValueError( - f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." - ) + raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: @@ -474,12 +478,12 @@ def trtllm_mnnvl_all_reduce( ) # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. - assert wait_for_results and (out is not None), ( - "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." - ) + assert wait_for_results and ( + out is not None + ), "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_allreduce_fusion( - input, + inp, multicast_buffer_ptr, buffer_ptrs_dev, 0, # Allreduce kernel itself does not use this local pointer; still this could be risky but it is only used for legacy code compatibility. diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index b77c5a91d1..461e1527ac 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -110,6 +110,131 @@ def func( ) +@torch.inference_mode() +def row_linear_residual_norm_fusion_forward_legacy( + x: torch.Tensor, + residual: torch.Tensor, + norm_weight: torch.Tensor, + eps: float, + hidden_size: int, + dtype: torch.dtype, + mapping: Mapping, + fusion: bool, + reference_output: tuple[torch.Tensor, ...], + multicast_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, + max_num_elements_mnnvl: int, + buffer_flags_mnnvl: torch.Tensor, +): + tensor_parallel_size = mapping.tp_size + tensor_parallel_rank = mapping.tp_rank + MPI.COMM_WORLD.barrier() + + def func( + input, + residual, + norm_weight, + eps, + enable_fusion, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ): + # For both fused and unfused cases: + shape = input.shape + input = input.view(-1, shape[-1]) + buffer_M = max_num_elements_mnnvl // hidden_size + + if enable_fusion: + use_pdl = True + + prenorm_output = torch.empty_like(residual) + normed_output = torch.empty_like(residual) + + trtllm_mnnvl_ar.mpi_barrier() + + trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output, + normed_output, + input, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + norm_weight, + eps, + residual, + use_pdl, + ) + + return normed_output.view(shape), prenorm_output.view(shape) + + else: + output = torch.empty_like(input) + + trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce( + input, + multicast_ptr, + buffer_ptrs_dev, + buffer_M, + buffer_flags_mnnvl, + tensor_parallel_size, + tensor_parallel_rank, + True, # wait_for_results + False, # launch_with_pdl + output, # Need to provide output tensor since we are writing them out. + ) + return (output.view(shape),) + + output = func( + x.clone(), + residual.clone(), + norm_weight, + eps, + fusion, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + ) + + assert output[0].shape == reference_output[0].shape + + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], + ) + + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) + print( + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], + ) + + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + + if fusion: + torch.testing.assert_close( + output[1], + reference_output[1], + rtol=0.05, + atol=0.15, + ) + + """Helper function to run the core MNNVL AllReduce test logic""" @@ -155,7 +280,13 @@ def prepare_test_data(seq_len: int, hidden_size: int, dtype: torch.dtype, fusion def run_mnnvl_ar_full( - monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int + monkeypatch, + seq_lens: list[int], + fusion: bool, + dtype: torch.dtype, + hidden_size: int, + legacy_explicit_workspace_bytes: int = None, + legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. @@ -211,14 +342,30 @@ def run_mnnvl_ar_full( failure_message = "" try: - required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( - mapping.tp_size, - max(seq_lens), - hidden_size, - dtype, - trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, - ) - workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace(mapping, required_workspace_bytes) + if legacy_api: + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace( + mapping, dtype, buffer_size_in_bytes=legacy_explicit_workspace_bytes + ) + ) + + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) + + else: + required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( + mapping.tp_size, + max(seq_lens), + hidden_size, + dtype, + trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, + ) + workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( + mapping, required_workspace_bytes + ) test_data = [] for seq_len in seq_lens: @@ -266,19 +413,34 @@ def run_mnnvl_ar_full( print( f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) ->>>>>>> bca4f5d9 (Passing the test.) - - # Run the test with the same workspace - row_linear_residual_norm_fusion_forward( - x, - residual, - norm_weight, - eps, - mapping, - fusion, - reference_output, - workspace, - ) + if legacy_api: + row_linear_residual_norm_fusion_forward_legacy( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + fusion, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + ) + else: + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + mapping, + fusion, + reference_output, + workspace, + ) # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() @@ -327,8 +489,23 @@ def run_mnnvl_ar_full( @pytest.mark.parametrize("fusion", [False, True]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [2880, 5120, 7168, 8192]) -def test_mnnvl_allreduce_default_workspace( +def test_mnnvl_allreduce_refactored( + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +): + """Test MNNVL AllReduce with refactored API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=False + ) + + +@pytest.mark.parametrize("seq_lens", [[1], [4], [15], [27, 11, 24], [127]]) +@pytest.mark.parametrize("fusion", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +def test_mnnvl_allreduce_legacy( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int ): - """Test MNNVL AllReduce with default workspace size.""" - run_mnnvl_ar_full(monkeypatch, seq_lens, fusion, dtype, hidden_size) + """Test MNNVL AllReduce with legacy API.""" + run_mnnvl_ar_full( + monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True + ) From 01564e97d007485bc971018a4ca89e6f8cb40093 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 19 Nov 2025 20:31:37 -0800 Subject: [PATCH 08/32] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 3 +-- flashinfer/comm/trtllm_mnnvl_ar.py | 3 --- tests/comm/test_trtllm_mnnvl_allreduce.py | 5 ++--- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index c7215a4241..5049344872 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -103,8 +103,7 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt status = twoshotAllreduceFusionDispatch(params); } TVM_FFI_ICHECK(status == cudaSuccess) - << "twoshot_allreduce_dispatch_world_size failed with error code " - << cudaGetErrorString(status); + << "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status); }); } diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 82c40a7c83..7770cb815e 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -192,9 +192,6 @@ def trtllm_mnnvl_allreduce_fusion( gamma: Gamma tensor (if rmsnorm) epsilon: Epsilon value (if rmsnorm) """ - print( - f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}" - ) module.trtllm_mnnvl_allreduce_fusion( input, multicast_buffer_ptr, diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 461e1527ac..cb5425f5e2 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,5 +1,6 @@ # Check torch version: -from typing import Tuple, Optional +import traceback +from typing import Tuple import pytest import torch @@ -451,8 +452,6 @@ def run_mnnvl_ar_full( rank_failed = True failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) - import traceback - print(traceback.format_exc()) # Gather failure status from all ranks for logging From 775918d5d8d6581815129ca72f7638f098ce1e72 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 20 Nov 2025 13:36:41 -0800 Subject: [PATCH 09/32] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 11 +++++++++-- flashinfer/comm/mnnvl.py | 18 ++++++++++++++++-- include/flashinfer/utils.cuh | 13 +++++++++---- tests/comm/test_trtllm_mnnvl_allreduce.py | 4 ++-- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 5049344872..dea2ddd039 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -53,11 +53,18 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt << "nranks must be between 2 and 64, got " << nranks; TVM_FFI_ICHECK(rank >= 0 && rank < nranks) << "rank must be between 0 and nranks-1, got " << rank; - TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) || + TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() && + epsilon.has_value()) || !rmsnorm_fusion) - << "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true"; + << "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is " + "true"; if (rmsnorm_fusion) { + TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens && + residual_in.value().size(1) == token_dim) + << "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1) + << ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1) + << ")"; TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens && residual_out.value().size(1) == token_dim) << "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 6ca3a4b866..b6bbdd3906 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -716,6 +716,9 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return + if hasattr(self, "_ipc_socket"): + self._ipc_socket.close() + # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): @@ -864,7 +867,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Allocate local GPU memory self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) - # Export local handle to fabric handle + # Export local handle to fabric handle or FD local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], @@ -898,6 +901,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._shareable_handle_type, ) ) + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(all_shareable_uc_handles[p]) # Initialize multicasting if self.group_rank == 0: @@ -943,7 +952,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self._shareable_handle_type, ) ) - + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ): + # Close FD after import + os.close(shareable_mc_handle) # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 20c19a0eae..8481aabf39 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -335,16 +336,20 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +// This function is thread-safe and cached the sm_count. +// But it will only check the current CUDA device, thus assuming each process handles single GPU. inline int GetCudaMultiProcessorCount() { - static int sm_count = 0; - if (sm_count == 0) { + static std::atomic sm_count{0}; + int cached = sm_count.load(std::memory_order_relaxed); + if (cached == 0) { int device_id; cudaGetDevice(&device_id); cudaDeviceProp device_prop; cudaGetDeviceProperties(&device_prop, device_id); - sm_count = device_prop.multiProcessorCount; + cached = device_prop.multiProcessorCount; + sm_count.store(cached, std::memory_order_relaxed); } - return sm_count; + return cached; } template diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index cb5425f5e2..43437faf4b 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -1,6 +1,6 @@ # Check torch version: import traceback -from typing import Tuple +from typing import Tuple, Optional import pytest import torch @@ -286,7 +286,7 @@ def run_mnnvl_ar_full( fusion: bool, dtype: torch.dtype, hidden_size: int, - legacy_explicit_workspace_bytes: int = None, + legacy_explicit_workspace_bytes: Optional[int] = None, legacy_api: bool = False, ): """Core test logic for MNNVL AllReduce operations. From 45a5b828c9f3a94ad25b5fe350bb978c16296431 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 21 Nov 2025 14:37:32 -0800 Subject: [PATCH 10/32] Address review comments. --- csrc/trtllm_mnnvl_allreduce.cu | 2 - flashinfer/comm/trtllm_mnnvl_ar.py | 83 +++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index dea2ddd039..e1c998d8ea 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -26,8 +26,6 @@ using tvm::ffi::Optional; } \ }() -// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this -// level void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev, int64_t buffer_ptr_local, TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 7770cb815e..4a69b83d8a 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -33,9 +33,18 @@ class MNNVLAllreduceFusionStrategy(Enum): AUTO = 99 @staticmethod +<<<<<<< HEAD def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: +======= + def select_strategy( + tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype + ) -> "MNNVLAllreduceFusionStrategy": +>>>>>>> c6ed1472 (Address review comments.) elem_size = torch.tensor([], dtype=dtype).element_size() - return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD + if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: + return MNNVLAllreduceFusionStrategy.ONESHOT + else: + return MNNVLAllreduceFusionStrategy.TWOSHOT # Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size @@ -52,15 +61,15 @@ def __init__( comm_backend: Optional[CommBackend] = None, ): """ - Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD. + Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. Args: mapping: Mapping configuration containing rank info buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. """ if buffer_size_in_bytes is None: - # Default to 16MB workspace size if not provided - buffer_size_in_bytes = 16 * (1024**2) + # Default to 512MB workspace size if not provided + buffer_size_in_bytes = 512 * (1024**2) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) @@ -108,7 +117,28 @@ def __init__( self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() + @functools.cache + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, + ) -> bool: + """ + Calculate the required buffer size for a given problem size. + """ + required_buffer_size = self.get_required_buffer_size_bytes( + tp_size, num_tokens, hidden_dim, dtype, strategy + ) + if required_buffer_size > self.buffer_size_bytes: + return False + else: + return True + @staticmethod + @functools.cache def get_required_buffer_size_bytes( tp_size: int, num_tokens: int, @@ -120,10 +150,19 @@ def get_required_buffer_size_bytes( Calculate the required buffer size for a given problem size. """ elem_size = torch.tensor([], dtype=dtype).element_size() +<<<<<<< HEAD is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot ): +======= + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + tp_size, num_tokens, hidden_dim, dtype + ) + + if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: +>>>>>>> c6ed1472 (Address review comments.) # For one-shot, each rank needs to store num_tokens * tp_size tokens buffer_size = num_tokens * hidden_dim * tp_size * elem_size else: @@ -256,10 +295,25 @@ def trtllm_mnnvl_allreduce( module = get_trtllm_mnnvl_comm_module() +<<<<<<< HEAD use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( strategy == MNNVLAllreduceFusionStrategy.AUTO and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) ) +======= + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." + ) + +>>>>>>> c6ed1472 (Address review comments.) module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -270,7 +324,7 @@ def trtllm_mnnvl_allreduce( workspace.rank, False, # No RMSNorm Fusion launch_with_pdl, - use_oneshot, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, output, None, None, @@ -340,15 +394,16 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( module = get_trtllm_mnnvl_comm_module() - use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot( - workspace.tp_size, - input.shape[0], - input.shape[1], - input.dtype, + if strategy == MNNVLAllreduceFusionStrategy.AUTO: + strategy = MNNVLAllreduceFusionStrategy.select_strategy( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype + ) + if not workspace.is_buffer_size_sufficient( + workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy + ): + raise ValueError( + f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." ) - ) module.trtllm_mnnvl_allreduce_fusion( input, @@ -360,7 +415,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( workspace.rank, True, # RMSNorm Fusion launch_with_pdl, - use_oneshot, + strategy == MNNVLAllreduceFusionStrategy.ONESHOT, output, residual_out, residual_in, From 815aaf33dc6f3479db57e1c066a679e827fd2ca2 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Fri, 21 Nov 2025 15:20:35 -0800 Subject: [PATCH 11/32] Rounding up workspace size according to allocation (page size). --- flashinfer/comm/mnnvl.py | 19 ++++++++++++++----- flashinfer/comm/trtllm_mnnvl_ar.py | 30 +++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index b6bbdd3906..66c47e4c9c 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -803,6 +803,14 @@ def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size + def get_allocation_size(self) -> int: + """Get the total allocation size (including signal pad)""" + return self.allocation_size + + def get_usable_buffer_size(self) -> int: + """Get the usable buffer size (excluding signal pad)""" + return self.allocation_size - self.SIGNAL_PAD_SIZE + def _init_ipc_socket(self): if self.group_rank == 0: # Gnerate the opId @@ -838,7 +846,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int): alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, - cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, + cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED, ) ) @@ -1015,8 +1023,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): else: raise ValueError(f"Unsupported dtype: {dtype}") - # Calculate number of elements that fit in allocation_size - num_elements = self.allocation_size // dsize + # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. + num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) @@ -1042,7 +1050,7 @@ def __init__( Constructor for McastGpuBuffer. Args: - buf_size: The total size of the buffer in bytes + buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements. group_size: The number of ranks in the communication group group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation @@ -1061,7 +1069,8 @@ def __init__( mn_nvlink, comm_backend_for_handle_transfer, ) - self.buf_size = buf_size + # Update buf_size to reflect the actual usable buffer size after allocation + self.buf_size = self.mcast_device_memory.get_usable_buffer_size() self.local_device = device def lamport_initialize(self, rank: int, dtype: torch.dtype): diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 4a69b83d8a..4244e00aa4 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -65,11 +65,11 @@ def __init__( Args: mapping: Mapping configuration containing rank info - buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes. + buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. """ if buffer_size_in_bytes is None: - # Default to 512MB workspace size if not provided - buffer_size_in_bytes = 512 * (1024**2) + # Default to 16MB workspace size if not provided + buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) @@ -80,15 +80,18 @@ def __init__( f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)." ) - self.buffer_size_bytes = buffer_size_in_bytes - self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + # Calculate total requested workspace size + requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS + self.rank = mapping.tp_rank self.tp_size = mapping.tp_size logging.debug( - f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes." + f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer." ) + + # Allocate the workspace self.mcast_buffer_handle = McastGPUBuffer( - self.workspace_size_bytes, + requested_workspace_size, mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), @@ -96,6 +99,19 @@ def __init__( comm_backend, ) + # Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer) + allocated_size = self.mcast_buffer_handle.buf_size + # We want the buffer size to be aligned to 16B which is the granularity for buffer management. + self.buffer_size_bytes = ( + math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 + ) + # This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size. + self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS + + logging.debug( + f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes." + ) + # We use FP32 for sentinel value regardless of the real dtype self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) # Wait until the initialization is done From 68a9b9b8ef33705d293d1e5e3030e0cc6d1aa45c Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 15:56:04 -0800 Subject: [PATCH 12/32] Fix rebasing errors. --- flashinfer/comm/mnnvl.py | 233 +++++++++++++++++++++-------- flashinfer/comm/trtllm_mnnvl_ar.py | 64 ++++---- 2 files changed, 204 insertions(+), 93 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 66c47e4c9c..3128a9874a 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -41,7 +41,8 @@ from cuda import cuda except ImportError as e: raise ImportError( - "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." + "Could not import the 'cuda' module. " + "Please install cuda-python that matches your CUDA version." ) from e from ..cuda_utils import checkCudaErrors @@ -62,7 +63,9 @@ def round_up(val: int, gran: int) -> int: return (val + gran - 1) & ~(gran - 1) -def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, device_id: int) -> torch.Tensor: +def create_tensor_from_cuda_memory( + ptr: int, shape: tuple, dtype: torch.dtype, device_id: int +) -> torch.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. @@ -84,7 +87,9 @@ def create_tensor_from_cuda_memory(ptr: int, shape: tuple, dtype: torch.dtype, d element_size = torch.tensor([], dtype=dtype).element_size() # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) - capsule_wrapper = create_dlpack_capsule(ptr, element_size, element_size, numel, dtype, device_id) + capsule_wrapper = create_dlpack_capsule( + ptr, element_size, element_size, numel, dtype, device_id + ) # Convert to tensor and reshape tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) @@ -136,7 +141,9 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> int: size_in_bytes = ctypes.sizeof(c_array) device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) - checkCudaErrors(cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)) + checkCudaErrors( + cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) + ) # c_array should be freed by GC return int(device_ptr) @@ -292,14 +299,18 @@ def initialize(): @staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] - comm = config.comm_backend.Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) + comm = config.comm_backend.Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) MnnvlMemory.comm = comm # type: ignore[assignment] @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm - comm = MpiComm().Split(mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank) + comm = MpiComm().Split( + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank + ) MnnvlMemory.comm = comm return comm @@ -315,7 +326,9 @@ def get_allocation_prop(dev_id: int): arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch if is_on_aarch64: - allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + allocation_prop.requestedHandleTypes = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) else: allocation_prop.requestedHandleTypes = ( cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR @@ -331,19 +344,27 @@ def get_allocation_granularity(dev_id: int): option = cuda.CUmemAllocationGranularity_flags( cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ) - granularity = checkCudaErrors(cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)) + granularity = checkCudaErrors( + cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option) + ) MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): - page_count = (size + MnnvlMemory.fabric_page_size - 1) // MnnvlMemory.fabric_page_size + page_count = ( + size + MnnvlMemory.fabric_page_size - 1 + ) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size - logging.info(f"[MnnvlMemory] creating address with stride={current_rank_stride}") + logging.info( + f"[MnnvlMemory] creating address with stride={current_rank_stride}" + ) comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size - ptr = checkCudaErrors(cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0)) + ptr = checkCudaErrors( + cuda.cuMemAddressReserve(address_size, MnnvlMemory.fabric_page_size, 0, 0) + ) MnnvlMemory.current_start_address = int(ptr) MnnvlMemory.current_rank_stride = current_rank_stride MnnvlMemory.current_mem_offset = 0 @@ -354,29 +375,44 @@ def open_mnnvl_memory(mapping: Mapping, size: int): dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id - assert ( - dev_id == MnnvlMemory.dev_id - ), f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" + assert dev_id == MnnvlMemory.dev_id, ( + f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" + ) comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size - assert all(x == size for x in all_rank_allocate_sizes), "Not all rank allocating same size." + assert all(x == size for x in all_rank_allocate_sizes), ( + "Not all rank allocating same size." + ) granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity - if MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride: + if ( + MnnvlMemory.current_mem_offset + aligned_size + > MnnvlMemory.current_rank_stride + ): MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) - assert MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride + assert ( + MnnvlMemory.current_mem_offset + aligned_size + <= MnnvlMemory.current_rank_stride + ) allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) - allocated_mem_handle = checkCudaErrors(cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)) + allocated_mem_handle = checkCudaErrors( + cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) + ) exported_fabric_handle = checkCudaErrors( - cuda.cuMemExportToShareableHandle(allocated_mem_handle, allocation_prop.requestedHandleTypes, 0) + cuda.cuMemExportToShareableHandle( + allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 + ) ) - if allocation_prop.requestedHandleTypes == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + allocation_prop.requestedHandleTypes + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): all_handles_data = comm.allgather(exported_fabric_handle.data) else: all_handles_data = comm.allgather(exported_fabric_handle) @@ -390,7 +426,9 @@ def open_mnnvl_memory(mapping: Mapping, size: int): pidfd = syscall(SYS_pidfd_open, pid, 0) if pidfd < 0: err = ctypes.get_errno() - raise RuntimeError(f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}") + raise RuntimeError( + f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" + ) pidfds.append(pidfd) remote_fds = [] @@ -405,7 +443,9 @@ def open_mnnvl_memory(mapping: Mapping, size: int): "to your docker run command." ) else: - error_msg += " This may be due to kernel version (requires Linux 5.6+)." + error_msg += ( + " This may be due to kernel version (requires Linux 5.6+)." + ) raise RuntimeError(error_msg) remote_fds.append(remote_fd) @@ -421,19 +461,27 @@ def open_mnnvl_memory(mapping: Mapping, size: int): for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( - MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset + MnnvlMemory.current_start_address + + MnnvlMemory.current_rank_stride * i + + MnnvlMemory.current_mem_offset ) if i == comm_rank: # Local memory mapping mem_handles[i] = allocated_mem_handle - checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0)) + checkCudaErrors( + cuda.cuMemMap(rank_ptr, aligned_size, 0, allocated_mem_handle, 0) + ) else: # Fabric memory mapping imported_mem_handle = checkCudaErrors( - cuda.cuMemImportFromShareableHandle(remote_handle_data, allocation_prop.requestedHandleTypes) + cuda.cuMemImportFromShareableHandle( + remote_handle_data, allocation_prop.requestedHandleTypes + ) ) mem_handles[i] = imported_mem_handle - checkCudaErrors(cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0)) + checkCudaErrors( + cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) + ) checkCudaErrors(cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)) @@ -490,14 +538,20 @@ def support_nvlink(need_all_up: bool = True): available_links = 0 for link_idx in range(link_count): try: - if pynvml.nvmlDeviceGetNvLinkCapability(handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED): + if pynvml.nvmlDeviceGetNvLinkCapability( + handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED + ): available_links += 1 is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx) if is_active: active_links += 1 except pynvml.NVMLError_NotSupported: continue - return active_links == available_links and available_links > 0 if need_all_up else available_links > 0 + return ( + active_links == available_links and available_links > 0 + if need_all_up + else available_links > 0 + ) @staticmethod def supports_mnnvl() -> bool: @@ -585,14 +639,18 @@ def recv_fd(self): fds = array.array("i") msg, ancdata, flags, addr = self.sock.recvmsg( 1, - socket.CMSG_SPACE(fds.itemsize), # Buffer size for dummy data # Ancillary data size + socket.CMSG_SPACE( + fds.itemsize + ), # Buffer size for dummy data # Ancillary data size ) # Extract file descriptor from ancillary data for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: fds = array.array("i") - fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + fds.frombytes( + cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)] + ) return fds[0] raise RuntimeError("No file descriptor received") @@ -617,7 +675,6 @@ def __init__( device_idx: int, is_multi_node: bool = True, comm_backend_for_handle_transfer: Optional[CommBackend] = None, - comm_backend_for_handle_transfer: Optional[CommBackend] = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -653,7 +710,9 @@ def __init__( self.signal_pads_dev = 0 # std::vector mSignalPadsDev self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle - self.uc_handles: List[int] = [] # std::vector mUcHandles + self.uc_handles: List[ + int + ] = [] # std::vector mUcHandles self._shareable_handle_type = None @@ -669,7 +728,9 @@ def __init__( ) ) if multicast_supported == 0: - raise RuntimeError("[McastDeviceMemory] Device does not support multicasting.") + raise RuntimeError( + "[McastDeviceMemory] Device does not support multicasting." + ) # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) @@ -689,13 +750,19 @@ def __init__( ) ) if fabric_handle_supported == 0: - raise RuntimeError("[McastDeviceMemory] Device does not support fabric handle.") + raise RuntimeError( + "[McastDeviceMemory] Device does not support fabric handle." + ) # Use fabric handle for multi-node NVLS - self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ) else: self._init_ipc_socket() # Use NVLink handle for single-node NVLS - self._shareable_handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._shareable_handle_type = ( + cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + ) self._alloc_mn_mcast_mem(buf_size) # Initialize signal pads @@ -703,7 +770,9 @@ def __init__( for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: - checkCudaErrors(cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)) + checkCudaErrors( + cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) + ) # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) @@ -746,19 +815,29 @@ def __del__(self): checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: - checkCudaErrors(cuda.cuMemUnmap(self.uc_ptrs[rank], self.allocation_size)) + checkCudaErrors( + cuda.cuMemUnmap( + self.uc_ptrs[rank], self.allocation_size + ) + ) except Exception as e: - print(f"Destructor: Failed to release UC handle for rank {rank}: {e}") + print( + f"Destructor: Failed to release UC handle for rank {rank}: {e}" + ) # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: - checkCudaErrors(cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size)) + checkCudaErrors( + cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) + ) # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) - checkCudaErrors(cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size)) + checkCudaErrors( + cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) + ) checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") @@ -828,7 +907,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): current_device = checkCudaErrors(cuda.cuCtxGetDevice()) if int(current_device) != self.device_idx: - print(f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}") + print( + f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" + ) except Exception as e: print(f"Error checking CUDA context: {e}") @@ -837,7 +918,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): allocation_prop.requestedHandleTypes = self._shareable_handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() - allocation_prop.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + allocation_prop.location.type = ( + cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + ) allocation_prop.location.id = self.device_idx allocation_prop.allocFlags.gpuDirectRDMACapable = 1 @@ -851,7 +934,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); - self.allocation_size = round_up(buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity) + self.allocation_size = round_up( + buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity + ) # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() @@ -873,7 +958,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self.uc_handles = [0] * self.group_size # Allocate local GPU memory - self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)) + self.uc_handles[self.group_rank] = checkCudaErrors( + cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) + ) # Export local handle to fabric handle or FD local_shareable_uc_handle = checkCudaErrors( @@ -884,9 +971,14 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) - if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): # All-gather fabric handles - all_shareable_uc_handles = self.comm_backend.allgather(local_shareable_uc_handle.data) + all_shareable_uc_handles = self.comm_backend.allgather( + local_shareable_uc_handle.data + ) else: # Implement the allgather logic with ipc socket all_shareable_uc_handles = [None] * self.group_size @@ -931,7 +1023,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) else: shareable_mc_handle = None - if self._shareable_handle_type == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC: + if ( + self._shareable_handle_type + == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + ): # Broadcast multicast handle shareable_mc_handle = self.comm_backend.bcast( shareable_mc_handle.data if shareable_mc_handle else None, root=0 @@ -975,7 +1070,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors(cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0)) + uc_base_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) + ) self.uc_base_ptr = uc_base_ptr # Store for cleanup # Set up memory access descriptor @@ -989,15 +1086,27 @@ def _alloc_mn_mcast_mem(self, buf_size: int): for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors(cuda.cuMemMap(self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0)) + checkCudaErrors( + cuda.cuMemMap( + self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 + ) + ) # Set memory access permissions - checkCudaErrors(cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1)) + checkCudaErrors( + cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) + ) # Bind MC pointer - self.mc_ptr = checkCudaErrors(cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0)) - checkCudaErrors(cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0)) - checkCudaErrors(cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1)) + self.mc_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + ) + checkCudaErrors( + cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) + ) + checkCudaErrors( + cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) + ) # Bind memory to multicast checkCudaErrors( @@ -1026,7 +1135,9 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype): # Calculate number of elements that fit in allocation_size; We don't want to include the signal pad. num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize - checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)) + checkCudaErrors( + memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) + ) class McastGPUBuffer: @@ -1055,11 +1166,7 @@ def __init__( group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used -<<<<<<< HEAD comm_backend_for_handle_transfer: Communication backend for handle transfer -======= - comm_backend_for_handle_transfer: The communicator to use for handle transfer ->>>>>>> a2670e8c (Incorporate 2056; Add test for legacy APIs) """ self.mcast_device_memory = McastDeviceMemory( buf_size, @@ -1076,7 +1183,9 @@ def __init__( def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) - def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: + def get_multicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. @@ -1092,7 +1201,9 @@ def get_multicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: # FIXME: Is this needed? As the behavior of reading from mc_ptr is undefined. raise NotImplementedError("Not implemented yet") - def get_unicast_buffer(self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0) -> torch.Tensor: + def get_unicast_buffer( + self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 + ) -> torch.Tensor: """ Returns a PyTorch tensor view of the unicast buffer portion. """ diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 4244e00aa4..afdd580910 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -33,13 +33,9 @@ class MNNVLAllreduceFusionStrategy(Enum): AUTO = 99 @staticmethod -<<<<<<< HEAD - def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool: -======= def select_strategy( tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype ) -> "MNNVLAllreduceFusionStrategy": ->>>>>>> c6ed1472 (Address review comments.) elem_size = torch.tensor([], dtype=dtype).element_size() if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD: return MNNVLAllreduceFusionStrategy.ONESHOT @@ -72,7 +68,9 @@ def __init__( buffer_size_in_bytes = 16 * (1024**2) else: # Round up to the nearest multiple of 8MB - buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2)) + buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * ( + 8 * (1024**2) + ) if comm_backend is None: comm_backend = MPIBackend() if buffer_size_in_bytes > (2**32 - 1): @@ -166,25 +164,20 @@ def get_required_buffer_size_bytes( Calculate the required buffer size for a given problem size. """ elem_size = torch.tensor([], dtype=dtype).element_size() -<<<<<<< HEAD - is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype) - if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot - ): -======= if strategy == MNNVLAllreduceFusionStrategy.AUTO: strategy = MNNVLAllreduceFusionStrategy.select_strategy( tp_size, num_tokens, hidden_dim, dtype ) if strategy == MNNVLAllreduceFusionStrategy.ONESHOT: ->>>>>>> c6ed1472 (Address review comments.) # For one-shot, each rank needs to store num_tokens * tp_size tokens buffer_size = num_tokens * hidden_dim * tp_size * elem_size else: # For two-shot, each rank stores a slices of tokens. We need to round up to the nearest tp_size. # 2 Stage is required for the two-shot allreduce. - buffer_size = 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + buffer_size = ( + 2 * math.ceil(num_tokens / tp_size) * tp_size * hidden_dim * elem_size + ) return buffer_size @@ -302,21 +295,19 @@ def trtllm_mnnvl_allreduce( # Check ndims here as the shape check is done in the kernel launch code. if len(input.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) if output is None: output = torch.empty_like(input) elif len(output.shape) != 2: - raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) module = get_trtllm_mnnvl_comm_module() -<<<<<<< HEAD - use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or ( - strategy == MNNVLAllreduceFusionStrategy.AUTO - and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype) - ) -======= if strategy == MNNVLAllreduceFusionStrategy.AUTO: strategy = MNNVLAllreduceFusionStrategy.select_strategy( workspace.tp_size, input.shape[0], input.shape[1], input.dtype @@ -329,7 +320,6 @@ def trtllm_mnnvl_allreduce( f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes." ) ->>>>>>> c6ed1472 (Address review comments.) module.trtllm_mnnvl_allreduce_fusion( input, workspace.mc_ptr, @@ -388,7 +378,9 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( epsilon = torch.finfo(input.dtype).eps if len(input.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(input.shape)}D. The shape is {input.shape}." + ) if len(residual_in.shape) != 2: raise ValueError( f"The residual input tensor must be 2D, got {len(residual_in.shape)}D. The shape is {residual_in.shape}." @@ -400,7 +392,9 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( if output is None: output = torch.empty_like(input) elif len(output.shape) != 2: - raise ValueError(f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}.") + raise ValueError( + f"The output tensor must be 2D, got {len(output.shape)}D. The shape is {output.shape}." + ) if residual_out is None: residual_out = torch.empty_like(residual_in) elif len(residual_out.shape) != 2: @@ -478,13 +472,17 @@ def get_allreduce_mnnvl_workspace( # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 - TARGET_WORKSPACE_SIZE_BYTES = buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 - buffer_size_in_bytes = math.ceil(TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)) * ( - lcm_hidden_dim * stride + TARGET_WORKSPACE_SIZE_BYTES = ( + buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 ) + buffer_size_in_bytes = math.ceil( + TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) + ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace(mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer) + workspace = MNNVLAllreduceFusionWorkspace( + mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer + ) mcast_buffer = workspace.mcast_buffer_handle buffer_flags = workspace.buffer_flags @@ -537,7 +535,9 @@ def trtllm_mnnvl_all_reduce( """ if len(inp.shape) != 2: - raise ValueError(f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}.") + raise ValueError( + f"The input tensor must be 2D, got {len(inp.shape)}D. The shape is {inp.shape}." + ) # buffer_M is no longer used in this kernel but let's keep this check for consistency in behavior. if inp.shape[0] > buffer_M: @@ -546,9 +546,9 @@ def trtllm_mnnvl_all_reduce( ) # Even in legacy code, this should only be used when we implement the fused allreduce+rmsnorm. - assert wait_for_results and ( - out is not None - ), "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + assert wait_for_results and (out is not None), ( + "Calling the legacy trtllm_mnnvl_all_reduce with wait_for_results=False is not supported. Please use trtllm_mnnvl_allreduce instead." + ) module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_allreduce_fusion( inp, From 9e11752cbe5872a0be82928e30e735ed2eae85fd Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 16:20:14 -0800 Subject: [PATCH 13/32] Fix rebase errors. --- tests/comm/test_trtllm_mnnvl_allreduce.py | 51 ++--------------------- 1 file changed, 3 insertions(+), 48 deletions(-) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 43437faf4b..cf93b1af6c 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -8,7 +8,6 @@ import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping -from flashinfer.comm.mnnvl import CommBackend, MpiComm # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -26,15 +25,7 @@ def row_linear_residual_norm_fusion_forward( workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank -<<<<<<< HEAD - if comm_backend_for_handle_transfer is None: - comm = MpiComm() - else: - comm = comm_backend_for_handle_transfer - comm.barrier() -======= MPI.COMM_WORLD.barrier() ->>>>>>> bca4f5d9 (Passing the test.) def func( input, @@ -322,19 +313,12 @@ def run_mnnvl_ar_full( torch.cuda.set_device(mapping.local_rank) if mapping.local_rank == 0: -<<<<<<< HEAD - print(f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks") - print(f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}") - - tensor_parallel_size = world_size -======= print( f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" ) print( f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" ) ->>>>>>> bca4f5d9 (Passing the test.) eps = 1e-5 torch.manual_seed(42 + rank) @@ -380,37 +364,6 @@ def run_mnnvl_ar_full( # Test each sequence length with the same workspace (reusing allocated buffers within this list) for seq_len, x, residual, norm_weight, reference_output in test_data: if rank == 0: -<<<<<<< HEAD - print(f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}") - print(f"[Rank {rank}] Buffer flags: {workspace.buffer_flags}") - - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), - dtype=dtype, - device=torch.device("cuda"), - ) - residual = torch.randn((seq_len, hidden_size), dtype=dtype, device=torch.device("cuda")) - norm_weight = torch.randn((hidden_size,), dtype=dtype, device=torch.device("cuda")) - - # Each rank gets its slice of the input - x = x_full[rank, :, :] - - # Compute reference output based on fusion mode - reference_output: Tuple[torch.Tensor, ...] = None - if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual - print("Device of residual_out:{}, norm_weight:{}".format(residual_out.device, norm_weight.device)) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - - reference_output = (norm_out, residual_out) - else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - reference_output = (allreduce_result,) -======= print( f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) @@ -446,7 +399,9 @@ def run_mnnvl_ar_full( # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}") + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" + ) except Exception as e: rank_failed = True From 4a5faeff6509c97d454bfce7a9369565f93dcd79 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 26 Nov 2025 16:24:27 -0800 Subject: [PATCH 14/32] Refactor mcast device memory. --- flashinfer/comm/mnnvl.py | 314 +++++++++++++++++++++++---------------- 1 file changed, 185 insertions(+), 129 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 3128a9874a..13ca4f534d 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -663,6 +663,107 @@ def close(self): os.unlink(self.socket_path) +class HandleExchanger(ABC): + """Abstract interface for exchanging CUDA shareable handles across ranks.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + self.comm = comm_backend + self.rank = group_rank + self.size = group_size + + @property + @abstractmethod + def handle_type(self) -> cuda.CUmemAllocationHandleType: + """The CUDA handle type this exchanger works with.""" + ... + + @abstractmethod + def allgather(self, local_handle) -> List: + """All-gather shareable handles from all ranks.""" + ... + + @abstractmethod + def broadcast(self, handle, root: int): + """Broadcast a handle from root to all ranks.""" + ... + + @abstractmethod + def cleanup(self, handle) -> None: ... + + @abstractmethod + def close(self) -> None: ... + + +class FabricHandleExchanger(HandleExchanger): + """Handle exchange using CUDA Fabric handles via MPI/collective backend.""" + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + + def allgather(self, local_handle) -> List: + return self.comm.allgather(local_handle.data) + + def broadcast(self, handle, root: int): + return self.comm.bcast(handle.data if handle else None, root=root) + + def cleanup(self, handle) -> None: + pass # No cleanup needed for Fabric handles. + + def close(self) -> None: + pass # No close needed for Fabric handles. + + +class PosixFDHandleExchanger(HandleExchanger): + """Handle exchange using POSIX file descriptors via IPC sockets.""" + + def __init__(self, comm_backend: "CommBackend", group_rank: int, group_size: int): + super().__init__(comm_backend, group_rank, group_size) + self._socket = self._init_ipc_socket() + + def _init_ipc_socket(self) -> IpcSocket: + if self.rank == 0: + opId = random.randint(0, 2**64 - 1) + else: + opId = None + opId = self.comm.bcast(opId, root=0) + return IpcSocket(self.rank, opId) + + @property + def handle_type(self) -> cuda.CUmemAllocationHandleType: + return cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + + def allgather(self, local_handle) -> List: + result = [None] * self.size + for i in range(self.size): + self.comm.barrier() + self._socket.send_fd(local_handle, (self.rank + i) % self.size) + src = (self.rank + self.size - i) % self.size + result[src] = self._socket.recv_fd() + return result + + def broadcast(self, handle, root: int): + if self.rank == root: + for p in range(1, self.size): + self.comm.barrier() + self._socket.send_fd(handle, p) + return handle + else: + # Ordered receive to avoid race condition + for _ in range(self.rank): + self.comm.barrier() + result = self._socket.recv_fd() + for _ in range(self.size - self.rank - 1): + self.comm.barrier() + return result + + def cleanup(self, handle) -> None: + os.close(handle) + + def close(self) -> None: + self._socket.close() + + # TODO: This class follows similar logic with MnnvlMemory, but the latter use single instance mode to manage the memory allocation. class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" @@ -714,8 +815,6 @@ def __init__( int ] = [] # std::vector mUcHandles - self._shareable_handle_type = None - # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE @@ -741,6 +840,7 @@ def __init__( f"Signal pad offset: {self.signal_pad_offset}" ) + # Create handle exchanger based on multi-node mode if self.is_multi_node: # Check if fabric handle is supported fabric_handle_supported = checkCudaErrors( @@ -753,15 +853,12 @@ def __init__( raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) - # Use fabric handle for multi-node NVLS - self._shareable_handle_type = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC + self._exchanger: HandleExchanger = FabricHandleExchanger( + self.comm_backend, self.group_rank, self.group_size ) else: - self._init_ipc_socket() - # Use NVLink handle for single-node NVLS - self._shareable_handle_type = ( - cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR + self._exchanger = PosixFDHandleExchanger( + self.comm_backend, self.group_rank, self.group_size ) self._alloc_mn_mcast_mem(buf_size) @@ -785,8 +882,8 @@ def __del__(self): if not hasattr(self, "is_multi_node"): return - if hasattr(self, "_ipc_socket"): - self._ipc_socket.close() + if hasattr(self, "_exchanger"): + self._exchanger.close() # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. @@ -890,22 +987,23 @@ def get_usable_buffer_size(self) -> int: """Get the usable buffer size (excluding signal pad)""" return self.allocation_size - self.SIGNAL_PAD_SIZE - def _init_ipc_socket(self): - if self.group_rank == 0: - # Gnerate the opId - opId = random.randint(0, 2**64 - 1) - else: - opId = None - opId = self.comm_backend.bcast(opId, root=0) - self._ipc_socket = IpcSocket(self.group_rank, opId) - def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" + self._verify_cuda_context() + + # Compute allocation size and get allocation properties + allocation_prop, mc_prop = self._get_allocation_prop(buf_size) + + # Allocate, exchange, and map unicast buffers + self._allocate_unicast_buffers(allocation_prop) + + # Setup multicast object, exchange handles, map and bind memory + self._setup_multicast(mc_prop) - # Verify CUDA context + def _verify_cuda_context(self): + """Verify CUDA context is set to the correct device.""" try: current_device = checkCudaErrors(cuda.cuCtxGetDevice()) - if int(current_device) != self.device_idx: print( f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" @@ -913,16 +1011,16 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Set up allocation properties + def _get_allocation_prop(self, buf_size: int): + """Compute allocation size and return allocation/multicast properties.""" allocation_prop = cuda.CUmemAllocationProp() - allocation_prop.requestedHandleTypes = self._shareable_handle_type + allocation_prop.requestedHandleTypes = self._exchanger.handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() allocation_prop.location.type = ( cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE ) allocation_prop.location.id = self.device_idx - allocation_prop.allocFlags.gpuDirectRDMACapable = 1 # Get allocation granularity @@ -933,7 +1031,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) - # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); self.allocation_size = round_up( buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity ) @@ -942,18 +1039,21 @@ def _alloc_mn_mcast_mem(self, buf_size: int): mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size - mc_prop.handleTypes = self._shareable_handle_type + mc_prop.handleTypes = self._exchanger.handle_type - # Get multicast granularity - mc_granularity = checkCudaErrors( + # Get multicast granularity and adjust allocation size + self._mc_granularity = checkCudaErrors( cuda.cuMulticastGetGranularity( mc_prop, cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, ) ) + self.allocation_size = round_up(self.allocation_size, self._mc_granularity) - self.allocation_size = round_up(self.allocation_size, mc_granularity) + return allocation_prop, mc_prop + def _allocate_unicast_buffers(self, allocation_prop): + """Allocate local UC memory, exchange handles with peers, and map memory.""" # Initialize UC handles list self.uc_handles = [0] * self.group_size @@ -962,34 +1062,17 @@ def _alloc_mn_mcast_mem(self, buf_size: int): cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) - # Export local handle to fabric handle or FD + # Export local handle to shareable handle local_shareable_uc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], - self._shareable_handle_type, + self._exchanger.handle_type, 0, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): - # All-gather fabric handles - all_shareable_uc_handles = self.comm_backend.allgather( - local_shareable_uc_handle.data - ) - else: - # Implement the allgather logic with ipc socket - all_shareable_uc_handles = [None] * self.group_size - for i in range(self.group_size): - self.comm_backend.barrier() - # Send to peer at offset i - dest_rank = (self.group_rank + i) % self.group_size - self._ipc_socket.send_fd(local_shareable_uc_handle, dest_rank) - # Receive from peer at offset -i - src_rank = (self.group_rank + self.group_size - i) % self.group_size - all_shareable_uc_handles[src_rank] = self._ipc_socket.recv_fd() + # All-gather shareable handles + all_shareable_uc_handles = self._exchanger.allgather(local_shareable_uc_handle) cuda.cuCtxSynchronize() # Import remote handles @@ -998,117 +1081,81 @@ def _alloc_mn_mcast_mem(self, buf_size: int): self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( all_shareable_uc_handles[p], - self._shareable_handle_type, + self._exchanger.handle_type, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR - ): - # Close FD after import - os.close(all_shareable_uc_handles[p]) + self._exchanger.cleanup(all_shareable_uc_handles[p]) + + # Reserve address space for UC pointers + self.uc_ptrs = [0] * self.group_size + total_uc_size = self.allocation_size * self.group_size + self.total_uc_size = total_uc_size + uc_base_ptr = checkCudaErrors( + cuda.cuMemAddressReserve(total_uc_size, self._mc_granularity, 0, 0) + ) + self.uc_base_ptr = uc_base_ptr + + # Map UC memory + for i in range(self.group_size): + offset = self.allocation_size * i + self.uc_ptrs[i] = int(uc_base_ptr) + offset + checkCudaErrors( + cuda.cuMemMap( + self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 + ) + ) - # Initialize multicasting + # Set memory access permissions for UC + access_desc = self._get_mem_access_desc() + checkCudaErrors( + cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) + ) + + def _setup_multicast(self, mc_prop): + """Create multicast object, exchange handle, map memory, and bind.""" + # Rank 0 creates the multicast object if self.group_rank == 0: - # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) - - # Export multicast handle, there's only one handle for the entire group shareable_mc_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, - self._shareable_handle_type, + self._exchanger.handle_type, 0, ) ) else: shareable_mc_handle = None - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC - ): - # Broadcast multicast handle - shareable_mc_handle = self.comm_backend.bcast( - shareable_mc_handle.data if shareable_mc_handle else None, root=0 - ) - else: - # Implement bcast logic with ipc socket - if self.group_rank == 0: - for p in range(1, self.group_size): - self.comm_backend.barrier() - self._ipc_socket.send_fd(shareable_mc_handle, p) - else: - # Other ranks receive from rank 0 - # We need to order the receive to avoid a race condition bug we encountered. If driver fixed this issue, the additional barriers used for ordering can be removed. - for _ in range(self.group_rank): - self.comm_backend.barrier() - shareable_mc_handle = self._ipc_socket.recv_fd() - for _ in range(self.group_size - self.group_rank - 1): - self.comm_backend.barrier() - # Sync device to ensure broadcast is complete + + # Broadcast multicast handle from rank 0 + shareable_mc_handle = self._exchanger.broadcast(shareable_mc_handle, root=0) cuda.cuCtxSynchronize() + # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( shareable_mc_handle, - self._shareable_handle_type, + self._exchanger.handle_type, ) ) - if ( - self._shareable_handle_type - == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR - ): - # Close FD after import - os.close(shareable_mc_handle) + self._exchanger.cleanup(shareable_mc_handle) + # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) - # Bind memory addresses - self.uc_ptrs = [0] * self.group_size - - # Reserve address space for UC pointers - total_uc_size = self.allocation_size * self.group_size - self.total_uc_size = total_uc_size - uc_base_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) - ) - self.uc_base_ptr = uc_base_ptr # Store for cleanup - - # Set up memory access descriptor - access_desc = cuda.CUmemAccessDesc() - access_desc.location = cuda.CUmemLocation() - access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE - access_desc.location.id = self.device_idx - access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE - - # Map UC memory - for i in range(self.group_size): - offset = self.allocation_size * i - self.uc_ptrs[i] = int(uc_base_ptr) + offset - checkCudaErrors( - cuda.cuMemMap( - self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 - ) - ) - - # Set memory access permissions - checkCudaErrors( - cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) - ) - - # Bind MC pointer + # Reserve and map MC pointer self.mc_ptr = checkCudaErrors( - cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) + cuda.cuMemAddressReserve(self.allocation_size, self._mc_granularity, 0, 0) ) checkCudaErrors( cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) ) + access_desc = self._get_mem_access_desc() checkCudaErrors( cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) ) - # Bind memory to multicast + # Bind local memory to multicast checkCudaErrors( cuda.cuMulticastBindMem( self.mc_handle, @@ -1120,6 +1167,15 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) + def _get_mem_access_desc(self): + """Create memory access descriptor for this device.""" + access_desc = cuda.CUmemAccessDesc() + access_desc.location = cuda.CUmemLocation() + access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + access_desc.location.id = self.device_idx + access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + return access_desc + def lamport_initialize(self, rank: int, dtype: torch.dtype): if dtype == torch.bfloat16 or dtype == torch.float16: neg_zero = 0x8000 From 03700a2dcddaab11b2bdd6a7389517feda260dab Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 4 Dec 2025 13:29:16 -0800 Subject: [PATCH 15/32] =?UTF-8?q?Adapt=20the=20workspace=20creation=20API?= =?UTF-8?q?=20for=20unified=20backend=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flashinfer/comm/trtllm_mnnvl_ar.py | 44 +++++++++++++++++++---- tests/comm/test_trtllm_mnnvl_allreduce.py | 12 +++---- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index afdd580910..6121e88335 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -53,6 +53,9 @@ class MNNVLAllreduceFusionWorkspace: def __init__( self, mapping: Mapping, + max_num_tokens: Optional[int] = None, + hidden_dim: Optional[int] = None, + dtype: Optional[torch.dtype] = None, buffer_size_in_bytes: Optional[int] = None, comm_backend: Optional[CommBackend] = None, ): @@ -63,14 +66,40 @@ def __init__( mapping: Mapping configuration containing rank info buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. """ + if buffer_size_in_bytes is None: - # Default to 16MB workspace size if not provided - buffer_size_in_bytes = 16 * (1024**2) + assert ( + max_num_tokens is not None + and hidden_dim is not None + and dtype is not None + ), ( + "max_num_tokens, hidden_dim, and dtype must be provided if buffer_size_in_bytes is not provided." + ) + one_shot_size_bytes = self.get_required_buffer_size_bytes( + mapping.tp_size, + max_num_tokens, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.ONESHOT, + ) + if max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD: + two_shot_size_bytes = self.get_required_buffer_size_bytes( + mapping.tp_size, + max_num_tokens, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.TWOSHOT, + ) + else: + two_shot_size_bytes = 0 + + # We don't do roundup here as it will happen at the allocation. + buffer_size_in_bytes = max(one_shot_size_bytes, two_shot_size_bytes) else: - # Round up to the nearest multiple of 8MB - buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * ( - 8 * (1024**2) + logging.debug( + f"[MNNVL Allreduce] Using provided buffer size override in bytes: {buffer_size_in_bytes} bytes." ) + if comm_backend is None: comm_backend = MPIBackend() if buffer_size_in_bytes > (2**32 - 1): @@ -481,11 +510,14 @@ def get_allreduce_mnnvl_workspace( # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. workspace = MNNVLAllreduceFusionWorkspace( - mapping, buffer_size_in_bytes, comm_backend_for_handle_transfer + mapping, + buffer_size_in_bytes=buffer_size_in_bytes, + comm_backend=comm_backend_for_handle_transfer, ) mcast_buffer = workspace.mcast_buffer_handle buffer_flags = workspace.buffer_flags + # this is calculated using the legacy behavior. We do not use the actual allocated size. max_num_elements = workspace.buffer_size_bytes // stride return ( diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index cf93b1af6c..78ce392b7a 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -341,15 +341,11 @@ def run_mnnvl_ar_full( ) else: - required_workspace_bytes = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace.get_required_buffer_size_bytes( - mapping.tp_size, - max(seq_lens), - hidden_size, - dtype, - trtllm_mnnvl_ar.MNNVLAllreduceFusionStrategy.AUTO, - ) workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( - mapping, required_workspace_bytes + mapping, + max_num_tokens=max(seq_lens), + hidden_dim=hidden_size, + dtype=dtype, ) test_data = [] From ff84e871d592b67ce93eb1562e6b33af2dc9ef34 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 4 Dec 2025 13:43:28 -0800 Subject: [PATCH 16/32] Use threshold only for onshot workspace size calculation. --- flashinfer/comm/trtllm_mnnvl_ar.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 6121e88335..5ce705abd6 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -75,24 +75,28 @@ def __init__( ), ( "max_num_tokens, hidden_dim, and dtype must be provided if buffer_size_in_bytes is not provided." ) + + # If the user want to explictly use one-shot pass the threshold, which requires larger workspace size, + # We expect the user to set workspace size manually. + elem_size = torch.tensor([], dtype=dtype).element_size() + oneshot_max_num_tokens = min( + MNNVL_ONE_SHOT_THRESHOLD // (mapping.tp_size * elem_size * hidden_dim), + max_num_tokens, + ) one_shot_size_bytes = self.get_required_buffer_size_bytes( mapping.tp_size, - max_num_tokens, + oneshot_max_num_tokens, hidden_dim, dtype, MNNVLAllreduceFusionStrategy.ONESHOT, ) - if max_num_tokens > MNNVL_ONE_SHOT_THRESHOLD: - two_shot_size_bytes = self.get_required_buffer_size_bytes( - mapping.tp_size, - max_num_tokens, - hidden_dim, - dtype, - MNNVLAllreduceFusionStrategy.TWOSHOT, - ) - else: - two_shot_size_bytes = 0 - + two_shot_size_bytes = self.get_required_buffer_size_bytes( + mapping.tp_size, + max_num_tokens, + hidden_dim, + dtype, + MNNVLAllreduceFusionStrategy.TWOSHOT, + ) # We don't do roundup here as it will happen at the allocation. buffer_size_in_bytes = max(one_shot_size_bytes, two_shot_size_bytes) else: From 0f5b7b31f1bdd2542250d1f05c75a11b808473ae Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 9 Dec 2025 05:22:32 -0800 Subject: [PATCH 17/32] Document worjkspace creation behavior. --- flashinfer/comm/trtllm_mnnvl_ar.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 5ce705abd6..308ee4bda6 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -60,10 +60,19 @@ def __init__( comm_backend: Optional[CommBackend] = None, ): """ - Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. + Initialize the MNNVL Allreduce Fusion Workspace. The workspace will be allocated and initialized based on the provided problem size. If max_num_tokens is larger than the one-shot threshold, the workspace will be created according to the max of required one-shot size at threshold, or the required two-shot size. Note that the workspace is not bind to the given problem size. It can be reused for different problem size without reinitialization given the allocated size is sufficient. + + If the buffer_size_in_bytes is provided, the workspace will be created according to the provided size. The user is expected to use the utility function get_required_buffer_size_bytes to calculate the required size. The actual allocation size may be larger due to alignment requirements. This covers the advanced used case, for example, the user may want to enforce oneshot strategy and ignore the heuristics. + + Either max_num_tokens or buffer_size_in_bytes must be provided. + + comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization. Args: mapping: Mapping configuration containing rank info + max_num_tokens: The maximum number of tokens in the input tensor. + hidden_dim: The hidden dimension of the tensors to be reduced. + dtype: The data type of the tensors to be reduced. buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer. """ From b9f43293b15b8a91d742d618945874c60e3e234f Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Thu, 20 Nov 2025 16:09:53 -0800 Subject: [PATCH 18/32] Added first non-working version --- flashinfer/comm/__init__.py | 13 + flashinfer/comm/allreduce.py | 809 +++++++++++++++++++++++++++++++++++ 2 files changed, 822 insertions(+) create mode 100644 flashinfer/comm/allreduce.py diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..b0e9dfd0a4 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,4 +39,17 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# Unified AllReduce Fusion API +from .allreduce import AllReduceFusionContext as AllReduceFusionContext +from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace +from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace +from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace +from .allreduce import allreduce_fusion as allreduce_fusion +from .allreduce import ( + create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, +) +from .allreduce import ( + destroy_allreduce_fusion_workspace as destroy_allreduce_fusion_workspace, +) + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py new file mode 100644 index 0000000000..a8f6498c0c --- /dev/null +++ b/flashinfer/comm/allreduce.py @@ -0,0 +1,809 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Unified AllReduce Fusion API + +This module provides a unified interface for AllReduce + RMSNorm fusion operations +across different backends (TensorRT-LLM, MNNVL). + +Example usage: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> destroy_allreduce_fusion_workspace(workspace) +""" + +from typing import Union, Literal, Optional +from abc import ABC, abstractmethod + +import torch + +from ..utils import backend_requirement, supported_compute_capability + + +# ============================================================================ +# WORKSPACE BASE CLASS +# ============================================================================ + + +class AllReduceFusionWorkspace(ABC): + """Base class for AllReduce fusion workspaces.""" + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + + @property + @abstractmethod + def backend(self) -> str: + """Return backend name.""" + pass + + +class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """TensorRT-LLM workspace for AllReduce fusion.""" + + def __init__(self, world_size: int, rank: int, workspace_ptrs, metadata): + super().__init__(world_size, rank) + self.workspace_ptrs = workspace_ptrs + self.metadata = metadata + + @property + def backend(self) -> str: + return "trtllm" + + +class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): + """MNNVL workspace for AllReduce fusion.""" + + def __init__( + self, + world_size: int, + rank: int, + multicast_buffer_ptr: int, + buffer_ptrs_dev: int, + unicast_ptr: int, + buffer_M: int, + buffer_flags, + ): + super().__init__(world_size, rank) + self.multicast_buffer_ptr = multicast_buffer_ptr + self.buffer_ptrs_dev = buffer_ptrs_dev + self.unicast_ptr = unicast_ptr + self.buffer_M = buffer_M + self.buffer_flags = buffer_flags + + @property + def backend(self) -> str: + return "mnnvl" + + +# ============================================================================ +# BACKEND CHECKS - Hard requirements for decorator +# ============================================================================ + + +@supported_compute_capability([80, 86, 89, 90, 100]) +def _trtllm_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> bool: + """ + Check if trtllm backend CAN be used for workspace creation. + + Hard requirements: + - SM80+ compute capability (checked by decorator) + - Single-node topology + - Module availability + """ + # trtllm is optimized for single-node + if topology == "multi_node": + return False + + return True + + +@supported_compute_capability([90, 100]) +def _mnnvl_workspace_check( + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> bool: + """ + Check if mnnvl backend CAN be used for workspace creation. + + Hard requirements: + - SM90+ compute capability (checked by decorator) + - Multi-node topology + - Module availability + """ + # MNNVL is designed for multi-node + if topology == "single_node": + return False + + return True + + +# ============================================================================ +# HEURISTIC - Performance-based selection for decorator +# ============================================================================ + + +def _workspace_creation_heuristic( + suitable_backends: list[str], + backend: str, + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: Optional[torch.device], + topology: str, + **kwargs, +) -> list[str]: + """ + Select best backend for workspace creation based on performance. + + Called by decorator after checking which backends pass requirements. + Uses benchmarking data to pick fastest option. + + Args: + suitable_backends: List of backends that passed hard requirement checks + backend: Requested backend ("auto", "trtllm", or "mnnvl") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional arguments + + Returns: + List containing the selected backend (single element) + """ + if not suitable_backends: + return [] + + if len(suitable_backends) == 1: + return suitable_backends + + # Decision tree based on benchmark data + # TODO: Replace with actual benchmarking results + + # Multi-node: MNNVL is designed for this + if topology == "multi_node": + if "mnnvl" in suitable_backends: + return ["mnnvl"] + + # Single-node scenarios + problem_size = max_token_num * hidden_dim + + # Large problems (>4M elements): trtllm optimized for throughput + if problem_size > 4 * 1024 * 1024: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Small token counts (<128): trtllm one-shot has better latency + if max_token_num < 128: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Small world sizes (<=4): trtllm one-shot efficient + if world_size <= 4: + if "trtllm" in suitable_backends: + return ["trtllm"] + + # Default: return first available + return [suitable_backends[0]] + + +# ============================================================================ +# WORKSPACE CREATION - Uses decorator for all validation +# ============================================================================ + + +@backend_requirement( + backend_checks={ + "trtllm": _trtllm_workspace_check, + "mnnvl": _mnnvl_workspace_check, + }, + heuristic_func=_workspace_creation_heuristic, +) +def create_allreduce_fusion_workspace( + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + device: Optional[torch.device] = None, + topology: str = "single_node", + process_group: Optional["torch.distributed.ProcessGroup"] = None, + **backend_kwargs, +) -> AllReduceFusionWorkspace: + """ + Create workspace for AllReduce fusion operations. + + Backend selection (checks + heuristics) handled by @backend_requirement decorator. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + "auto" uses heuristic to select best backend based on topology + and problem size + world_size: Number of ranks in the process group + rank: Current rank ID + max_token_num: Maximum number of tokens to support + hidden_dim: Hidden dimension size + dtype: Data type for communication tensors + device: CUDA device (defaults to current CUDA device) + topology: Network topology hint for backend selection + "single_node" - All ranks on one node (default) + "multi_node" - Ranks span multiple nodes + process_group: PyTorch distributed process group + **backend_kwargs: Additional backend-specific arguments + + Returns: + Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) + The workspace type determines which backend will be used in allreduce_fusion() + + Raises: + BackendSupportedError: If no suitable backend available for the configuration + ValueError: If problem size not supported for the specified backend + + Examples: + >>> # Auto-select best backend based on topology + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> print(workspace.backend) # "trtllm" + + >>> # Explicit backend selection + >>> workspace = create_allreduce_fusion_workspace( + ... backend="mnnvl", + ... world_size=16, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="multi_node" + ... ) + >>> print(workspace.backend) # "mnnvl" + """ + if device is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + # Decorator has validated backend - now create workspace + # If backend="auto", decorator has selected the best one and stored it + + # Get actual backend (decorator resolved "auto" to concrete backend) + if backend == "auto": + # Decorator stored the selected backend in suitable_auto_backends + actual_backend = create_allreduce_fusion_workspace.suitable_auto_backends[0] + else: + actual_backend = backend + + # Create workspace for selected backend + if actual_backend == "trtllm": + from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion + + workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_size=world_size, + tp_rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + process_group=process_group, + **backend_kwargs, + ) + # Ensure workspace has required attributes for our API + if not hasattr(workspace, "world_size"): + workspace.world_size = world_size + if not hasattr(workspace, "rank"): + workspace.rank = rank + return workspace + + elif actual_backend == "mnnvl": + # TODO: Implement create_mnnvl_allreduce_fusion_workspace + # For now, raise NotImplementedError with instructions + raise NotImplementedError( + "MNNVL workspace creation needs to be implemented. " + "Expected function: trtllm_mnnvl_ar.create_mnnvl_allreduce_fusion_workspace" + ) + # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace + # return create_mnnvl_allreduce_fusion_workspace( + # world_size=world_size, + # rank=rank, + # max_token_num=max_token_num, + # hidden_dim=hidden_dim, + # dtype=dtype, + # device=device, + # **backend_kwargs + # ) + else: + raise RuntimeError(f"Unknown backend: {actual_backend}") + + +# ============================================================================ +# WORKSPACE DESTRUCTION +# ============================================================================ + + +def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> None: + """ + Destroy workspace and free resources. + + Automatically detects workspace type from the object and calls + appropriate cleanup function. + + Args: + workspace: Workspace object to destroy + + Example: + >>> workspace = create_allreduce_fusion_workspace(...) + >>> # ... use workspace ... + >>> destroy_allreduce_fusion_workspace(workspace) + """ + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(workspace) + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + # TODO: Implement MNNVL workspace destruction + raise NotImplementedError("MNNVL workspace destruction not yet implemented") + # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace + # destroy_mnnvl_allreduce_fusion_workspace(workspace) + else: + raise TypeError(f"Unknown workspace type: {type(workspace)}") + + +# ============================================================================ +# MAIN API - NO backend parameter, infers from workspace type +# ============================================================================ + + +def allreduce_fusion( + input: torch.Tensor, + workspace: AllReduceFusionWorkspace, + launch_with_pdl: bool = False, + # ===== OUTPUT tensors (pre-allocated, will be filled) ===== + output: Optional[torch.Tensor] = None, + residual_out: Optional[torch.Tensor] = None, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + # ===== INPUT parameters ===== + residual_in: Optional[torch.Tensor] = None, + rms_gamma: Optional[torch.Tensor] = None, + rms_eps: float = 1e-6, + scale_factor: Optional[Union[torch.Tensor, float]] = None, + layout_code: Optional[int] = None, + # ===== Control parameters ===== + pattern: Optional[int] = None, + use_oneshot: Optional[bool] = None, + fp32_acc: bool = False, + metadata: Optional[dict] = None, +) -> torch.Tensor: + """ + AllReduce + RMSNorm fusion operation. + + Backend is automatically determined from workspace type. + No backend parameter needed! + + Supports multiple fusion patterns: + - AllReduce only + - AllReduce + Residual + RMSNorm + - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + + Args: + input: Input tensor [token_num, hidden_dim] + workspace: Workspace object (type determines backend) + launch_with_pdl: Use Persistent Device Launch + + # ===== OUTPUT tensors (pre-allocated, filled by function) ===== + output: AllReduce output [token_num, hidden_dim] + residual_out: Prenorm output (after residual add, before norm) [token_num, hidden_dim] + norm_out: Normalized output [token_num, hidden_dim] + quant_out: Quantized output [token_num, hidden_dim] [trtllm only] + scale_out: Quantization scale factors [trtllm only] + + # ===== INPUT parameters ===== + residual_in: Residual tensor to ADD [token_num, hidden_dim] + rms_gamma: RMSNorm weight [hidden_dim] + rms_eps: RMSNorm epsilon for numerical stability + scale_factor: Input scale factor for quantization [trtllm only] + layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] + + # ===== Control parameters ===== + pattern: Fusion pattern (AllReduceFusionPattern) + If None, auto-detected based on provided output tensors + use_oneshot: [trtllm only] Use oneshot strategy vs twoshot + If None, uses internal heuristics + fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce + metadata: [trtllm only] Workspace metadata for validation + + Returns: + Output tensor (typically norm_out for fusion cases, output otherwise) + + Examples: + >>> # Basic AllReduce + Residual + RMSNorm + >>> workspace = create_allreduce_fusion_workspace( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) + >>> + >>> # Pre-allocate output tensors + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> + >>> # Call fusion - backend inferred from workspace type + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> # output == normed (final result) + + >>> # With FP8 quantization + >>> quant = torch.empty_like(hidden_states, dtype=torch.float8_e4m3fn) + >>> scales = torch.empty(token_num * hidden_dim // 16, dtype=torch.float16) + >>> + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... norm_out=normed, + ... quant_out=quant, + ... scale_out=scales, + ... residual_in=residual, + ... rms_gamma=norm_weight, + ... scale_factor=scale_tensor + ... ) + """ + # Auto-detect pattern if not provided + if pattern is None: + pattern = _infer_fusion_pattern( + output, residual_in, residual_out, norm_out, quant_out, scale_out + ) + + # Infer backend from workspace type and dispatch + if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): + return _allreduce_fusion_trtllm( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=output, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + pattern=pattern, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=metadata, + ) + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + return _allreduce_fusion_mnnvl( + input=input, + workspace=workspace, + launch_with_pdl=launch_with_pdl, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + ) + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" + ) + + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + + +def _infer_fusion_pattern( + output, residual_in, residual_out, norm_out, quant_out, scale_out +) -> int: + """ + Automatically infer fusion pattern from provided tensors. + + Returns AllReduceFusionPattern value based on which output tensors are provided. + """ + from .trtllm_ar import AllReduceFusionPattern + + if quant_out is not None: + # Quantization patterns + if norm_out is not None and residual_out is not None: + # Has separate norm output and residual output + return AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + else: + # Quant without separate outputs + return AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 + elif norm_out is not None: + # RMS Norm without quantization + return AllReduceFusionPattern.kARResidualRMSNorm # 1 + else: + # Just AllReduce + return AllReduceFusionPattern.kAllReduce # 0 + + +def _allreduce_fusion_trtllm( + input: torch.Tensor, + workspace: TRTLLMAllReduceFusionWorkspace, + launch_with_pdl: bool, + output: Optional[torch.Tensor], + residual_in: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + norm_out: Optional[torch.Tensor], + quant_out: Optional[torch.Tensor], + scale_out: Optional[torch.Tensor], + rms_gamma: Optional[torch.Tensor], + rms_eps: float, + scale_factor: Optional[Union[torch.Tensor, float]], + layout_code: Optional[int], + pattern: int, + use_oneshot: Optional[bool], + fp32_acc: bool, + metadata: Optional[dict], +) -> torch.Tensor: + """TensorRT-LLM backend implementation.""" + from .trtllm_ar import trtllm_allreduce_fusion + + token_num, hidden_dim = input.shape + + if output is None: + output = torch.empty_like(input) + + trtllm_allreduce_fusion( + allreduce_in=input, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_ptrs, + launch_with_pdl=launch_with_pdl, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, + use_oneshot=use_oneshot, + allreduce_out=output, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + metadata=metadata, + ) + + # Return the most downstream output + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out + else: + return output + + +def _allreduce_fusion_mnnvl( + input: torch.Tensor, + workspace: MNNVLAllReduceFusionWorkspace, + launch_with_pdl: bool, + residual_in: Optional[torch.Tensor], + residual_out: Optional[torch.Tensor], + norm_out: Optional[torch.Tensor], + rms_gamma: Optional[torch.Tensor], + rms_eps: float, +) -> torch.Tensor: + """ + MNNVL backend implementation. + + Calls trtllm_mnnvl_fused_allreduce_rmsnorm which performs: + 1. AllReduce on input + 2. Add residual + 3. RMSNorm + """ + from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_rmsnorm + + # Validate required parameters for RMS fusion + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if residual_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + ) + if norm_out is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires norm_out (normed_output)") + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") + + # Call the MNNVL fusion function + trtllm_mnnvl_fused_allreduce_rmsnorm( + prenorm_output=residual_out, + normed_output=norm_out, + shard_input=input, + multicast_buffer_ptr=workspace.multicast_buffer_ptr, + buffer_ptrs_dev=workspace.buffer_ptrs_dev, + unicast_ptr=workspace.unicast_ptr, + buffer_M=workspace.buffer_M, + buffer_flags_mnnvl=workspace.buffer_flags, + nranks=workspace.world_size, + rank=workspace.rank, + gamma=rms_gamma, + epsilon=rms_eps, + residual=residual_in, + launch_with_pdl=launch_with_pdl, + ) + + return norm_out + + +# ============================================================================ +# CONTEXT MANAGER +# ============================================================================ + + +class AllReduceFusionContext: + """ + Context manager with automatic workspace management. + + This provides a convenient high-level API that handles workspace + creation and cleanup automatically. + + Example: + >>> with AllReduceFusionContext( + ... backend="auto", + ... world_size=8, + ... rank=0, + ... max_token_num=2048, + ... hidden_dim=4096, + ... dtype=torch.bfloat16, + ... topology="single_node" + ... ) as ctx: + ... for batch in training_loop: + ... prenorm = torch.empty_like(batch.hidden_states) + ... normed = torch.empty_like(batch.hidden_states) + ... + ... output = ctx.allreduce_fusion( + ... input=batch.hidden_states, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=batch.residual, + ... rms_gamma=model.norm_weight, + ... launch_with_pdl=True + ... ) + >>> # Workspace automatically cleaned up + """ + + def __init__( + self, + backend: Literal["trtllm", "mnnvl", "auto"] = "auto", + world_size: int = None, + rank: int = None, + max_token_num: int = None, + hidden_dim: int = None, + dtype: torch.dtype = None, + device: Optional[torch.device] = None, + topology: str = "single_node", + **kwargs, + ): + """ + Initialize context manager. + + Args: + backend: Backend to use ("trtllm", "mnnvl", or "auto") + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum tokens to support + hidden_dim: Hidden dimension + dtype: Data type + device: CUDA device + topology: Network topology ("single_node" or "multi_node") + **kwargs: Additional backend-specific arguments + """ + # Workspace creation does all the selection logic via decorator + self.workspace = create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + topology=topology, + **kwargs, + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + destroy_allreduce_fusion_workspace(self.workspace) + + def allreduce_fusion(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Call allreduce_fusion with the managed workspace. + + Args: + input: Input tensor + **kwargs: Additional arguments passed to allreduce_fusion() + + Returns: + Output tensor + """ + return allreduce_fusion(input=input, workspace=self.workspace, **kwargs) From 000271eb2d9892a59d2def51e54d282007c298f9 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 11:33:21 -0800 Subject: [PATCH 19/32] Polished the interface --- flashinfer/comm/allreduce.py | 349 ++++++++++++++++++----------------- 1 file changed, 181 insertions(+), 168 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index a8f6498c0c..e30ca03606 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -54,10 +54,24 @@ import torch from ..utils import backend_requirement, supported_compute_capability +from .trtllm_ar import trtllm_allreduce_fusion +from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion # ============================================================================ -# WORKSPACE BASE CLASS +# WORKSPACE BASE CLASS AND IMPLEMENTATIONS +# ============================================================================ +# +# Workspace classes wrap the underlying backend workspace implementations: +# - TRTLLMAllReduceFusionWorkspace: Wraps trtllm_create_ipc_workspace_for_all_reduce_fusion +# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (to be implemented) +# +# Each workspace: +# 1. Calls the backend-specific workspace creation function in __init__ +# 2. Stores the internal workspace as _internal_workspace +# 3. Exposes essential attributes for the unified API +# 4. Can be destroyed using destroy_allreduce_fusion_workspace() # ============================================================================ @@ -67,6 +81,7 @@ class AllReduceFusionWorkspace(ABC): def __init__(self, world_size: int, rank: int): self.world_size = world_size self.rank = rank + self._destroyed = False @property @abstractmethod @@ -74,14 +89,105 @@ def backend(self) -> str: """Return backend name.""" pass + @abstractmethod + def destroy(self) -> None: + """ + Destroy workspace and free resources. + + This should be called explicitly when done using the workspace. + Prefer using AllReduceFusionContext context manager for automatic cleanup. + """ + pass + + def __del__(self): + """ + Destructor - safety net if destroy() wasn't called explicitly. + + Warns if cleanup wasn't done properly. Not recommended to rely on this + as __del__ timing is non-deterministic and can cause issues with + distributed/CUDA resources. + """ + if not self._destroyed: + import warnings + + warnings.warn( + f"{self.__class__.__name__} was not explicitly destroyed. " + f"Call workspace.destroy() or use AllReduceFusionContext to ensure " + f"proper cleanup of distributed/CUDA resources.", + ResourceWarning, + stacklevel=2, + ) + try: + self.destroy() + except Exception as e: + # Can't raise in __del__, just warn + warnings.warn( + f"Error during automatic cleanup of {self.__class__.__name__}: {e}", + ResourceWarning, + stacklevel=2, + ) + class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): """TensorRT-LLM workspace for AllReduce fusion.""" - def __init__(self, world_size: int, rank: int, workspace_ptrs, metadata): - super().__init__(world_size, rank) - self.workspace_ptrs = workspace_ptrs - self.metadata = metadata + def __init__( + self, + tp_size: int, + tp_rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + process_group: Optional["torch.distributed.ProcessGroup"] = None, + **kwargs, + ): + """ + Create TensorRT-LLM AllReduce fusion workspace. + + Args: + tp_size: Tensor parallel size (world size) + tp_rank: Tensor parallel rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + process_group: PyTorch distributed process group + **kwargs: Additional arguments for workspace creation + """ + super().__init__(tp_size, tp_rank) + + # Call the actual workspace creation function + self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_size=tp_size, + tp_rank=tp_rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + process_group=process_group, + **kwargs, + ) + + # Store essential attributes for easy access + self.workspace_ptrs = self._internal_workspace.workspace_ptrs + self.metadata = self._internal_workspace.metadata + + def __getattr__(self, name): + """Delegate attribute access to internal workspace if not found.""" + if name.startswith("_"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + return getattr(self._internal_workspace, name) + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self._internal_workspace) + self._destroyed = True @property def backend(self) -> str: @@ -95,18 +201,64 @@ def __init__( self, world_size: int, rank: int, - multicast_buffer_ptr: int, - buffer_ptrs_dev: int, - unicast_ptr: int, - buffer_M: int, - buffer_flags, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + device: torch.device, + **kwargs, ): + """ + Create MNNVL AllReduce fusion workspace. + + Args: + world_size: Number of ranks + rank: Current rank + max_token_num: Maximum number of tokens + hidden_dim: Hidden dimension size + dtype: Data type + device: CUDA device + **kwargs: Additional arguments for workspace creation + """ super().__init__(world_size, rank) - self.multicast_buffer_ptr = multicast_buffer_ptr - self.buffer_ptrs_dev = buffer_ptrs_dev - self.unicast_ptr = unicast_ptr - self.buffer_M = buffer_M - self.buffer_flags = buffer_flags + + # TODO: Import and call the actual MNNVL workspace creation function + # For now, raise NotImplementedError + raise NotImplementedError( + "MNNVL workspace creation needs to be implemented in trtllm_mnnvl_ar.py. " + "Expected function: create_mnnvl_allreduce_fusion_workspace" + ) + + # When implemented, should look like: + # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace + # + # self._internal_workspace = create_mnnvl_allreduce_fusion_workspace( + # world_size=world_size, + # rank=rank, + # max_token_num=max_token_num, + # hidden_dim=hidden_dim, + # dtype=dtype, + # device=device, + # **kwargs, + # ) + # + # # Store essential attributes for easy access + # self.multicast_buffer_ptr = self._internal_workspace.multicast_buffer_ptr + # self.buffer_ptrs_dev = self._internal_workspace.buffer_ptrs_dev + # self.unicast_ptr = self._internal_workspace.unicast_ptr + # self.buffer_M = self._internal_workspace.buffer_M + # self.buffer_flags = self._internal_workspace.buffer_flags + + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + # TODO: Implement MNNVL workspace destruction + self._destroyed = True + raise NotImplementedError("MNNVL workspace destruction not yet implemented") + # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace + # destroy_mnnvl_allreduce_fusion_workspace(self._internal_workspace) + # self._destroyed = True @property def backend(self) -> str: @@ -337,11 +489,9 @@ def create_allreduce_fusion_workspace( else: actual_backend = backend - # Create workspace for selected backend + # Create workspace for selected backend using workspace constructors if actual_backend == "trtllm": - from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion - - workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( + return TRTLLMAllReduceFusionWorkspace( tp_size=world_size, tp_rank=rank, max_token_num=max_token_num, @@ -351,30 +501,17 @@ def create_allreduce_fusion_workspace( process_group=process_group, **backend_kwargs, ) - # Ensure workspace has required attributes for our API - if not hasattr(workspace, "world_size"): - workspace.world_size = world_size - if not hasattr(workspace, "rank"): - workspace.rank = rank - return workspace elif actual_backend == "mnnvl": - # TODO: Implement create_mnnvl_allreduce_fusion_workspace - # For now, raise NotImplementedError with instructions - raise NotImplementedError( - "MNNVL workspace creation needs to be implemented. " - "Expected function: trtllm_mnnvl_ar.create_mnnvl_allreduce_fusion_workspace" + return MNNVLAllReduceFusionWorkspace( + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + device=device, + **backend_kwargs, ) - # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace - # return create_mnnvl_allreduce_fusion_workspace( - # world_size=world_size, - # rank=rank, - # max_token_num=max_token_num, - # hidden_dim=hidden_dim, - # dtype=dtype, - # device=device, - # **backend_kwargs - # ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") @@ -388,8 +525,7 @@ def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> N """ Destroy workspace and free resources. - Automatically detects workspace type from the object and calls - appropriate cleanup function. + This is a convenience function that calls the workspace's destroy() method. Args: workspace: Workspace object to destroy @@ -398,18 +534,9 @@ def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> N >>> workspace = create_allreduce_fusion_workspace(...) >>> # ... use workspace ... >>> destroy_allreduce_fusion_workspace(workspace) + >>> # Or call directly: workspace.destroy() """ - if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): - from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion - - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(workspace) - elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): - # TODO: Implement MNNVL workspace destruction - raise NotImplementedError("MNNVL workspace destruction not yet implemented") - # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace - # destroy_mnnvl_allreduce_fusion_workspace(workspace) - else: - raise TypeError(f"Unknown workspace type: {type(workspace)}") + workspace.destroy() # ============================================================================ @@ -619,7 +746,6 @@ def _allreduce_fusion_trtllm( metadata: Optional[dict], ) -> torch.Tensor: """TensorRT-LLM backend implementation.""" - from .trtllm_ar import trtllm_allreduce_fusion token_num, hidden_dim = input.shape @@ -678,8 +804,6 @@ def _allreduce_fusion_mnnvl( 2. Add residual 3. RMSNorm """ - from .trtllm_mnnvl_ar import trtllm_mnnvl_fused_allreduce_rmsnorm - # Validate required parameters for RMS fusion if residual_in is None: raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") @@ -693,117 +817,6 @@ def _allreduce_fusion_mnnvl( raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") # Call the MNNVL fusion function - trtllm_mnnvl_fused_allreduce_rmsnorm( - prenorm_output=residual_out, - normed_output=norm_out, - shard_input=input, - multicast_buffer_ptr=workspace.multicast_buffer_ptr, - buffer_ptrs_dev=workspace.buffer_ptrs_dev, - unicast_ptr=workspace.unicast_ptr, - buffer_M=workspace.buffer_M, - buffer_flags_mnnvl=workspace.buffer_flags, - nranks=workspace.world_size, - rank=workspace.rank, - gamma=rms_gamma, - epsilon=rms_eps, - residual=residual_in, - launch_with_pdl=launch_with_pdl, - ) + raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") return norm_out - - -# ============================================================================ -# CONTEXT MANAGER -# ============================================================================ - - -class AllReduceFusionContext: - """ - Context manager with automatic workspace management. - - This provides a convenient high-level API that handles workspace - creation and cleanup automatically. - - Example: - >>> with AllReduceFusionContext( - ... backend="auto", - ... world_size=8, - ... rank=0, - ... max_token_num=2048, - ... hidden_dim=4096, - ... dtype=torch.bfloat16, - ... topology="single_node" - ... ) as ctx: - ... for batch in training_loop: - ... prenorm = torch.empty_like(batch.hidden_states) - ... normed = torch.empty_like(batch.hidden_states) - ... - ... output = ctx.allreduce_fusion( - ... input=batch.hidden_states, - ... residual_out=prenorm, - ... norm_out=normed, - ... residual_in=batch.residual, - ... rms_gamma=model.norm_weight, - ... launch_with_pdl=True - ... ) - >>> # Workspace automatically cleaned up - """ - - def __init__( - self, - backend: Literal["trtllm", "mnnvl", "auto"] = "auto", - world_size: int = None, - rank: int = None, - max_token_num: int = None, - hidden_dim: int = None, - dtype: torch.dtype = None, - device: Optional[torch.device] = None, - topology: str = "single_node", - **kwargs, - ): - """ - Initialize context manager. - - Args: - backend: Backend to use ("trtllm", "mnnvl", or "auto") - world_size: Number of ranks - rank: Current rank - max_token_num: Maximum tokens to support - hidden_dim: Hidden dimension - dtype: Data type - device: CUDA device - topology: Network topology ("single_node" or "multi_node") - **kwargs: Additional backend-specific arguments - """ - # Workspace creation does all the selection logic via decorator - self.workspace = create_allreduce_fusion_workspace( - backend=backend, - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - device=device, - topology=topology, - **kwargs, - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - destroy_allreduce_fusion_workspace(self.workspace) - - def allreduce_fusion(self, input: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Call allreduce_fusion with the managed workspace. - - Args: - input: Input tensor - **kwargs: Additional arguments passed to allreduce_fusion() - - Returns: - Output tensor - """ - return allreduce_fusion(input=input, workspace=self.workspace, **kwargs) From 0141ae0ad3fd7a96b9db4598e513428ff7c12f34 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 14:07:33 -0800 Subject: [PATCH 20/32] Removed device param --- flashinfer/comm/allreduce.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index e30ca03606..e7699c8d0b 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -138,7 +138,6 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: torch.device, process_group: Optional["torch.distributed.ProcessGroup"] = None, **kwargs, ): @@ -151,7 +150,6 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device process_group: PyTorch distributed process group **kwargs: Additional arguments for workspace creation """ @@ -159,12 +157,10 @@ def __init__( # Call the actual workspace creation function self._internal_workspace = trtllm_create_ipc_workspace_for_all_reduce_fusion( - tp_size=tp_size, tp_rank=tp_rank, + tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - dtype=dtype, - device=device, process_group=process_group, **kwargs, ) @@ -204,7 +200,6 @@ def __init__( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: torch.device, **kwargs, ): """ @@ -216,7 +211,6 @@ def __init__( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device **kwargs: Additional arguments for workspace creation """ super().__init__(world_size, rank) @@ -229,15 +223,12 @@ def __init__( ) # When implemented, should look like: - # from .trtllm_mnnvl_ar import create_mnnvl_allreduce_fusion_workspace - # # self._internal_workspace = create_mnnvl_allreduce_fusion_workspace( # world_size=world_size, # rank=rank, # max_token_num=max_token_num, # hidden_dim=hidden_dim, # dtype=dtype, - # device=device, # **kwargs, # ) # @@ -278,7 +269,6 @@ def _trtllm_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> bool: @@ -305,7 +295,6 @@ def _mnnvl_workspace_check( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> bool: @@ -337,7 +326,6 @@ def _workspace_creation_heuristic( max_token_num: int, hidden_dim: int, dtype: torch.dtype, - device: Optional[torch.device], topology: str, **kwargs, ) -> list[str]: @@ -355,7 +343,6 @@ def _workspace_creation_heuristic( max_token_num: Maximum number of tokens hidden_dim: Hidden dimension size dtype: Data type - device: CUDA device topology: Network topology ("single_node" or "multi_node") **kwargs: Additional arguments @@ -417,7 +404,6 @@ def create_allreduce_fusion_workspace( max_token_num: int = None, hidden_dim: int = None, dtype: torch.dtype = None, - device: Optional[torch.device] = None, topology: str = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, **backend_kwargs, @@ -436,7 +422,6 @@ def create_allreduce_fusion_workspace( max_token_num: Maximum number of tokens to support hidden_dim: Hidden dimension size dtype: Data type for communication tensors - device: CUDA device (defaults to current CUDA device) topology: Network topology hint for backend selection "single_node" - All ranks on one node (default) "multi_node" - Ranks span multiple nodes @@ -476,9 +461,6 @@ def create_allreduce_fusion_workspace( ... ) >>> print(workspace.backend) # "mnnvl" """ - if device is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - # Decorator has validated backend - now create workspace # If backend="auto", decorator has selected the best one and stored it @@ -496,8 +478,6 @@ def create_allreduce_fusion_workspace( tp_rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, - dtype=dtype, - device=device, process_group=process_group, **backend_kwargs, ) @@ -509,7 +489,6 @@ def create_allreduce_fusion_workspace( max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - device=device, **backend_kwargs, ) else: @@ -580,7 +559,7 @@ def allreduce_fusion( Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) - launch_with_pdl: Use Persistent Device Launch + launch_with_pdl: Use Persistent Dependency Launch # ===== OUTPUT tensors (pre-allocated, filled by function) ===== output: AllReduce output [token_num, hidden_dim] From 4ea74f3e9a6dacf31ba7d1e7feb91c57abfdafa6 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Fri, 21 Nov 2025 15:06:53 -0800 Subject: [PATCH 21/32] Updated test with legacy vs unified API --- flashinfer/comm/__init__.py | 1 - flashinfer/comm/allreduce.py | 40 ++-- tests/comm/test_trtllm_allreduce_fusion.py | 236 +++++++++++++++------ 3 files changed, 195 insertions(+), 82 deletions(-) diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index b0e9dfd0a4..3ffad2505d 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -40,7 +40,6 @@ from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers # Unified AllReduce Fusion API -from .allreduce import AllReduceFusionContext as AllReduceFusionContext from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index e7699c8d0b..693d052723 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -137,7 +137,6 @@ def __init__( tp_rank: int, max_token_num: int, hidden_dim: int, - dtype: torch.dtype, process_group: Optional["torch.distributed.ProcessGroup"] = None, **kwargs, ): @@ -161,13 +160,14 @@ def __init__( tp_size=tp_size, max_token_num=max_token_num, hidden_dim=hidden_dim, - process_group=process_group, + group=process_group, **kwargs, ) # Store essential attributes for easy access - self.workspace_ptrs = self._internal_workspace.workspace_ptrs - self.metadata = self._internal_workspace.metadata + self.ipc_handles = self._internal_workspace[0] + self.workspace_tensor = self._internal_workspace[1] + self.metadata = self._internal_workspace[2] def __getattr__(self, name): """Delegate attribute access to internal workspace if not found.""" @@ -726,37 +726,49 @@ def _allreduce_fusion_trtllm( ) -> torch.Tensor: """TensorRT-LLM backend implementation.""" + # Extract shape from 2D input token_num, hidden_dim = input.shape + # Allocate output if needed (keep 2D shape) if output is None: output = torch.empty_like(input) + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + input_flat = input.flatten() + output_flat = output.flatten() + residual_in_flat = residual_in.flatten() if residual_in is not None else None + residual_out_flat = residual_out.flatten() if residual_out is not None else None + norm_out_flat = norm_out.flatten() if norm_out is not None else None + quant_out_flat = quant_out.flatten() if quant_out is not None else None + + # Call legacy API with flattened tensors trtllm_allreduce_fusion( - allreduce_in=input, + allreduce_in=input_flat, world_size=workspace.world_size, world_rank=workspace.rank, token_num=token_num, hidden_dim=hidden_dim, - workspace_ptrs=workspace.workspace_ptrs, + workspace_ptrs=workspace.workspace_tensor, launch_with_pdl=launch_with_pdl, trigger_completion_at_end=launch_with_pdl, # Same meaning fp32_acc=fp32_acc, pattern_code=pattern, use_oneshot=use_oneshot, - allreduce_out=output, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, layout_code=layout_code, metadata=metadata, ) - # Return the most downstream output + # Return the most downstream output (already in 2D shape from input views) if norm_out is not None: return norm_out elif quant_out is not None: diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index c3aa8c8252..17d9d7c2d4 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -22,7 +22,9 @@ SCALE_FACTOR_RANGE = (-1, 1) -def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_init_port): +def _run_correctness_worker( + world_size, rank, dtype, hidden_dim, distributed_init_port, legacy_api=True +): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" @@ -57,18 +59,37 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini lamport_use_fp32 = dtype == torch.float32 - # create workspace for allreduce fusion with metadata - ipc_handles, workspace_tensor, workspace_metadata = ( - comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( - rank, - world_size, - MAX_TOKEN_NUM, - hidden_dim, - group=group, + # Create workspace - choose between legacy and new API + if legacy_api: + # Legacy API: create workspace for allreduce fusion with metadata + ipc_handles, workspace_tensor, workspace_metadata = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + MAX_TOKEN_NUM, + hidden_dim, + group=group, + use_fp32_lamport=lamport_use_fp32, + create_metadata=True, # Get metadata for validation + ) + ) + else: + workspace = None + # New unified API: create workspace + workspace = comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=MAX_TOKEN_NUM, + hidden_dim=hidden_dim, + dtype=dtype, + topology="single_node", + process_group=group, use_fp32_lamport=lamport_use_fp32, - create_metadata=True, # Get metadata for validation + create_metadata=True, ) - ) + # Extract metadata for compatibility with tests + workspace_metadata = workspace.metadata test_loop = 5 @@ -163,60 +184,130 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=workspace_metadata, + ) # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): for _ in range(test_loop): - comm.trtllm_allreduce_fusion( - allreduce_in=allreduce_in, - world_size=world_size, - world_rank=rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace_tensor, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - trigger_completion_at_end=trigger_completion_at_end, - fp32_acc=fp32_acc, - pattern_code=pattern_code, - allreduce_out=all_reduce_out, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=swizzled_layout_code, - metadata=workspace_metadata, - ) + if legacy_api: + # Legacy API - uses flattened tensors + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + metadata=workspace_metadata, + ) + else: + # New unified API - expects 2D tensors [token_num, hidden_dim] + comm.allreduce_fusion( + input=allreduce_in.view( + token_num, hidden_dim + ), + workspace=workspace, + launch_with_pdl=launch_with_pdl, + output=all_reduce_out.view( + token_num, hidden_dim + ), + residual_in=residual_in.view( + token_num, hidden_dim + ), + residual_out=residual_out.view( + token_num, hidden_dim + ), + norm_out=norm_out.view( + token_num, hidden_dim + ), + quant_out=quant_out.view( + token_num, hidden_dim + ), + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=swizzled_layout_code, + pattern=pattern_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, + metadata=workspace_metadata, + ) # replay g.replay() torch.cuda.synchronize() @@ -307,9 +398,14 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini finally: dist.barrier(group=group) - comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( - ipc_handles, group=group - ) + # Destroy workspace - choose between legacy and new API + if legacy_api: + comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion( + ipc_handles, group=group + ) + elif workspace is not None: + # New unified API + workspace.destroy() dist.destroy_process_group(group=group) @@ -358,7 +454,8 @@ def multi_process_parallel( @pytest.mark.parametrize("world_size", [2, 4, 8]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192]) -def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): +@pytest.mark.parametrize("legacy_api", [True, False]) +def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim, legacy_api): np.random.seed(42) torch.manual_seed(42) torch.cuda.manual_seed_all(42) @@ -367,17 +464,22 @@ def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim): pytest.skip( f"world_size {world_size} is greater than available_gpus {available_gpus}" ) - print(f"Running test for world_size={world_size}") + api_str = "legacy" if legacy_api else "unified" + print(f"Running test for world_size={world_size} with {api_str} API") multi_process_parallel( world_size, dtype, hidden_dim, _run_correctness_worker, - target_args=(), + target_args=(legacy_api,), ) - print(f"allreduce fusion tp = {world_size}: OK") + print(f"allreduce fusion tp = {world_size} ({api_str} API): OK") if __name__ == "__main__": - test_trtllm_allreduce_fusion(2, torch.float16, 1024) + # Test both legacy and unified APIs + print("Testing legacy API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=True) + print("\nTesting unified API...") + test_trtllm_allreduce_fusion(2, torch.float16, 1024, legacy_api=False) From 2e2fb25b5761186486174a6e99223bfa5eb8efbd Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 09:18:46 -0800 Subject: [PATCH 22/32] Fixed unit test --- flashinfer/comm/allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 693d052723..89e5edfebb 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -182,7 +182,7 @@ def destroy(self) -> None: if self._destroyed: return # Already destroyed, nothing to do - trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self._internal_workspace) + trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) self._destroyed = True @property From fe8b88c11b9c5eff5c7fffe582aa53ab77f6807f Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 12:51:38 -0800 Subject: [PATCH 23/32] Relaxed check on trtllm_ar --- flashinfer/comm/trtllm_ar.py | 92 +++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 33bb7ac97b..e0b4369f12 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -804,6 +804,51 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] +def _check_workspace_metadata( + token_num: int, + hidden_dim: int, + world_size: int, + dtype: torch.dtype, + metadata: dict, +) -> None: + errors = [] + required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] + for key in required_keys: + if key not in metadata: + errors.append(f"Workspace metadata is missing required key: {key}") + if errors: + error_msg = "Workspace metadata validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + # world_size must match tp_size (flag size depends on it) + if world_size != metadata["tp_size"]: + errors.append( + f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " + f"Workspace was created for tp_size={metadata['tp_size']}." + ) + + # token_num * hidden_dim must not exceed max_token_num * hidden_dim + if token_num * hidden_dim > metadata["max_token_num"] * metadata["hidden_dim"]: + errors.append( + f"token_num ({token_num}) * hidden_dim ({hidden_dim}) exceeds workspace max_token_num ({metadata['max_token_num']}) * hidden_dim ({metadata['hidden_dim']}). " + f"This may cause Illegal Memory Access." + ) + + # use_fp32_lamport must match + if metadata["use_fp32_lamport"] != (dtype == torch.float32): + errors.append( + f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({dtype}). " + f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." + ) + if errors: + error_msg = "Workspace validation failed:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise ValueError(error_msg) + + def trtllm_allreduce_fusion( allreduce_in: torch.Tensor, world_size: int, @@ -858,50 +903,9 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - errors = [] - required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"] - for key in required_keys: - if key not in metadata: - errors.append(f"Workspace metadata is missing required key: {key}") - if errors: - error_msg = "Workspace metadata validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) - - # Check 1: token_num must not exceed max_token_num - if token_num > metadata["max_token_num"]: - errors.append( - f"token_num ({token_num}) exceeds workspace max_token_num ({metadata['max_token_num']}). " - f"This may cause Illegal Memory Access." - ) - - # Check 2: world_size must match tp_size - if world_size != metadata["tp_size"]: - errors.append( - f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). " - f"Workspace was created for tp_size={metadata['tp_size']}." - ) - - # Check 3: hidden_dim must match - if hidden_dim != metadata["hidden_dim"]: - errors.append( - f"hidden_dim ({hidden_dim}) does not match workspace hidden_dim ({metadata['hidden_dim']}). " - f"Workspace was created for hidden_dim={metadata['hidden_dim']}." - ) - - # Check 4: use_fp32_lamport must match - if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32): - errors.append( - f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({allreduce_in.dtype}). " - f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}." - ) - - if errors: - error_msg = "Workspace validation failed:\n" + "\n".join( - f" - {e}" for e in errors - ) - raise ValueError(error_msg) + _check_workspace_metadata( + token_num, hidden_dim, world_size, allreduce_in.dtype, metadata + ) if use_oneshot is None: use_oneshot = _should_use_oneshot( From b3b19a54cffda358e37db895db13257cda7c3cda Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 15:01:21 -0800 Subject: [PATCH 24/32] Made metadata mandatory in unified API, added workspace check functions --- flashinfer/comm/allreduce.py | 77 +++++++++++++++++----- flashinfer/comm/trtllm_ar.py | 4 +- tests/comm/test_trtllm_allreduce_fusion.py | 1 - 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 89e5edfebb..1227dc4b06 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -48,7 +48,7 @@ >>> destroy_allreduce_fusion_workspace(workspace) """ -from typing import Union, Literal, Optional +from typing import Union, Literal, Optional, Tuple, List, cast from abc import ABC, abstractmethod import torch @@ -57,6 +57,11 @@ from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion +from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata + +# Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) +# Import them for runtime use but type hint as int for mypy compatibility +from .trtllm_ar import AllReduceFusionPattern # ============================================================================ @@ -161,13 +166,22 @@ def __init__( max_token_num=max_token_num, hidden_dim=hidden_dim, group=process_group, + create_metadata=True, **kwargs, ) # Store essential attributes for easy access - self.ipc_handles = self._internal_workspace[0] - self.workspace_tensor = self._internal_workspace[1] - self.metadata = self._internal_workspace[2] + # Cast to 3-tuple to make linter happy, since we always call with create_metadata=True + workspace_tuple = cast( + Tuple[List[List[int]], torch.Tensor, dict], self._internal_workspace + ) + self.ipc_handles = workspace_tuple[0] + self.workspace_tensor = workspace_tuple[1] + self.metadata = workspace_tuple[2] + + @property + def backend(self) -> str: + return "trtllm" def __getattr__(self, name): """Delegate attribute access to internal workspace if not found.""" @@ -177,6 +191,18 @@ def __getattr__(self, name): ) return getattr(self._internal_workspace, name) + def is_sufficient_for( + self, token_num: int, hidden_dim: int, tp_size: int, dtype: torch.dtype + ) -> bool: + try: + check_trtllm_allreduce_fusion_workspace_metadata( + token_num, hidden_dim, tp_size, dtype, self.metadata + ) + return True + except ValueError as e: + print(f"Workspace is insufficient for problem size. {e}") + return False + def destroy(self) -> None: """Destroy workspace and free resources.""" if self._destroyed: @@ -185,10 +211,6 @@ def destroy(self) -> None: trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) self._destroyed = True - @property - def backend(self) -> str: - return "trtllm" - class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): """MNNVL workspace for AllReduce fusion.""" @@ -214,7 +236,6 @@ def __init__( **kwargs: Additional arguments for workspace creation """ super().__init__(world_size, rank) - # TODO: Import and call the actual MNNVL workspace creation function # For now, raise NotImplementedError raise NotImplementedError( @@ -239,6 +260,10 @@ def __init__( # self.buffer_M = self._internal_workspace.buffer_M # self.buffer_flags = self._internal_workspace.buffer_flags + @property + def backend(self) -> str: + return "mnnvl" + def destroy(self) -> None: """Destroy workspace and free resources.""" if self._destroyed: @@ -251,10 +276,6 @@ def destroy(self) -> None: # destroy_mnnvl_allreduce_fusion_workspace(self._internal_workspace) # self._destroyed = True - @property - def backend(self) -> str: - return "mnnvl" - # ============================================================================ # BACKEND CHECKS - Hard requirements for decorator @@ -413,6 +434,19 @@ def create_allreduce_fusion_workspace( Backend selection (checks + heuristics) handled by @backend_requirement decorator. + **Important: Workspace Reusability** + The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). + You can reuse the same workspace with different shapes as long as the total size fits: + + - Workspace(max_token_num=2048, hidden_dim=4096) can handle: + - (token_num=2048, hidden_dim=4096) ✓ + - (token_num=1024, hidden_dim=4096) ✓ + - (token_num=4096, hidden_dim=2048) ✓ (same total size) + - (token_num=1024, hidden_dim=8192) ✓ (same total size) + - (token_num=4096, hidden_dim=4096) ✗ (too large) + + Use `workspace.is_sufficient_for(token_num, hidden_dim, dtype)` to check before use. + Args: backend: Backend to use ("trtllm", "mnnvl", or "auto") "auto" uses heuristic to select best backend based on topology @@ -448,6 +482,11 @@ def create_allreduce_fusion_workspace( ... topology="single_node" ... ) >>> print(workspace.backend) # "trtllm" + >>> print(workspace.get_workspace_capacity()) # 8388608 elements + + >>> # Check if workspace can handle different problem sizes + >>> workspace.is_sufficient_for(1024, 4096, 8, torch.bfloat16) # True + >>> workspace.is_sufficient_for(4096, 2048, 8, torch.bfloat16) # True (same total) >>> # Explicit backend selection >>> workspace = create_allreduce_fusion_workspace( @@ -556,6 +595,10 @@ def allreduce_fusion( - AllReduce + Residual + RMSNorm - AllReduce + Residual + RMSNorm + Quantization (FP8/FP4) + **Note on Workspace Reusability:** + You can reuse the same workspace with different (token_num, hidden_dim) combinations + as long as `workspace.is_sufficient_for(token_num, hidden_dim, tp_size, dtype)` returns True. + Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) @@ -685,9 +728,8 @@ def _infer_fusion_pattern( """ Automatically infer fusion pattern from provided tensors. - Returns AllReduceFusionPattern value based on which output tensors are provided. + Returns AllReduceFusionPattern value (as int) based on which output tensors are provided. """ - from .trtllm_ar import AllReduceFusionPattern if quant_out is not None: # Quantization patterns @@ -743,6 +785,7 @@ def _allreduce_fusion_trtllm( quant_out_flat = quant_out.flatten() if quant_out is not None else None # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints trtllm_allreduce_fusion( allreduce_in=input_flat, world_size=workspace.world_size, @@ -753,7 +796,7 @@ def _allreduce_fusion_trtllm( launch_with_pdl=launch_with_pdl, trigger_completion_at_end=launch_with_pdl, # Same meaning fp32_acc=fp32_acc, - pattern_code=pattern, + pattern_code=pattern, # type: ignore[arg-type] use_oneshot=use_oneshot, allreduce_out=output_flat, residual_in=residual_in_flat, @@ -764,7 +807,7 @@ def _allreduce_fusion_trtllm( rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, - layout_code=layout_code, + layout_code=layout_code, # type: ignore[arg-type] metadata=metadata, ) diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index e0b4369f12..87246f739a 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -804,7 +804,7 @@ def _should_use_oneshot( return comm_size_mb <= _use_oneshot_heuristics[world_size] -def _check_workspace_metadata( +def check_trtllm_allreduce_fusion_workspace_metadata( token_num: int, hidden_dim: int, world_size: int, @@ -903,7 +903,7 @@ def trtllm_allreduce_fusion( # Validate against workspace metadata if provided if metadata is not None: - _check_workspace_metadata( + check_trtllm_allreduce_fusion_workspace_metadata( token_num, hidden_dim, world_size, allreduce_in.dtype, metadata ) diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index 17d9d7c2d4..31ddc1518b 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -86,7 +86,6 @@ def _run_correctness_worker( topology="single_node", process_group=group, use_fp32_lamport=lamport_use_fp32, - create_metadata=True, ) # Extract metadata for compatibility with tests workspace_metadata = workspace.metadata From ca83f12805f242f08fb46e743d4d77431eb295a7 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 24 Nov 2025 15:12:22 -0800 Subject: [PATCH 25/32] Merged dtype and use_fp32_lamport params --- flashinfer/comm/allreduce.py | 8 ++++---- tests/comm/test_trtllm_allreduce_fusion.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 1227dc4b06..70a307187e 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -142,8 +142,8 @@ def __init__( tp_rank: int, max_token_num: int, hidden_dim: int, + dtype: torch.dtype = torch.float16, process_group: Optional["torch.distributed.ProcessGroup"] = None, - **kwargs, ): """ Create TensorRT-LLM AllReduce fusion workspace. @@ -167,7 +167,7 @@ def __init__( hidden_dim=hidden_dim, group=process_group, create_metadata=True, - **kwargs, + use_fp32_lamport=dtype == torch.float32, ) # Store essential attributes for easy access @@ -427,7 +427,7 @@ def create_allreduce_fusion_workspace( dtype: torch.dtype = None, topology: str = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, - **backend_kwargs, + **backend_kwargs, # TODO(nvmbreughe): remove this ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -517,8 +517,8 @@ def create_allreduce_fusion_workspace( tp_rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, + dtype=dtype, process_group=process_group, - **backend_kwargs, ) elif actual_backend == "mnnvl": diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index 31ddc1518b..601bddbb91 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -85,7 +85,6 @@ def _run_correctness_worker( dtype=dtype, topology="single_node", process_group=group, - use_fp32_lamport=lamport_use_fp32, ) # Extract metadata for compatibility with tests workspace_metadata = workspace.metadata From b0314749ff897a19e9dad427b555c0aeda2649c0 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:29:02 -0800 Subject: [PATCH 26/32] removed useless function --- flashinfer/comm/__init__.py | 3 -- flashinfer/comm/allreduce.py | 53 ++++++++++-------------------------- 2 files changed, 15 insertions(+), 41 deletions(-) diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 3ffad2505d..8a1d6f9f0c 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -47,8 +47,5 @@ from .allreduce import ( create_allreduce_fusion_workspace as create_allreduce_fusion_workspace, ) -from .allreduce import ( - destroy_allreduce_fusion_workspace as destroy_allreduce_fusion_workspace, -) # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 70a307187e..54f1b17c8b 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -32,20 +32,20 @@ ... topology="single_node" ... ) >>> - >>> # Perform AllReduce + RMSNorm fusion - >>> prenorm = torch.empty_like(hidden_states) - >>> normed = torch.empty_like(hidden_states) - >>> output = allreduce_fusion( - ... input=hidden_states, - ... workspace=workspace, - ... launch_with_pdl=True, - ... residual_out=prenorm, - ... norm_out=normed, - ... residual_in=residual, - ... rms_gamma=norm_weight - ... ) - >>> - >>> destroy_allreduce_fusion_workspace(workspace) + >>> # Perform AllReduce + RMSNorm fusion + >>> prenorm = torch.empty_like(hidden_states) + >>> normed = torch.empty_like(hidden_states) + >>> output = allreduce_fusion( + ... input=hidden_states, + ... workspace=workspace, + ... launch_with_pdl=True, + ... residual_out=prenorm, + ... norm_out=normed, + ... residual_in=residual, + ... rms_gamma=norm_weight + ... ) + >>> + >>> workspace.destroy() """ from typing import Union, Literal, Optional, Tuple, List, cast @@ -76,7 +76,7 @@ # 1. Calls the backend-specific workspace creation function in __init__ # 2. Stores the internal workspace as _internal_workspace # 3. Exposes essential attributes for the unified API -# 4. Can be destroyed using destroy_allreduce_fusion_workspace() +# 4. Can be destroyed using workspace.destroy() # ============================================================================ @@ -534,29 +534,6 @@ def create_allreduce_fusion_workspace( raise RuntimeError(f"Unknown backend: {actual_backend}") -# ============================================================================ -# WORKSPACE DESTRUCTION -# ============================================================================ - - -def destroy_allreduce_fusion_workspace(workspace: AllReduceFusionWorkspace) -> None: - """ - Destroy workspace and free resources. - - This is a convenience function that calls the workspace's destroy() method. - - Args: - workspace: Workspace object to destroy - - Example: - >>> workspace = create_allreduce_fusion_workspace(...) - >>> # ... use workspace ... - >>> destroy_allreduce_fusion_workspace(workspace) - >>> # Or call directly: workspace.destroy() - """ - workspace.destroy() - - # ============================================================================ # MAIN API - NO backend parameter, infers from workspace type # ============================================================================ From 0ee3fd6f06d734d1b4b2e2f9351bd50243b1ffdb Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:41:53 -0800 Subject: [PATCH 27/32] Moved in the helper functions, rejected some patterns for mnnvl --- flashinfer/comm/allreduce.py | 250 ++++++++++++----------------------- 1 file changed, 84 insertions(+), 166 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 54f1b17c8b..af8ae3433a 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -651,183 +651,101 @@ def allreduce_fusion( """ # Auto-detect pattern if not provided if pattern is None: - pattern = _infer_fusion_pattern( - output, residual_in, residual_out, norm_out, quant_out, scale_out - ) + if quant_out is not None: + # Quantization patterns + if norm_out is not None and residual_out is not None: + pattern = AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + else: + pattern = AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 + elif norm_out is not None: + pattern = AllReduceFusionPattern.kARResidualRMSNorm # 1 + else: + pattern = AllReduceFusionPattern.kAllReduce # 0 - # Infer backend from workspace type and dispatch + # Dispatch based on workspace type if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): - return _allreduce_fusion_trtllm( - input=input, - workspace=workspace, + # TensorRT-LLM backend implementation + # Extract shape from 2D input + token_num, hidden_dim = input.shape + + # Allocate output if needed (keep 2D shape) + if output is None: + output = torch.empty_like(input) + + # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API + # The legacy API expects flattened tensors and explicit token_num/hidden_dim + input_flat = input.flatten() + output_flat = output.flatten() + residual_in_flat = residual_in.flatten() if residual_in is not None else None + residual_out_flat = residual_out.flatten() if residual_out is not None else None + norm_out_flat = norm_out.flatten() if norm_out is not None else None + quant_out_flat = quant_out.flatten() if quant_out is not None else None + + # Call legacy API with flattened tensors + # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints + trtllm_allreduce_fusion( + allreduce_in=input_flat, + world_size=workspace.world_size, + world_rank=workspace.rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace.workspace_tensor, launch_with_pdl=launch_with_pdl, - output=output, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - quant_out=quant_out, - scale_out=scale_out, - rms_gamma=rms_gamma, + trigger_completion_at_end=launch_with_pdl, # Same meaning + fp32_acc=fp32_acc, + pattern_code=pattern, # type: ignore[arg-type] + use_oneshot=use_oneshot, + allreduce_out=output_flat, + residual_in=residual_in_flat, + residual_out=residual_out_flat, + norm_out=norm_out_flat, + quant_out=quant_out_flat, + scale_out=scale_out, # scale_out is not reshaped + rms_gamma=rms_gamma, # 1D tensor, no reshape needed rms_eps=rms_eps, scale_factor=scale_factor, - layout_code=layout_code, - pattern=pattern, - use_oneshot=use_oneshot, - fp32_acc=fp32_acc, + layout_code=layout_code, # type: ignore[arg-type] metadata=metadata, ) - elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): - return _allreduce_fusion_mnnvl( - input=input, - workspace=workspace, - launch_with_pdl=launch_with_pdl, - residual_in=residual_in, - residual_out=residual_out, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - ) - else: - raise TypeError( - f"Unknown workspace type: {type(workspace)}. " - f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" - ) - - -# ============================================================================ -# HELPER FUNCTIONS -# ============================================================================ - - -def _infer_fusion_pattern( - output, residual_in, residual_out, norm_out, quant_out, scale_out -) -> int: - """ - Automatically infer fusion pattern from provided tensors. - Returns AllReduceFusionPattern value (as int) based on which output tensors are provided. - """ - - if quant_out is not None: - # Quantization patterns - if norm_out is not None and residual_out is not None: - # Has separate norm output and residual output - return AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 + # Return the most downstream output (already in 2D shape from input views) + if norm_out is not None: + return norm_out + elif quant_out is not None: + return quant_out else: - # Quant without separate outputs - return AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 - elif norm_out is not None: - # RMS Norm without quantization - return AllReduceFusionPattern.kARResidualRMSNorm # 1 - else: - # Just AllReduce - return AllReduceFusionPattern.kAllReduce # 0 + return output + elif isinstance(workspace, MNNVLAllReduceFusionWorkspace): + if ( + pattern != AllReduceFusionPattern.kARResidualRMSNorm + and pattern != AllReduceFusionPattern.kAllReduce + ): + raise ValueError( + f"MNNVL AllReduce+RMS fusion does not support pattern {pattern}" + ) + # MNNVL backend implementation + # Validate required parameters for RMS fusion + if residual_in is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") + if residual_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + ) + if norm_out is None: + raise ValueError( + "MNNVL AllReduce+RMS fusion requires norm_out (normed_output)" + ) + if rms_gamma is None: + raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") -def _allreduce_fusion_trtllm( - input: torch.Tensor, - workspace: TRTLLMAllReduceFusionWorkspace, - launch_with_pdl: bool, - output: Optional[torch.Tensor], - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - quant_out: Optional[torch.Tensor], - scale_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], - rms_eps: float, - scale_factor: Optional[Union[torch.Tensor, float]], - layout_code: Optional[int], - pattern: int, - use_oneshot: Optional[bool], - fp32_acc: bool, - metadata: Optional[dict], -) -> torch.Tensor: - """TensorRT-LLM backend implementation.""" - - # Extract shape from 2D input - token_num, hidden_dim = input.shape - - # Allocate output if needed (keep 2D shape) - if output is None: - output = torch.empty_like(input) - - # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API - # The legacy API expects flattened tensors and explicit token_num/hidden_dim - input_flat = input.flatten() - output_flat = output.flatten() - residual_in_flat = residual_in.flatten() if residual_in is not None else None - residual_out_flat = residual_out.flatten() if residual_out is not None else None - norm_out_flat = norm_out.flatten() if norm_out is not None else None - quant_out_flat = quant_out.flatten() if quant_out is not None else None - - # Call legacy API with flattened tensors - # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints - trtllm_allreduce_fusion( - allreduce_in=input_flat, - world_size=workspace.world_size, - world_rank=workspace.rank, - token_num=token_num, - hidden_dim=hidden_dim, - workspace_ptrs=workspace.workspace_tensor, - launch_with_pdl=launch_with_pdl, - trigger_completion_at_end=launch_with_pdl, # Same meaning - fp32_acc=fp32_acc, - pattern_code=pattern, # type: ignore[arg-type] - use_oneshot=use_oneshot, - allreduce_out=output_flat, - residual_in=residual_in_flat, - residual_out=residual_out_flat, - norm_out=norm_out_flat, - quant_out=quant_out_flat, - scale_out=scale_out, # scale_out is not reshaped - rms_gamma=rms_gamma, # 1D tensor, no reshape needed - rms_eps=rms_eps, - scale_factor=scale_factor, - layout_code=layout_code, # type: ignore[arg-type] - metadata=metadata, - ) - - # Return the most downstream output (already in 2D shape from input views) - if norm_out is not None: - return norm_out - elif quant_out is not None: - return quant_out - else: - return output - + # Call the MNNVL fusion function + raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") -def _allreduce_fusion_mnnvl( - input: torch.Tensor, - workspace: MNNVLAllReduceFusionWorkspace, - launch_with_pdl: bool, - residual_in: Optional[torch.Tensor], - residual_out: Optional[torch.Tensor], - norm_out: Optional[torch.Tensor], - rms_gamma: Optional[torch.Tensor], - rms_eps: float, -) -> torch.Tensor: - """ - MNNVL backend implementation. + return norm_out - Calls trtllm_mnnvl_fused_allreduce_rmsnorm which performs: - 1. AllReduce on input - 2. Add residual - 3. RMSNorm - """ - # Validate required parameters for RMS fusion - if residual_in is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires residual_in") - if residual_out is None: - raise ValueError( - "MNNVL AllReduce+RMS fusion requires residual_out (prenorm_output)" + else: + raise TypeError( + f"Unknown workspace type: {type(workspace)}. " + f"Expected TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace" ) - if norm_out is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires norm_out (normed_output)") - if rms_gamma is None: - raise ValueError("MNNVL AllReduce+RMS fusion requires rms_gamma") - - # Call the MNNVL fusion function - raise NotImplementedError("MNNVL AllReduce+RMS fusion is not implemented") - - return norm_out From 1c2a3420d2966641f0f3b86780a4fc0318c660af Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 1 Dec 2025 13:46:53 -0800 Subject: [PATCH 28/32] Made fusion pattern param mandatory --- flashinfer/comm/allreduce.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index af8ae3433a..5fb28753bf 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -542,6 +542,7 @@ def create_allreduce_fusion_workspace( def allreduce_fusion( input: torch.Tensor, workspace: AllReduceFusionWorkspace, + pattern: int, launch_with_pdl: bool = False, # ===== OUTPUT tensors (pre-allocated, will be filled) ===== output: Optional[torch.Tensor] = None, @@ -556,7 +557,6 @@ def allreduce_fusion( scale_factor: Optional[Union[torch.Tensor, float]] = None, layout_code: Optional[int] = None, # ===== Control parameters ===== - pattern: Optional[int] = None, use_oneshot: Optional[bool] = None, fp32_acc: bool = False, metadata: Optional[dict] = None, @@ -579,6 +579,14 @@ def allreduce_fusion( Args: input: Input tensor [token_num, hidden_dim] workspace: Workspace object (type determines backend) + pattern: Fusion pattern (AllReduceFusionPattern constant, 0-5) + - kAllReduce = 0 + - kARResidualRMSNorm = 1 + - kARResidualRMSNormFP8Quant = 2 + - kARResidualRMSNormFP4Quant = 3 + - kARResidualRMSNormOutFP8Quant = 4 + - kARResidualRMSNormOutFP4Quant = 5 + Note: MNNVL only supports patterns 0 and 1 launch_with_pdl: Use Persistent Dependency Launch # ===== OUTPUT tensors (pre-allocated, filled by function) ===== @@ -596,8 +604,6 @@ def allreduce_fusion( layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] # ===== Control parameters ===== - pattern: Fusion pattern (AllReduceFusionPattern) - If None, auto-detected based on provided output tensors use_oneshot: [trtllm only] Use oneshot strategy vs twoshot If None, uses internal heuristics fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce @@ -626,6 +632,7 @@ def allreduce_fusion( >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNorm, ... launch_with_pdl=True, ... residual_out=prenorm, ... norm_out=normed, @@ -641,6 +648,7 @@ def allreduce_fusion( >>> output = allreduce_fusion( ... input=hidden_states, ... workspace=workspace, + ... pattern=AllReduceFusionPattern.kARResidualRMSNormFP8Quant, ... norm_out=normed, ... quant_out=quant, ... scale_out=scales, @@ -649,19 +657,6 @@ def allreduce_fusion( ... scale_factor=scale_tensor ... ) """ - # Auto-detect pattern if not provided - if pattern is None: - if quant_out is not None: - # Quantization patterns - if norm_out is not None and residual_out is not None: - pattern = AllReduceFusionPattern.kARResidualRMSNormOutFP8Quant # 4 - else: - pattern = AllReduceFusionPattern.kARResidualRMSNormFP8Quant # 2 - elif norm_out is not None: - pattern = AllReduceFusionPattern.kARResidualRMSNorm # 1 - else: - pattern = AllReduceFusionPattern.kAllReduce # 0 - # Dispatch based on workspace type if isinstance(workspace, TRTLLMAllReduceFusionWorkspace): # TensorRT-LLM backend implementation From abb06bc55e8d0d0de41c80082d84c2a47e1a2a49 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 8 Dec 2025 13:50:33 -0800 Subject: [PATCH 29/32] Removed backend_kwargs and changed one_shot/two_shot --- flashinfer/comm/allreduce.py | 9 ++++----- tests/comm/test_trtllm_allreduce_fusion.py | 4 ---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 5fb28753bf..3bd7069a58 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -559,7 +559,6 @@ def allreduce_fusion( # ===== Control parameters ===== use_oneshot: Optional[bool] = None, fp32_acc: bool = False, - metadata: Optional[dict] = None, ) -> torch.Tensor: """ AllReduce + RMSNorm fusion operation. @@ -604,10 +603,10 @@ def allreduce_fusion( layout_code: Scale factor layout (QuantizationSFLayout) [trtllm only] # ===== Control parameters ===== - use_oneshot: [trtllm only] Use oneshot strategy vs twoshot - If None, uses internal heuristics + use_oneshot: Use oneshot strategy vs twoshot + If None, uses internal heuristics. + Note that the MNNVL backend needs to be initialized with a sufficiently large workspace if one_shot is used. fp32_acc: [trtllm only] Use FP32 accumulation for AllReduce - metadata: [trtllm only] Workspace metadata for validation Returns: Output tensor (typically norm_out for fusion cases, output otherwise) @@ -700,7 +699,7 @@ def allreduce_fusion( rms_eps=rms_eps, scale_factor=scale_factor, layout_code=layout_code, # type: ignore[arg-type] - metadata=metadata, + metadata=workspace.metadata, ) # Return the most downstream output (already in 2D shape from input views) diff --git a/tests/comm/test_trtllm_allreduce_fusion.py b/tests/comm/test_trtllm_allreduce_fusion.py index 601bddbb91..dab4877fb9 100644 --- a/tests/comm/test_trtllm_allreduce_fusion.py +++ b/tests/comm/test_trtllm_allreduce_fusion.py @@ -86,8 +86,6 @@ def _run_correctness_worker( topology="single_node", process_group=group, ) - # Extract metadata for compatibility with tests - workspace_metadata = workspace.metadata test_loop = 5 @@ -239,7 +237,6 @@ def _run_correctness_worker( pattern=pattern_code, use_oneshot=use_oneshot, fp32_acc=fp32_acc, - metadata=workspace_metadata, ) # NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern. @@ -304,7 +301,6 @@ def _run_correctness_worker( pattern=pattern_code, use_oneshot=use_oneshot, fp32_acc=fp32_acc, - metadata=workspace_metadata, ) # replay g.replay() From c2a311bc22ded4cb12c0c7d3f3970cd11fbe28be Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 8 Dec 2025 14:02:19 -0800 Subject: [PATCH 30/32] Ensured that we can flattend the I/O tensors. --- flashinfer/comm/allreduce.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 3bd7069a58..18a62ea83e 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -668,12 +668,31 @@ def allreduce_fusion( # Flatten all tensors to 1D for legacy trtllm_allreduce_fusion API # The legacy API expects flattened tensors and explicit token_num/hidden_dim - input_flat = input.flatten() - output_flat = output.flatten() - residual_in_flat = residual_in.flatten() if residual_in is not None else None - residual_out_flat = residual_out.flatten() if residual_out is not None else None - norm_out_flat = norm_out.flatten() if norm_out is not None else None - quant_out_flat = quant_out.flatten() if quant_out is not None else None + # We require contiguous tensors so that view(-1) creates a view (not a copy), + # ensuring writes to the flattened tensors are reflected in the original 2D tensors + def _flatten_checked(t, name): + if not t.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + return t.view(-1) + + input_flat = _flatten_checked(input, "input") + output_flat = _flatten_checked(output, "output") + residual_in_flat = ( + _flatten_checked(residual_in, "residual_in") + if residual_in is not None + else None + ) + residual_out_flat = ( + _flatten_checked(residual_out, "residual_out") + if residual_out is not None + else None + ) + norm_out_flat = ( + _flatten_checked(norm_out, "norm_out") if norm_out is not None else None + ) + quant_out_flat = ( + _flatten_checked(quant_out, "quant_out") if quant_out is not None else None + ) # Call legacy API with flattened tensors # Note: pattern and layout_code are ints but legacy API uses pseudo-type hints From 686db76f7a1a224ef592347162e88132892b8b4e Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Tue, 9 Dec 2025 11:24:01 -0800 Subject: [PATCH 31/32] Moved out the workspace base class, refactored for mnnvl --- flashinfer/comm/__init__.py | 4 +- flashinfer/comm/allreduce.py | 162 ++++------------------ flashinfer/comm/trtllm_mnnvl_ar.py | 28 +++- flashinfer/comm/workspace_base.py | 84 +++++++++++ tests/comm/test_trtllm_mnnvl_allreduce.py | 4 +- 5 files changed, 140 insertions(+), 142 deletions(-) create mode 100644 flashinfer/comm/workspace_base.py diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 8a1d6f9f0c..6d945980be 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -41,7 +41,9 @@ # Unified AllReduce Fusion API from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace -from .allreduce import MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace +from .trtllm_mnnvl_ar import ( + MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace, +) from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace from .allreduce import allreduce_fusion as allreduce_fusion from .allreduce import ( diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index 18a62ea83e..b6e02de652 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -48,8 +48,8 @@ >>> workspace.destroy() """ -from typing import Union, Literal, Optional, Tuple, List, cast -from abc import ABC, abstractmethod +from typing import Union, Literal, Optional, Tuple, List, cast, Any +from .workspace_base import AllReduceFusionWorkspace import torch @@ -59,18 +59,22 @@ from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import check_trtllm_allreduce_fusion_workspace_metadata +from .mapping import Mapping + +from .mnnvl import CommBackend + # Note: AllReduceFusionPattern and QuantizationSFLayout are pseudo-types (classes with int constants) # Import them for runtime use but type hint as int for mypy compatibility from .trtllm_ar import AllReduceFusionPattern - +from .trtllm_mnnvl_ar import MNNVLAllReduceFusionWorkspace # ============================================================================ -# WORKSPACE BASE CLASS AND IMPLEMENTATIONS +# WORKSPACE IMPLEMENTATIONS # ============================================================================ # # Workspace classes wrap the underlying backend workspace implementations: # - TRTLLMAllReduceFusionWorkspace: Wraps trtllm_create_ipc_workspace_for_all_reduce_fusion -# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (to be implemented) +# - MNNVLAllReduceFusionWorkspace: Wraps MNNVL workspace (see trtllm_mnnvl_ar.py) # # Each workspace: # 1. Calls the backend-specific workspace creation function in __init__ @@ -80,59 +84,6 @@ # ============================================================================ -class AllReduceFusionWorkspace(ABC): - """Base class for AllReduce fusion workspaces.""" - - def __init__(self, world_size: int, rank: int): - self.world_size = world_size - self.rank = rank - self._destroyed = False - - @property - @abstractmethod - def backend(self) -> str: - """Return backend name.""" - pass - - @abstractmethod - def destroy(self) -> None: - """ - Destroy workspace and free resources. - - This should be called explicitly when done using the workspace. - Prefer using AllReduceFusionContext context manager for automatic cleanup. - """ - pass - - def __del__(self): - """ - Destructor - safety net if destroy() wasn't called explicitly. - - Warns if cleanup wasn't done properly. Not recommended to rely on this - as __del__ timing is non-deterministic and can cause issues with - distributed/CUDA resources. - """ - if not self._destroyed: - import warnings - - warnings.warn( - f"{self.__class__.__name__} was not explicitly destroyed. " - f"Call workspace.destroy() or use AllReduceFusionContext to ensure " - f"proper cleanup of distributed/CUDA resources.", - ResourceWarning, - stacklevel=2, - ) - try: - self.destroy() - except Exception as e: - # Can't raise in __del__, just warn - warnings.warn( - f"Error during automatic cleanup of {self.__class__.__name__}: {e}", - ResourceWarning, - stacklevel=2, - ) - - class TRTLLMAllReduceFusionWorkspace(AllReduceFusionWorkspace): """TensorRT-LLM workspace for AllReduce fusion.""" @@ -191,12 +142,17 @@ def __getattr__(self, name): ) return getattr(self._internal_workspace, name) - def is_sufficient_for( - self, token_num: int, hidden_dim: int, tp_size: int, dtype: torch.dtype + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, ) -> bool: try: check_trtllm_allreduce_fusion_workspace_metadata( - token_num, hidden_dim, tp_size, dtype, self.metadata + num_tokens, hidden_dim, tp_size, dtype, self.metadata ) return True except ValueError as e: @@ -212,71 +168,6 @@ def destroy(self) -> None: self._destroyed = True -class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): - """MNNVL workspace for AllReduce fusion.""" - - def __init__( - self, - world_size: int, - rank: int, - max_token_num: int, - hidden_dim: int, - dtype: torch.dtype, - **kwargs, - ): - """ - Create MNNVL AllReduce fusion workspace. - - Args: - world_size: Number of ranks - rank: Current rank - max_token_num: Maximum number of tokens - hidden_dim: Hidden dimension size - dtype: Data type - **kwargs: Additional arguments for workspace creation - """ - super().__init__(world_size, rank) - # TODO: Import and call the actual MNNVL workspace creation function - # For now, raise NotImplementedError - raise NotImplementedError( - "MNNVL workspace creation needs to be implemented in trtllm_mnnvl_ar.py. " - "Expected function: create_mnnvl_allreduce_fusion_workspace" - ) - - # When implemented, should look like: - # self._internal_workspace = create_mnnvl_allreduce_fusion_workspace( - # world_size=world_size, - # rank=rank, - # max_token_num=max_token_num, - # hidden_dim=hidden_dim, - # dtype=dtype, - # **kwargs, - # ) - # - # # Store essential attributes for easy access - # self.multicast_buffer_ptr = self._internal_workspace.multicast_buffer_ptr - # self.buffer_ptrs_dev = self._internal_workspace.buffer_ptrs_dev - # self.unicast_ptr = self._internal_workspace.unicast_ptr - # self.buffer_M = self._internal_workspace.buffer_M - # self.buffer_flags = self._internal_workspace.buffer_flags - - @property - def backend(self) -> str: - return "mnnvl" - - def destroy(self) -> None: - """Destroy workspace and free resources.""" - if self._destroyed: - return # Already destroyed, nothing to do - - # TODO: Implement MNNVL workspace destruction - self._destroyed = True - raise NotImplementedError("MNNVL workspace destruction not yet implemented") - # from .trtllm_mnnvl_ar import destroy_mnnvl_allreduce_fusion_workspace - # destroy_mnnvl_allreduce_fusion_workspace(self._internal_workspace) - # self._destroyed = True - - # ============================================================================ # BACKEND CHECKS - Hard requirements for decorator # ============================================================================ @@ -427,7 +318,8 @@ def create_allreduce_fusion_workspace( dtype: torch.dtype = None, topology: str = "single_node", process_group: Optional["torch.distributed.ProcessGroup"] = None, - **backend_kwargs, # TODO(nvmbreughe): remove this + gpus_per_node: int = None, + comm_backend: Optional[CommBackend] = None, ) -> AllReduceFusionWorkspace: """ Create workspace for AllReduce fusion operations. @@ -459,8 +351,9 @@ def create_allreduce_fusion_workspace( topology: Network topology hint for backend selection "single_node" - All ranks on one node (default) "multi_node" - Ranks span multiple nodes - process_group: PyTorch distributed process group - **backend_kwargs: Additional backend-specific arguments + process_group: PyTorch distributed process group (for trtllm backend). + gpus_per_node: Number of GPUs per node (for multi-node topology). + comm_backend: Communication backend to use (for multi-node topology). Returns: Workspace object (TRTLLMAllReduceFusionWorkspace or MNNVLAllReduceFusionWorkspace) @@ -522,13 +415,18 @@ def create_allreduce_fusion_workspace( ) elif actual_backend == "mnnvl": - return MNNVLAllReduceFusionWorkspace( + mapping = Mapping( world_size=world_size, rank=rank, - max_token_num=max_token_num, + gpus_per_node=gpus_per_node, + tp_size=world_size, + ) + return MNNVLAllReduceFusionWorkspace( + mapping=mapping, + max_num_tokens=max_token_num, hidden_dim=hidden_dim, dtype=dtype, - **backend_kwargs, + comm_backend=comm_backend, ) else: raise RuntimeError(f"Unknown backend: {actual_backend}") diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 308ee4bda6..827883c2f6 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -18,6 +18,7 @@ from ..jit import gen_trtllm_mnnvl_comm_module from ..utils import register_custom_op from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend +from .workspace_base import AllReduceFusionWorkspace def mpi_barrier(): @@ -47,7 +48,7 @@ def select_strategy( MNNVL_ONE_SHOT_THRESHOLD = 64 * 1024 * 8 * 2 -class MNNVLAllreduceFusionWorkspace: +class MNNVLAllReduceFusionWorkspace(AllReduceFusionWorkspace): NUM_LAMPORT_BUFFERS = 3 def __init__( @@ -222,6 +223,19 @@ def get_required_buffer_size_bytes( ) return buffer_size + @property + def backend(self) -> str: + return "mnnvl" + + @property + def destroy(self) -> None: + """Destroy workspace and free resources.""" + if self._destroyed: + return # Already destroyed, nothing to do + + print("TODO: Implement this properly!") + self._destroyed = True + @functools.cache def get_trtllm_mnnvl_comm_module(): @@ -307,7 +321,7 @@ def trtllm_mnnvl_allreduce_fusion( def trtllm_mnnvl_allreduce( input: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + workspace: MNNVLAllReduceFusionWorkspace, launch_with_pdl: bool, output: Optional[torch.Tensor] = None, strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO, @@ -327,7 +341,7 @@ def trtllm_mnnvl_allreduce( Args: input: Local Input Shard [num_tokens, hidden_dim] - workspace: MNNVLAllreduceFusionWorkspace + workspace: MNNVLAllReduceFusionWorkspace launch_with_pdl: Whether to launch with PDL output: Output tensor to store the result, empty tensor will be created if not provided. strategy: MNNVLAllreduceFusionStrategy. Internal heuristics will be used if not provided. @@ -387,7 +401,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: torch.Tensor, residual_in: torch.Tensor, gamma: torch.Tensor, - workspace: MNNVLAllreduceFusionWorkspace, + workspace: MNNVLAllReduceFusionWorkspace, epsilon: Optional[float] = None, output: Optional[torch.Tensor] = None, residual_out: Optional[torch.Tensor] = None, @@ -404,7 +418,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( input: Input tensor [num_tokens, hidden_dim] residual_in: Residual input tensor [num_tokens, hidden_dim] gamma: Gamma tensor [hidden_dim] - workspace: MNNVLAllreduceFusionWorkspace + workspace: MNNVLAllReduceFusionWorkspace epsilon: The epsilon parameter for RMSNorm, torch.finfo.eps will be used if not provided. output: Output tensor for normalized results [num_tokens, hidden_dim], empty tensor will be created if not provided. residual_out: Residual output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided. @@ -479,7 +493,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm( # Legacy API that has been deprecated; Left for backward compatibility @deprecated( - "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllreduceFusionWorkspace class to manage the workspace instead" + "get_allreduce_mnnvl_workspace is deprecated, use MNNVLAllReduceFusionWorkspace class to manage the workspace instead" ) def get_allreduce_mnnvl_workspace( mapping: Mapping, @@ -522,7 +536,7 @@ def get_allreduce_mnnvl_workspace( ) * (lcm_hidden_dim * stride) # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. - workspace = MNNVLAllreduceFusionWorkspace( + workspace = MNNVLAllReduceFusionWorkspace( mapping, buffer_size_in_bytes=buffer_size_in_bytes, comm_backend=comm_backend_for_handle_transfer, diff --git a/flashinfer/comm/workspace_base.py b/flashinfer/comm/workspace_base.py new file mode 100644 index 0000000000..cfd8dbe72f --- /dev/null +++ b/flashinfer/comm/workspace_base.py @@ -0,0 +1,84 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from abc import ABC, abstractmethod +from typing import Optional, Any + +import torch + + +class AllReduceFusionWorkspace(ABC): + """Base class for AllReduce fusion workspaces.""" + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + self._destroyed = False + + @property + @abstractmethod + def backend(self) -> str: + """Return backend name.""" + pass + + @abstractmethod + def destroy(self) -> None: + """ + Destroy workspace and free resources. + + This should be called explicitly when done using the workspace. + Prefer using AllReduceFusionContext context manager for automatic cleanup. + """ + pass + + @abstractmethod + def is_buffer_size_sufficient( + self, + tp_size: int, + num_tokens: int, + hidden_dim: int, + dtype: torch.dtype, + use_oneshot: Optional[Any] = None, + ) -> bool: + pass + + def __del__(self): + """ + Destructor - safety net if destroy() wasn't called explicitly. + + Warns if cleanup wasn't done properly. Not recommended to rely on this + as __del__ timing is non-deterministic and can cause issues with + distributed/CUDA resources. + """ + if not self._destroyed: + import warnings + + warnings.warn( + f"{self.__class__.__name__} was not explicitly destroyed. " + f"Call workspace.destroy() or use AllReduceFusionContext to ensure " + f"proper cleanup of distributed/CUDA resources.", + ResourceWarning, + stacklevel=2, + ) + try: + self.destroy() + except Exception as e: + # Can't raise in __del__, just warn + warnings.warn( + f"Error during automatic cleanup of {self.__class__.__name__}: {e}", + ResourceWarning, + stacklevel=2, + ) diff --git a/tests/comm/test_trtllm_mnnvl_allreduce.py b/tests/comm/test_trtllm_mnnvl_allreduce.py index 78ce392b7a..ce7880e406 100644 --- a/tests/comm/test_trtllm_mnnvl_allreduce.py +++ b/tests/comm/test_trtllm_mnnvl_allreduce.py @@ -22,7 +22,7 @@ def row_linear_residual_norm_fusion_forward( mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], - workspace: trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace, + workspace: trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace, ): tensor_parallel_rank = mapping.tp_rank MPI.COMM_WORLD.barrier() @@ -341,7 +341,7 @@ def run_mnnvl_ar_full( ) else: - workspace = trtllm_mnnvl_ar.MNNVLAllreduceFusionWorkspace( + workspace = trtllm_mnnvl_ar.MNNVLAllReduceFusionWorkspace( mapping, max_num_tokens=max(seq_lens), hidden_dim=hidden_size, From 10554e5fd38f464b9e94793d9b12e39d6ceb6de2 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Tue, 9 Dec 2025 15:08:38 -0800 Subject: [PATCH 32/32] Removed backend decorator as it is not appicable with workspace creation --- flashinfer/comm/allreduce.py | 114 ++++++++++++++++++------------ flashinfer/comm/workspace_base.py | 5 ++ 2 files changed, 73 insertions(+), 46 deletions(-) diff --git a/flashinfer/comm/allreduce.py b/flashinfer/comm/allreduce.py index b6e02de652..234058059d 100644 --- a/flashinfer/comm/allreduce.py +++ b/flashinfer/comm/allreduce.py @@ -53,7 +53,6 @@ import torch -from ..utils import backend_requirement, supported_compute_capability from .trtllm_ar import trtllm_allreduce_fusion from .trtllm_ar import trtllm_create_ipc_workspace_for_all_reduce_fusion from .trtllm_ar import trtllm_destroy_ipc_workspace_for_all_reduce_fusion @@ -161,7 +160,7 @@ def is_buffer_size_sufficient( def destroy(self) -> None: """Destroy workspace and free resources.""" - if self._destroyed: + if self._destroyed is True: return # Already destroyed, nothing to do trtllm_destroy_ipc_workspace_for_all_reduce_fusion(self.ipc_handles) @@ -169,11 +168,10 @@ def destroy(self) -> None: # ============================================================================ -# BACKEND CHECKS - Hard requirements for decorator +# BACKEND CHECKS - Hard requirements for backend selection # ============================================================================ -@supported_compute_capability([80, 86, 89, 90, 100]) def _trtllm_workspace_check( backend: str, world_size: int, @@ -188,9 +186,8 @@ def _trtllm_workspace_check( Check if trtllm backend CAN be used for workspace creation. Hard requirements: - - SM80+ compute capability (checked by decorator) - - Single-node topology - - Module availability + - Single-node topology (multi-node not supported) + """ # trtllm is optimized for single-node if topology == "multi_node": @@ -199,7 +196,6 @@ def _trtllm_workspace_check( return True -@supported_compute_capability([90, 100]) def _mnnvl_workspace_check( backend: str, world_size: int, @@ -213,20 +209,13 @@ def _mnnvl_workspace_check( """ Check if mnnvl backend CAN be used for workspace creation. - Hard requirements: - - SM90+ compute capability (checked by decorator) - - Multi-node topology - - Module availability """ - # MNNVL is designed for multi-node - if topology == "single_node": - return False return True # ============================================================================ -# HEURISTIC - Performance-based selection for decorator +# HEURISTIC - Performance-based backend selection # ============================================================================ @@ -239,6 +228,7 @@ def _workspace_creation_heuristic( hidden_dim: int, dtype: torch.dtype, topology: str, + # TODO(nvmbreughe): Remove this **kwargs, ) -> list[str]: """ @@ -276,39 +266,33 @@ def _workspace_creation_heuristic( return ["mnnvl"] # Single-node scenarios - problem_size = max_token_num * hidden_dim + return ["mnnvl"] + # problem_size = max_token_num * hidden_dim - # Large problems (>4M elements): trtllm optimized for throughput - if problem_size > 4 * 1024 * 1024: - if "trtllm" in suitable_backends: - return ["trtllm"] + # # Large problems (>4M elements): trtllm optimized for throughput + # if problem_size > 4 * 1024 * 1024: + # if "trtllm" in suitable_backends: + # return ["trtllm"] - # Small token counts (<128): trtllm one-shot has better latency - if max_token_num < 128: - if "trtllm" in suitable_backends: - return ["trtllm"] + # # Small token counts (<128): trtllm one-shot has better latency + # if max_token_num < 128: + # if "trtllm" in suitable_backends: + # return ["trtllm"] - # Small world sizes (<=4): trtllm one-shot efficient - if world_size <= 4: - if "trtllm" in suitable_backends: - return ["trtllm"] + # # Small world sizes (<=4): trtllm one-shot efficient + # if world_size <= 4: + # if "trtllm" in suitable_backends: + # return ["trtllm"] - # Default: return first available - return [suitable_backends[0]] + # # Default: return first available + # return [suitable_backends[0]] # ============================================================================ -# WORKSPACE CREATION - Uses decorator for all validation +# WORKSPACE CREATION # ============================================================================ -@backend_requirement( - backend_checks={ - "trtllm": _trtllm_workspace_check, - "mnnvl": _mnnvl_workspace_check, - }, - heuristic_func=_workspace_creation_heuristic, -) def create_allreduce_fusion_workspace( backend: Literal["trtllm", "mnnvl", "auto"] = "auto", world_size: int = None, @@ -324,7 +308,7 @@ def create_allreduce_fusion_workspace( """ Create workspace for AllReduce fusion operations. - Backend selection (checks + heuristics) handled by @backend_requirement decorator. + Backend selection uses topology-based checks and heuristics. **Important: Workspace Reusability** The workspace is allocated based on the total size (max_token_num * hidden_dim * dtype_size). @@ -393,13 +377,51 @@ def create_allreduce_fusion_workspace( ... ) >>> print(workspace.backend) # "mnnvl" """ - # Decorator has validated backend - now create workspace - # If backend="auto", decorator has selected the best one and stored it - - # Get actual backend (decorator resolved "auto" to concrete backend) + if gpus_per_node is None: + gpus_per_node = min(torch.cuda.device_count(), world_size) + # Determine the actual backend to use if backend == "auto": - # Decorator stored the selected backend in suitable_auto_backends - actual_backend = create_allreduce_fusion_workspace.suitable_auto_backends[0] + # Find suitable backends based on topology (anny CC check needs to be checked at kernel runtime, since there are no tensor available at this point) + suitable_backends = [] + if _trtllm_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("trtllm") + if _mnnvl_workspace_check( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ): + suitable_backends.append("mnnvl") + + if not suitable_backends: + raise ValueError( + f"No suitable backend found for topology={topology}. " + f"trtllm requires single_node topology, mnnvl works with both." + ) + + # Apply heuristic to select best backend + selected = _workspace_creation_heuristic( + suitable_backends=suitable_backends, + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + topology=topology, + ) + actual_backend = selected[0] if selected else suitable_backends[0] else: actual_backend = backend diff --git a/flashinfer/comm/workspace_base.py b/flashinfer/comm/workspace_base.py index cfd8dbe72f..5de8d07483 100644 --- a/flashinfer/comm/workspace_base.py +++ b/flashinfer/comm/workspace_base.py @@ -23,6 +23,11 @@ class AllReduceFusionWorkspace(ABC): """Base class for AllReduce fusion workspaces.""" + # Explicit type annotations for mypy (needed due to __getattr__ in subclasses) + world_size: int + rank: int + _destroyed: bool + def __init__(self, world_size: int, rank: int): self.world_size = world_size self.rank = rank