From bde2e87c5e9de484a7b3792d41148a94a5d38877 Mon Sep 17 00:00:00 2001 From: Dongmin Ra Date: Wed, 1 Oct 2025 13:22:44 +0900 Subject: [PATCH 1/3] fix inconsistency in internode dispatch/combine --- .../dispatch_combine/test_dispatch_combine.py | 2 +- .../ops/dispatch_combine/dispatch_combine.hpp | 9 +- python/mori/ops/dispatch_combine.py | 8 +- src/ops/dispatch_combine/dispatch_combine.cpp | 6 +- src/ops/dispatch_combine/internode.hpp | 20 +- src/ops/dispatch_combine/intranode.hpp | 4 +- src/pybind/mori.cpp | 8 +- ...ispatch_combine_internode_inconsistency.py | 255 ++++++++++++++++++ 8 files changed, 286 insertions(+), 26 deletions(-) create mode 100644 tests/python/ops/test_dispatch_combine_internode_inconsistency.py diff --git a/examples/ops/dispatch_combine/test_dispatch_combine.py b/examples/ops/dispatch_combine/test_dispatch_combine.py index c444a799..8af67d57 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine.py +++ b/examples/ops/dispatch_combine/test_dispatch_combine.py @@ -244,7 +244,7 @@ def run_test_once(self, op, test_data): print("Dispatch Pass") total_recv_num_token = dispatch_recv_num_token[0].item() - combine_input = op.get_registered_input_buffer(self.config.data_type) + combine_input = op.get_registered_combine_input_buffer(self.config.data_type) combine_input[:total_recv_num_token, :].copy_( dispatch_output[:total_recv_num_token, :] ) diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 3cdd773b..f19f415f 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -169,7 +169,8 @@ class EpDispatchCombineHandle { uint8_t* scalesBuf{nullptr}; // Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output - mori::application::SymmMemObjPtr shmemInpTokMemObj; + mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj; + mori::application::SymmMemObjPtr combineShmemInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; @@ -227,7 +228,8 @@ struct EpDispatchCombineArgs { T* outTokenBuf{nullptr}; float* weightsBuf{nullptr}; uint8_t* scalesBuf{nullptr}; - mori::application::SymmMemObjPtr shmemInpTokMemObj; + mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj; + mori::application::SymmMemObjPtr combineShmemInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; mori::application::SymmMemObjPtr shmemInpWeightsMemObj; @@ -270,7 +272,8 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.scalesBuf = handle.scalesBuf; args.destPeTokenCounter = handle.destPeTokenCounter; args.localPeTokenCounter = handle.localPeTokenCounter; - args.shmemInpTokMemObj = handle.shmemInpTokMemObj; + args.dispatchShmemInpTokMemObj = handle.dispatchShmemInpTokMemObj; + args.combineShmemInpTokMemObj = handle.combineShmemInpTokMemObj; args.shmemOutTokMemObj = handle.shmemOutTokMemObj; args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj; args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj; diff --git a/python/mori/ops/dispatch_combine.py b/python/mori/ops/dispatch_combine.py index 32904ccb..77a2723a 100644 --- a/python/mori/ops/dispatch_combine.py +++ b/python/mori/ops/dispatch_combine.py @@ -90,12 +90,12 @@ def __init__(self, config): self._get_dispatch_receiver_token_idx_map_func = _cpp_dispatch_combine_factory( "get_dispatch_receiver_token_idx_map" ) - self._get_registered_input_buffer = _cpp_dispatch_combine_factory( - "get_registered_input_buffer" + self._get_registered_combine_input_buffer = _cpp_dispatch_combine_factory( + "get_registered_combine_input_buffer" ) - def get_registered_input_buffer(self, dtype: torch.dtype): - return self._get_registered_input_buffer(self._handle, dtype) + def get_registered_combine_input_buffer(self, dtype: torch.dtype): + return self._get_registered_combine_input_buffer(self._handle, dtype) def dispatch( self, diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 7c2dbe00..465fa42b 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -70,7 +70,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { (config.hiddenDim * config.maxTokenTypeSize + (sizeof(float) + sizeof(index_t)) * config.numExpertPerToken + config.scaleDim * config.scaleTypeSize); - shmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + dispatchShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + combineShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); @@ -90,7 +91,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { } void EpDispatchCombineHandle::FinalizeShmemBuf() { - ShmemFree(shmemInpTokMemObj->localPtr); + ShmemFree(dispatchShmemInpTokMemObj->localPtr); + ShmemFree(combineShmemInpTokMemObj->localPtr); ShmemFree(shmemOutTokMemObj->localPtr); ShmemFree(shmemStagingTokMemObj->localPtr); ShmemFree(shmemInpWeightsMemObj->localPtr); diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index bf8f3e1f..4a9c8469 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -171,7 +171,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { tokenId * config.scaleDim * config.scaleTypeSize, config.scaleDim * config.scaleTypeSize); } - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.dispatchShmemInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset, destPe); } @@ -202,7 +202,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * stagingOffset; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * stagingOffset; - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.dispatchShmemInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * stagingOffset, destPe); @@ -297,23 +297,23 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset; core::WarpCopy(args.shmemOutTokMemObj->template GetAs() + localTokenOffset, - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset, + args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset, config.hiddenDim * sizeof(T)); core::WarpCopy( args.shmemOutWeightsMemObj->template GetAs() + localTokenIdx * config.numExpertPerToken * sizeof(float), - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + weightOffset, + args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + weightOffset, config.numExpertPerToken * sizeof(float)); core::WarpCopy( args.shmemOutIndicesMemObj->template GetAs() + localTokenIdx * config.numExpertPerToken * sizeof(index_t), - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + indicesOffset, + args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + indicesOffset, config.numExpertPerToken * sizeof(index_t)); if (args.scalesBuf && (config.scaleDim > 0) && (config.scaleTypeSize > 0)) { core::WarpCopy( args.shmemOutScalesMemObj->template GetAs() + localTokenIdx * config.scaleDim * config.scaleTypeSize, - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + scalesOffset, + args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + scalesOffset, config.scaleDim * config.scaleTypeSize); } if (laneId == 0) { @@ -426,7 +426,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { weightSize); } - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.combineShmemInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, tokenPackSize, srcPe); } @@ -457,7 +457,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * tokenPackSize; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * tokenPackSize; - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.combineShmemInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * tokenPackSize, srcPe); @@ -531,9 +531,9 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { if (destPe < config.worldSize) { srcPtrs[j] = - reinterpret_cast(args.shmemInpTokMemObj->template GetAs() + byteOffset); + reinterpret_cast(args.combineShmemInpTokMemObj->template GetAs() + byteOffset); srcWeightsPtr[j] = reinterpret_cast( - args.shmemInpTokMemObj->template GetAs() + weightByteOffset); + args.combineShmemInpTokMemObj->template GetAs() + weightByteOffset); } else { srcPtrs[j] = nullptr; srcWeightsPtr[j] = nullptr; diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index a529c139..57112b12 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -211,7 +211,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; if (args.config.useExternalInpBuffer) { for (int i = globalWarpId; i < totalRecvTokenNum; i += globalWarpNum) { - core::WarpCopy(args.shmemInpTokMemObj->template GetAs() + i * config.hiddenDim, + core::WarpCopy(args.combineShmemInpTokMemObj->template GetAs() + i * config.hiddenDim, args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim); } } @@ -252,7 +252,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { if (destPe < config.worldSize) { index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank; - srcPtrs[j] = args.shmemInpTokMemObj->template GetAs(destPe) + + srcPtrs[j] = args.combineShmemInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim + hiddenDimOffset; srcWeightsPtr[j] = args.shmemInpWeightsMemObj->template GetAs(destPe) + destLocalTokId * config.numExpertPerToken; diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index a2b6a383..7eac7649 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -173,10 +173,10 @@ torch::Tensor GetDispatchReceiverTokenIdxMap(mori::moe::EpDispatchCombineHandle& return tensor; } -torch::Tensor GetRegisteredInputBuffer(mori::moe::EpDispatchCombineHandle& handle, +torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle& handle, at::ScalarType scalarType) { torch::Tensor out = - torch::from_blob(handle.shmemInpTokMemObj->Get(), + torch::from_blob(handle.combineShmemInpTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim}, torch::TensorOptions().dtype(scalarType).device(torch::kCUDA)); return out; @@ -209,8 +209,8 @@ void DeclareEpDispatchCombineHandle(pybind11::module& m) { funcName = std::string("get_dispatch_receiver_token_idx_map"); m.def(funcName.c_str(), &GetDispatchReceiverTokenIdxMap); - funcName = std::string("get_registered_input_buffer"); - m.def(funcName.c_str(), &GetRegisteredInputBuffer); + funcName = std::string("get_registered_combine_input_buffer"); + m.def(funcName.c_str(), &GetRegisteredCombineInputBuffer); } } // namespace diff --git a/tests/python/ops/test_dispatch_combine_internode_inconsistency.py b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py new file mode 100644 index 00000000..4957302d --- /dev/null +++ b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py @@ -0,0 +1,255 @@ +# Copyright © Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import pytest +import mori +import os +from tests.python.utils import TorchDistProcessManager +import torch +import torch.distributed as dist + + +class EpDispatchCombineTestCase: + def __init__(self, config): + self.config = config + self.device = torch.device("cuda", self.config.rank) + self.rng = torch.Generator(device=self.device) + self.rng.manual_seed(123) + + def sync(self): + torch.cuda.synchronize() + dist.barrier() + + def gen_test_data(self, use_max_token_num=False): + if use_max_token_num: + num_token = torch.tensor( + [ + self.config.max_num_inp_token_per_rank + for i in range(self.config.world_size) + ] + ).to(self.device) + else: + num_token = torch.randint( + 0, + self.config.max_num_inp_token_per_rank + 1, + [self.config.world_size], + generator=self.rng, + device=self.device, + ) + + # gen indices + all_rank_indices = [] + for r in range(self.config.world_size): + indices = torch.empty( + num_token[r], + self.config.num_experts_per_token, + dtype=torch.int64, + # device=self.device, + ) + for i in range(num_token[r]): + perm = torch.randperm( + self.config.num_experts_per_rank * self.config.world_size, + generator=self.rng, + device=self.device, + ) + indices[i] = perm[: self.config.num_experts_per_token] + all_rank_indices.append(indices.to(torch.int32).to(self.device)) + + # gen weights + all_rank_weights = [ + torch.rand( + num_token[r], + self.config.num_experts_per_token, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) + for r in range(self.config.world_size) + ] + + # gen scales + all_rank_scales = [ + torch.rand( + num_token[r], + self.config.scale_dim, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) + for r in range(self.config.world_size) + ] + if self.config.scale_type_size == 1: + all_rank_scales = [t.to(torch.float8_e4m3fnuz) for t in all_rank_scales] + + # gen input & output + # some functions such as randn and cat are not implemented for fp8 + all_rank_input = [] + for r in range(self.config.world_size): + all_rank_input.append( + torch.randn( + num_token[r], + self.config.hidden_dim, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ).to(self.config.data_type) + ) + + return ( + num_token, + all_rank_indices, + all_rank_input, + all_rank_weights, + all_rank_scales, + ) + + def run_test_once(self, op, test_data): + ( + all_rank_num_token, + all_rank_indices, + all_rank_input, + all_rank_weights, + all_rank_scales, + ) = test_data + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = op.dispatch( + all_rank_input[self.config.rank], + all_rank_weights[self.config.rank], + all_rank_scales[self.config.rank], + all_rank_indices[self.config.rank], + ) + + recv_num_token = dispatch_recv_num_token.item() + max_expert_idx = dispatch_indices[:recv_num_token].max().item() + num_experts = self.config.num_experts_per_rank * self.config.world_size + if max_expert_idx >= num_experts: + print(f"Invalid expert id: {max_expert_idx}") + assert False + + combine_output, combine_output_weight = op.combine( + dispatch_output, dispatch_weights, dispatch_indices, call_reset=True + ) + self.sync() + + +@pytest.fixture(scope="session") +def torch_dist_process_manager(): + os.environ['MORI_DISABLE_P2P'] = '1' + try: + torch.multiprocessing.set_start_method("spawn", force=True) + print("Multiprocessing start method set to spawn") + except RuntimeError: + pass + manager = TorchDistProcessManager() + manager.start_workers(world_size=8) + yield manager + manager.shutdown() + + +def _test_dispatch_combine( + rank, + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, +): + config = mori.ops.EpDispatchCombineConfig( + data_type=data_type, + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + max_num_inp_token_per_rank=max_num_inp_token_per_rank, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + max_token_type_size=2, + block_num=16, + warp_num_per_block=1, + kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode, + ) + op = mori.ops.EpDispatchCombineOp(config) + test_case = EpDispatchCombineTestCase(config) + test_data = test_case.gen_test_data(True) + num_reps = 2048 + for idx in range(num_reps): + test_case.run_test_once(op, test_data) + if rank == 0: + print(f"Passed {idx}/{num_reps}") + + +# TODO: create a sub process group so that we can test worlds size < 8 +@pytest.mark.parametrize("world_size", (8,)) +@pytest.mark.parametrize("data_type", (torch.float8_e4m3fnuz,)) +@pytest.mark.parametrize("hidden_dim", (7168,)) +@pytest.mark.parametrize("scale_dim", (56,)) +@pytest.mark.parametrize("scale_type_size", (4,)) +@pytest.mark.parametrize("max_num_inp_token_per_rank", (256,)) +@pytest.mark.parametrize("num_experts_per_rank", (32,)) +@pytest.mark.parametrize("num_experts_per_token", (8,)) +def test_dispatch_combine( + torch_dist_process_manager, + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, +): + for i in range(world_size): + torch_dist_process_manager.task_queue.put( + ( + _test_dispatch_combine, + [ + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, + ], + ) + ) + + results = [] + for i in range(world_size): + ( + rank, + result, + ) = torch_dist_process_manager.result_queue.get() + results.append(result) + + for result in results: + if result is not None: + pytest.assume(False, result) From 4c18874c201db93ed4e09489d0ddd2beeaf22ef6 Mon Sep 17 00:00:00 2001 From: Dongmin Ra Date: Fri, 10 Oct 2025 16:59:48 +0900 Subject: [PATCH 2/3] Changed variable names --- .../ops/dispatch_combine/dispatch_combine.hpp | 12 ++--- src/ops/dispatch_combine/dispatch_combine.cpp | 10 ++-- src/ops/dispatch_combine/internode.hpp | 46 +++++++++---------- src/ops/dispatch_combine/intranode.hpp | 4 +- src/pybind/mori.cpp | 4 +- 5 files changed, 39 insertions(+), 37 deletions(-) diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index f19f415f..1be3c02d 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -169,8 +169,8 @@ class EpDispatchCombineHandle { uint8_t* scalesBuf{nullptr}; // Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output - mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj; - mori::application::SymmMemObjPtr combineShmemInpTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; + mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; @@ -228,8 +228,8 @@ struct EpDispatchCombineArgs { T* outTokenBuf{nullptr}; float* weightsBuf{nullptr}; uint8_t* scalesBuf{nullptr}; - mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj; - mori::application::SymmMemObjPtr combineShmemInpTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; + mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; mori::application::SymmMemObjPtr shmemInpWeightsMemObj; @@ -272,8 +272,8 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.scalesBuf = handle.scalesBuf; args.destPeTokenCounter = handle.destPeTokenCounter; args.localPeTokenCounter = handle.localPeTokenCounter; - args.dispatchShmemInpTokMemObj = handle.dispatchShmemInpTokMemObj; - args.combineShmemInpTokMemObj = handle.combineShmemInpTokMemObj; + args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj; + args.shmemCombineInpTokMemObj = handle.shmemCombineInpTokMemObj; args.shmemOutTokMemObj = handle.shmemOutTokMemObj; args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj; args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj; diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 465fa42b..9331327c 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -70,8 +70,10 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { (config.hiddenDim * config.maxTokenTypeSize + (sizeof(float) + sizeof(index_t)) * config.numExpertPerToken + config.scaleDim * config.scaleTypeSize); - dispatchShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); - combineShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + shmemDispatchInpTokMemObj = + ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + shmemCombineInpTokMemObj = + ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); @@ -91,8 +93,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { } void EpDispatchCombineHandle::FinalizeShmemBuf() { - ShmemFree(dispatchShmemInpTokMemObj->localPtr); - ShmemFree(combineShmemInpTokMemObj->localPtr); + ShmemFree(shmemDispatchInpTokMemObj->localPtr); + ShmemFree(shmemCombineInpTokMemObj->localPtr); ShmemFree(shmemOutTokMemObj->localPtr); ShmemFree(shmemStagingTokMemObj->localPtr); ShmemFree(shmemInpWeightsMemObj->localPtr); diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index 4a9c8469..e4d11393 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -171,7 +171,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { tokenId * config.scaleDim * config.scaleTypeSize, config.scaleDim * config.scaleTypeSize); } - shmem::ShmemPutTypeNbiWarp(args.dispatchShmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemDispatchInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset, destPe); } @@ -202,7 +202,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * stagingOffset; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * stagingOffset; - shmem::ShmemPutTypeNbiWarp(args.dispatchShmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemDispatchInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * stagingOffset, destPe); @@ -297,24 +297,24 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset; core::WarpCopy(args.shmemOutTokMemObj->template GetAs() + localTokenOffset, - args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset, + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset, config.hiddenDim * sizeof(T)); - core::WarpCopy( - args.shmemOutWeightsMemObj->template GetAs() + - localTokenIdx * config.numExpertPerToken * sizeof(float), - args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + weightOffset, - config.numExpertPerToken * sizeof(float)); - core::WarpCopy( - args.shmemOutIndicesMemObj->template GetAs() + - localTokenIdx * config.numExpertPerToken * sizeof(index_t), - args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + indicesOffset, - config.numExpertPerToken * sizeof(index_t)); + core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs() + + localTokenIdx * config.numExpertPerToken * sizeof(float), + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + weightOffset, + config.numExpertPerToken * sizeof(float)); + core::WarpCopy(args.shmemOutIndicesMemObj->template GetAs() + + localTokenIdx * config.numExpertPerToken * sizeof(index_t), + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + indicesOffset, + config.numExpertPerToken * sizeof(index_t)); if (args.scalesBuf && (config.scaleDim > 0) && (config.scaleTypeSize > 0)) { - core::WarpCopy( - args.shmemOutScalesMemObj->template GetAs() + - localTokenIdx * config.scaleDim * config.scaleTypeSize, - args.dispatchShmemInpTokMemObj->template GetAs() + peSortedTokenOffset + scalesOffset, - config.scaleDim * config.scaleTypeSize); + core::WarpCopy(args.shmemOutScalesMemObj->template GetAs() + + localTokenIdx * config.scaleDim * config.scaleTypeSize, + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + scalesOffset, + config.scaleDim * config.scaleTypeSize); } if (laneId == 0) { args.dispReceiverIdxMap[localTokenIdx] = peSortedId; @@ -426,7 +426,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { weightSize); } - shmem::ShmemPutTypeNbiWarp(args.combineShmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemCombineInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, tokenPackSize, srcPe); } @@ -457,7 +457,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * tokenPackSize; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * tokenPackSize; - shmem::ShmemPutTypeNbiWarp(args.combineShmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemCombineInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * tokenPackSize, srcPe); @@ -530,10 +530,10 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t weightByteOffset = size_t(peSortedId) * tokenPackSize + tokenSize; if (destPe < config.worldSize) { - srcPtrs[j] = - reinterpret_cast(args.combineShmemInpTokMemObj->template GetAs() + byteOffset); + srcPtrs[j] = reinterpret_cast(args.shmemCombineInpTokMemObj->template GetAs() + + byteOffset); srcWeightsPtr[j] = reinterpret_cast( - args.combineShmemInpTokMemObj->template GetAs() + weightByteOffset); + args.shmemCombineInpTokMemObj->template GetAs() + weightByteOffset); } else { srcPtrs[j] = nullptr; srcWeightsPtr[j] = nullptr; diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index 57112b12..f83ba003 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -211,7 +211,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; if (args.config.useExternalInpBuffer) { for (int i = globalWarpId; i < totalRecvTokenNum; i += globalWarpNum) { - core::WarpCopy(args.combineShmemInpTokMemObj->template GetAs() + i * config.hiddenDim, + core::WarpCopy(args.shmemCombineInpTokMemObj->template GetAs() + i * config.hiddenDim, args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim); } } @@ -252,7 +252,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { if (destPe < config.worldSize) { index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank; - srcPtrs[j] = args.combineShmemInpTokMemObj->template GetAs(destPe) + + srcPtrs[j] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim + hiddenDimOffset; srcWeightsPtr[j] = args.shmemInpWeightsMemObj->template GetAs(destPe) + destLocalTokId * config.numExpertPerToken; diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index 7eac7649..eaadfe9c 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -174,9 +174,9 @@ torch::Tensor GetDispatchReceiverTokenIdxMap(mori::moe::EpDispatchCombineHandle& } torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle& handle, - at::ScalarType scalarType) { + at::ScalarType scalarType) { torch::Tensor out = - torch::from_blob(handle.combineShmemInpTokMemObj->Get(), + torch::from_blob(handle.shmemCombineInpTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim}, torch::TensorOptions().dtype(scalarType).device(torch::kCUDA)); return out; From 176e18680011228c338e6c8fbb5204da5c36f6d9 Mon Sep 17 00:00:00 2001 From: Dongmin Ra Date: Tue, 14 Oct 2025 00:18:39 +0000 Subject: [PATCH 3/3] Changed test parameters --- .../test_dispatch_combine_internode_inconsistency.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/ops/test_dispatch_combine_internode_inconsistency.py b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py index 4957302d..1e6acc3e 100644 --- a/tests/python/ops/test_dispatch_combine_internode_inconsistency.py +++ b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py @@ -157,7 +157,7 @@ def run_test_once(self, op, test_data): @pytest.fixture(scope="session") def torch_dist_process_manager(): - os.environ['MORI_DISABLE_P2P'] = '1' + os.environ["MORI_DISABLE_P2P"] = "1" try: torch.multiprocessing.set_start_method("spawn", force=True) print("Multiprocessing start method set to spawn") @@ -192,7 +192,7 @@ def _test_dispatch_combine( num_experts_per_token=num_experts_per_token, max_token_type_size=2, block_num=16, - warp_num_per_block=1, + warp_num_per_block=16, kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode, ) op = mori.ops.EpDispatchCombineOp(config) @@ -200,9 +200,9 @@ def _test_dispatch_combine( test_data = test_case.gen_test_data(True) num_reps = 2048 for idx in range(num_reps): - test_case.run_test_once(op, test_data) - if rank == 0: - print(f"Passed {idx}/{num_reps}") + test_case.run_test_once(op, test_data) + if rank == 0: + print(f"Passed {idx}/{num_reps}") # TODO: create a sub process group so that we can test worlds size < 8 @@ -211,7 +211,7 @@ def _test_dispatch_combine( @pytest.mark.parametrize("hidden_dim", (7168,)) @pytest.mark.parametrize("scale_dim", (56,)) @pytest.mark.parametrize("scale_type_size", (4,)) -@pytest.mark.parametrize("max_num_inp_token_per_rank", (256,)) +@pytest.mark.parametrize("max_num_inp_token_per_rank", (4096,)) @pytest.mark.parametrize("num_experts_per_rank", (32,)) @pytest.mark.parametrize("num_experts_per_token", (8,)) def test_dispatch_combine(