Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
63b7683
Staging initial integration of kernel code.
timlee0212 Nov 18, 2025
1230273
Initial python interface. Need adjustment.
timlee0212 Nov 19, 2025
874c228
Refactor the interface.
timlee0212 Nov 19, 2025
17a1292
Staging changes, result is wrong.
timlee0212 Nov 20, 2025
4caf71a
Passing the test.
timlee0212 Nov 20, 2025
9a6beec
Remove debug prints and add compatability interface.
timlee0212 Nov 20, 2025
a4d1a17
Incorporate 2056; Add test for legacy APIs
timlee0212 Nov 20, 2025
01564e9
Address review comments.
timlee0212 Nov 20, 2025
775918d
Address review comments.
timlee0212 Nov 20, 2025
45a5b82
Address review comments.
timlee0212 Nov 21, 2025
815aaf3
Rounding up workspace size according to allocation (page size).
timlee0212 Nov 21, 2025
68a9b9b
Fix rebasing errors.
timlee0212 Nov 26, 2025
9e11752
Fix rebase errors.
timlee0212 Nov 27, 2025
4a5faef
Refactor mcast device memory.
timlee0212 Nov 27, 2025
03700a2
Adapt the workspace creation API for unified backend。
timlee0212 Dec 4, 2025
ff84e87
Use threshold only for onshot workspace size calculation.
timlee0212 Dec 4, 2025
1342a43
Merge branch 'main' into IKL-201
timlee0212 Dec 9, 2025
0f5b7b3
Document worjkspace creation behavior.
timlee0212 Dec 9, 2025
b9f4329
Added first non-working version
nvmbreughe Nov 21, 2025
000271e
Polished the interface
nvmbreughe Nov 21, 2025
0141ae0
Removed device param
nvmbreughe Nov 21, 2025
4ea74f3
Updated test with legacy vs unified API
nvmbreughe Nov 21, 2025
2e2fb25
Fixed unit test
nvmbreughe Nov 24, 2025
fe8b88c
Relaxed check on trtllm_ar
nvmbreughe Nov 24, 2025
b3b19a5
Made metadata mandatory in unified API, added workspace check functions
nvmbreughe Nov 24, 2025
ca83f12
Merged dtype and use_fp32_lamport params
nvmbreughe Nov 24, 2025
b031474
removed useless function
nvmbreughe Dec 1, 2025
0ee3fd6
Moved in the helper functions, rejected some patterns for mnnvl
nvmbreughe Dec 1, 2025
1c2a342
Made fusion pattern param mandatory
nvmbreughe Dec 1, 2025
abb06bc
Removed backend_kwargs and changed one_shot/two_shot
nvmbreughe Dec 8, 2025
c2a311b
Ensured that we can flattend the I/O tensors.
nvmbreughe Dec 8, 2025
686db76
Moved out the workspace base class, refactored for mnnvl
nvmbreughe Dec 9, 2025
10554e5
Removed backend decorator as it is not appicable with workspace creation
nvmbreughe Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 69 additions & 56 deletions csrc/trtllm_mnnvl_allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,77 +26,90 @@ using tvm::ffi::Optional;
} \
}()

void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev,
int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks,
int64_t rank, bool wait_for_results, bool launch_with_pdl,
Optional<TensorView> out) {
ffi::CUDADeviceGuard device_guard(in.device().device_id);
auto stream = get_stream(in.device());
void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr,
int64_t buffer_ptrs_dev, int64_t buffer_ptr_local,
TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank,
bool rmsnorm_fusion, bool launch_with_pdl, bool use_oneshot,
TensorView output, Optional<TensorView> residual_out,
Optional<TensorView> residual_in, Optional<TensorView> gamma,
Optional<double> epsilon) {
ffi::CUDADeviceGuard device_guard(input.device().device_id);
auto stream = get_stream(input.device());

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] {
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(input.dtype(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);
int64_t num_tokens = input.size(0);
int64_t token_dim = input.size(1);

// Validate input parameters
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float2) / sizeof(c_type);
TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float4) / sizeof(c_type)), 0)
<< "token_dim must be divisible by " << sizeof(float4) / sizeof(c_type);
TVM_FFI_ICHECK(output.size(0) == input.size(0) && output.size(1) == input.size(1))
<< "output shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << output.size(0) << ", " << output.size(1) << ")";
TVM_FFI_ICHECK(nranks >= 2 && nranks <= 64)
<< "nranks must be between 2 and 64, got " << nranks;
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
<< "rank must be between 0 and nranks-1, got " << rank;
TVM_FFI_ICHECK(out.has_value() || !wait_for_results)
<< "out tensor must be provided if wait_for_results is true";
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
epsilon.has_value()) ||
!rmsnorm_fusion)
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
"true";

if (rmsnorm_fusion) {
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
residual_in.value().size(1) == token_dim)
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
<< ")";
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
residual_out.value().size(1) == token_dim)
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
<< ") but got (" << residual_out.value().size(0) << ", " << residual_out.value().size(1)
<< ")";
TVM_FFI_ICHECK(gamma.value().size(0) == token_dim)
<< "gamma must have the same shape as token dimension (" << token_dim << ") but got ("
<< gamma.value().size(0) << ")";
}

// Create the parameters struct
AllReduceParams<c_type> params;
params.nranks = nranks;
params.rank = rank;
params.buffer_M = buffer_M;
params.num_tokens = num_tokens;
params.token_dim = token_dim;
params.buffer_ptrs_dev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.multicast_ptr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.buffer_flags = buffer_flags_mnnvl.data_ptr();
params.wait_for_results = wait_for_results;
params.launch_with_pdl = launch_with_pdl;
params.input = in.data_ptr();
params.output = out.has_value() ? out.value().data_ptr() : nullptr;
params.stream = stream;
AllReduceFusionParams params;

auto status = twoshot_allreduce_dispatch_world_size<c_type>(params);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_allreduce_dispatch_world_size failed with error code "
<< cudaGetErrorString(status);
});
}
// Aux Information
params.nRanks = nranks;
params.rank = rank;
params.numTokens = num_tokens;
params.tokenDim = token_dim;
params.bufferPtrsDev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.bufferPtrLocal = reinterpret_cast<void*>(buffer_ptr_local);
params.multicastPtr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.bufferFlags = reinterpret_cast<uint32_t*>(buffer_flags_mnnvl.data_ptr());
params.rmsNormFusion = rmsnorm_fusion;
params.launchWithPdl = launch_with_pdl;

void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output,
TensorView normed_output, TensorView gamma, double epsilon,
TensorView residual, TensorView buffer_flags, bool launch_with_pdl) {
ffi::CUDADeviceGuard device_guard(prenorm_output.device().device_id);
auto stream = get_stream(prenorm_output.device());
// input data
params.input = const_cast<void const*>(input.data_ptr());
params.residualIn =
residual_in.has_value() ? const_cast<void const*>(residual_in.value().data_ptr()) : nullptr;
params.gamma = gamma.has_value() ? const_cast<void const*>(gamma.value().data_ptr()) : nullptr;
params.epsilon = epsilon.has_value() ? epsilon.value() : 1e-5;

DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] {
// Create the parameters struct
RMSNormParams<c_type> params;
params.residual_output = prenorm_output.data_ptr();
params.output = normed_output.data_ptr();
params.input = reinterpret_cast<void const*>(multicast_buffer_ptr);
params.gamma = gamma.data_ptr();
params.epsilon = epsilon;
params.residual = residual.data_ptr();
params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
params.batch = normed_output.size(0);
params.hidden_dim = normed_output.size(1);
// output data
params.output = const_cast<void*>(output.data_ptr());
params.residualOut =
residual_out.has_value() ? const_cast<void*>(residual_out.value().data_ptr()) : nullptr;
params.stream = stream;
params.launch_with_pdl = launch_with_pdl;
auto status = twoshot_rmsnorm_dispatch_hidden_dim<c_type>(params);

cudaError_t status;
if (use_oneshot) {
status = oneshotAllreduceFusionDispatch<c_type>(params);
} else {
status = twoshotAllreduceFusionDispatch<c_type>(params);
}
TVM_FFI_ICHECK(status == cudaSuccess)
<< "twoshot_rmsnorm_dispatch_hidden_dim failed with error code "
<< cudaGetErrorString(status);
<< "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status);
});
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_all_reduce, trtllm_mnnvl_all_reduce);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_rmsnorm, trtllm_mnnvl_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mnnvl_allreduce_fusion, trtllm_mnnvl_allreduce_fusion);
11 changes: 11 additions & 0 deletions flashinfer/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,15 @@
from .vllm_ar import register_buffer as vllm_register_buffer
from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers

# Unified AllReduce Fusion API
from .allreduce import AllReduceFusionWorkspace as AllReduceFusionWorkspace
from .trtllm_mnnvl_ar import (
MNNVLAllReduceFusionWorkspace as MNNVLAllReduceFusionWorkspace,
)
from .allreduce import TRTLLMAllReduceFusionWorkspace as TRTLLMAllReduceFusionWorkspace
from .allreduce import allreduce_fusion as allreduce_fusion
from .allreduce import (
create_allreduce_fusion_workspace as create_allreduce_fusion_workspace,
)

# from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
Loading