diff --git a/examples/ops/dispatch_combine/test_dispatch_combine.py b/examples/ops/dispatch_combine/test_dispatch_combine.py index c444a799..8af67d57 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine.py +++ b/examples/ops/dispatch_combine/test_dispatch_combine.py @@ -244,7 +244,7 @@ def run_test_once(self, op, test_data): print("Dispatch Pass") total_recv_num_token = dispatch_recv_num_token[0].item() - combine_input = op.get_registered_input_buffer(self.config.data_type) + combine_input = op.get_registered_combine_input_buffer(self.config.data_type) combine_input[:total_recv_num_token, :].copy_( dispatch_output[:total_recv_num_token, :] ) diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 3cdd773b..1be3c02d 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -169,7 +169,8 @@ class EpDispatchCombineHandle { uint8_t* scalesBuf{nullptr}; // Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output - mori::application::SymmMemObjPtr shmemInpTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; + mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; @@ -227,7 +228,8 @@ struct EpDispatchCombineArgs { T* outTokenBuf{nullptr}; float* weightsBuf{nullptr}; uint8_t* scalesBuf{nullptr}; - mori::application::SymmMemObjPtr shmemInpTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; + mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; mori::application::SymmMemObjPtr shmemOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; mori::application::SymmMemObjPtr shmemInpWeightsMemObj; @@ -270,7 +272,8 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.scalesBuf = handle.scalesBuf; args.destPeTokenCounter = handle.destPeTokenCounter; args.localPeTokenCounter = handle.localPeTokenCounter; - args.shmemInpTokMemObj = handle.shmemInpTokMemObj; + args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj; + args.shmemCombineInpTokMemObj = handle.shmemCombineInpTokMemObj; args.shmemOutTokMemObj = handle.shmemOutTokMemObj; args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj; args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj; diff --git a/python/mori/ops/dispatch_combine.py b/python/mori/ops/dispatch_combine.py index 32904ccb..77a2723a 100644 --- a/python/mori/ops/dispatch_combine.py +++ b/python/mori/ops/dispatch_combine.py @@ -90,12 +90,12 @@ def __init__(self, config): self._get_dispatch_receiver_token_idx_map_func = _cpp_dispatch_combine_factory( "get_dispatch_receiver_token_idx_map" ) - self._get_registered_input_buffer = _cpp_dispatch_combine_factory( - "get_registered_input_buffer" + self._get_registered_combine_input_buffer = _cpp_dispatch_combine_factory( + "get_registered_combine_input_buffer" ) - def get_registered_input_buffer(self, dtype: torch.dtype): - return self._get_registered_input_buffer(self._handle, dtype) + def get_registered_combine_input_buffer(self, dtype: torch.dtype): + return self._get_registered_combine_input_buffer(self._handle, dtype) def dispatch( self, diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 7c2dbe00..9331327c 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -70,7 +70,10 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { (config.hiddenDim * config.maxTokenTypeSize + (sizeof(float) + sizeof(index_t)) * config.numExpertPerToken + config.scaleDim * config.scaleTypeSize); - shmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + shmemDispatchInpTokMemObj = + ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); + shmemCombineInpTokMemObj = + ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); @@ -90,7 +93,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { } void EpDispatchCombineHandle::FinalizeShmemBuf() { - ShmemFree(shmemInpTokMemObj->localPtr); + ShmemFree(shmemDispatchInpTokMemObj->localPtr); + ShmemFree(shmemCombineInpTokMemObj->localPtr); ShmemFree(shmemOutTokMemObj->localPtr); ShmemFree(shmemStagingTokMemObj->localPtr); ShmemFree(shmemInpWeightsMemObj->localPtr); diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index bf8f3e1f..e4d11393 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -171,7 +171,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { tokenId * config.scaleDim * config.scaleTypeSize, config.scaleDim * config.scaleTypeSize); } - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemDispatchInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset, destPe); } @@ -202,7 +202,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * stagingOffset; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * stagingOffset; - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemDispatchInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * stagingOffset, destPe); @@ -297,24 +297,24 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset; core::WarpCopy(args.shmemOutTokMemObj->template GetAs() + localTokenOffset, - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset, + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset, config.hiddenDim * sizeof(T)); - core::WarpCopy( - args.shmemOutWeightsMemObj->template GetAs() + - localTokenIdx * config.numExpertPerToken * sizeof(float), - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + weightOffset, - config.numExpertPerToken * sizeof(float)); - core::WarpCopy( - args.shmemOutIndicesMemObj->template GetAs() + - localTokenIdx * config.numExpertPerToken * sizeof(index_t), - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + indicesOffset, - config.numExpertPerToken * sizeof(index_t)); + core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs() + + localTokenIdx * config.numExpertPerToken * sizeof(float), + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + weightOffset, + config.numExpertPerToken * sizeof(float)); + core::WarpCopy(args.shmemOutIndicesMemObj->template GetAs() + + localTokenIdx * config.numExpertPerToken * sizeof(index_t), + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + indicesOffset, + config.numExpertPerToken * sizeof(index_t)); if (args.scalesBuf && (config.scaleDim > 0) && (config.scaleTypeSize > 0)) { - core::WarpCopy( - args.shmemOutScalesMemObj->template GetAs() + - localTokenIdx * config.scaleDim * config.scaleTypeSize, - args.shmemInpTokMemObj->template GetAs() + peSortedTokenOffset + scalesOffset, - config.scaleDim * config.scaleTypeSize); + core::WarpCopy(args.shmemOutScalesMemObj->template GetAs() + + localTokenIdx * config.scaleDim * config.scaleTypeSize, + args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + + scalesOffset, + config.scaleDim * config.scaleTypeSize); } if (laneId == 0) { args.dispReceiverIdxMap[localTokenIdx] = peSortedId; @@ -426,7 +426,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { weightSize); } - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, peSortedOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemCombineInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, tokenPackSize, srcPe); } @@ -457,7 +457,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t srcOffset = srcIdx * tokenPackSize; const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset; size_t dstOffset = dstIdx * tokenPackSize; - shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, + shmem::ShmemPutTypeNbiWarp(args.shmemCombineInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, actualTokenNum * tokenPackSize, srcPe); @@ -530,10 +530,10 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t weightByteOffset = size_t(peSortedId) * tokenPackSize + tokenSize; if (destPe < config.worldSize) { - srcPtrs[j] = - reinterpret_cast(args.shmemInpTokMemObj->template GetAs() + byteOffset); + srcPtrs[j] = reinterpret_cast(args.shmemCombineInpTokMemObj->template GetAs() + + byteOffset); srcWeightsPtr[j] = reinterpret_cast( - args.shmemInpTokMemObj->template GetAs() + weightByteOffset); + args.shmemCombineInpTokMemObj->template GetAs() + weightByteOffset); } else { srcPtrs[j] = nullptr; srcWeightsPtr[j] = nullptr; diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index a529c139..f83ba003 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -211,7 +211,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { index_t totalRecvTokenNum = args.totalRecvTokenNum[0]; if (args.config.useExternalInpBuffer) { for (int i = globalWarpId; i < totalRecvTokenNum; i += globalWarpNum) { - core::WarpCopy(args.shmemInpTokMemObj->template GetAs() + i * config.hiddenDim, + core::WarpCopy(args.shmemCombineInpTokMemObj->template GetAs() + i * config.hiddenDim, args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim); } } @@ -252,7 +252,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { if (destPe < config.worldSize) { index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank; - srcPtrs[j] = args.shmemInpTokMemObj->template GetAs(destPe) + + srcPtrs[j] = args.shmemCombineInpTokMemObj->template GetAs(destPe) + destLocalTokId * config.hiddenDim + hiddenDimOffset; srcWeightsPtr[j] = args.shmemInpWeightsMemObj->template GetAs(destPe) + destLocalTokId * config.numExpertPerToken; diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index a2b6a383..eaadfe9c 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -173,10 +173,10 @@ torch::Tensor GetDispatchReceiverTokenIdxMap(mori::moe::EpDispatchCombineHandle& return tensor; } -torch::Tensor GetRegisteredInputBuffer(mori::moe::EpDispatchCombineHandle& handle, - at::ScalarType scalarType) { +torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle& handle, + at::ScalarType scalarType) { torch::Tensor out = - torch::from_blob(handle.shmemInpTokMemObj->Get(), + torch::from_blob(handle.shmemCombineInpTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim}, torch::TensorOptions().dtype(scalarType).device(torch::kCUDA)); return out; @@ -209,8 +209,8 @@ void DeclareEpDispatchCombineHandle(pybind11::module& m) { funcName = std::string("get_dispatch_receiver_token_idx_map"); m.def(funcName.c_str(), &GetDispatchReceiverTokenIdxMap); - funcName = std::string("get_registered_input_buffer"); - m.def(funcName.c_str(), &GetRegisteredInputBuffer); + funcName = std::string("get_registered_combine_input_buffer"); + m.def(funcName.c_str(), &GetRegisteredCombineInputBuffer); } } // namespace diff --git a/tests/python/ops/test_dispatch_combine_internode_inconsistency.py b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py new file mode 100644 index 00000000..1e6acc3e --- /dev/null +++ b/tests/python/ops/test_dispatch_combine_internode_inconsistency.py @@ -0,0 +1,255 @@ +# Copyright © Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import pytest +import mori +import os +from tests.python.utils import TorchDistProcessManager +import torch +import torch.distributed as dist + + +class EpDispatchCombineTestCase: + def __init__(self, config): + self.config = config + self.device = torch.device("cuda", self.config.rank) + self.rng = torch.Generator(device=self.device) + self.rng.manual_seed(123) + + def sync(self): + torch.cuda.synchronize() + dist.barrier() + + def gen_test_data(self, use_max_token_num=False): + if use_max_token_num: + num_token = torch.tensor( + [ + self.config.max_num_inp_token_per_rank + for i in range(self.config.world_size) + ] + ).to(self.device) + else: + num_token = torch.randint( + 0, + self.config.max_num_inp_token_per_rank + 1, + [self.config.world_size], + generator=self.rng, + device=self.device, + ) + + # gen indices + all_rank_indices = [] + for r in range(self.config.world_size): + indices = torch.empty( + num_token[r], + self.config.num_experts_per_token, + dtype=torch.int64, + # device=self.device, + ) + for i in range(num_token[r]): + perm = torch.randperm( + self.config.num_experts_per_rank * self.config.world_size, + generator=self.rng, + device=self.device, + ) + indices[i] = perm[: self.config.num_experts_per_token] + all_rank_indices.append(indices.to(torch.int32).to(self.device)) + + # gen weights + all_rank_weights = [ + torch.rand( + num_token[r], + self.config.num_experts_per_token, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) + for r in range(self.config.world_size) + ] + + # gen scales + all_rank_scales = [ + torch.rand( + num_token[r], + self.config.scale_dim, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ) + for r in range(self.config.world_size) + ] + if self.config.scale_type_size == 1: + all_rank_scales = [t.to(torch.float8_e4m3fnuz) for t in all_rank_scales] + + # gen input & output + # some functions such as randn and cat are not implemented for fp8 + all_rank_input = [] + for r in range(self.config.world_size): + all_rank_input.append( + torch.randn( + num_token[r], + self.config.hidden_dim, + dtype=torch.float32, + generator=self.rng, + device=self.device, + ).to(self.config.data_type) + ) + + return ( + num_token, + all_rank_indices, + all_rank_input, + all_rank_weights, + all_rank_scales, + ) + + def run_test_once(self, op, test_data): + ( + all_rank_num_token, + all_rank_indices, + all_rank_input, + all_rank_weights, + all_rank_scales, + ) = test_data + ( + dispatch_output, + dispatch_weights, + dispatch_scales, + dispatch_indices, + dispatch_recv_num_token, + ) = op.dispatch( + all_rank_input[self.config.rank], + all_rank_weights[self.config.rank], + all_rank_scales[self.config.rank], + all_rank_indices[self.config.rank], + ) + + recv_num_token = dispatch_recv_num_token.item() + max_expert_idx = dispatch_indices[:recv_num_token].max().item() + num_experts = self.config.num_experts_per_rank * self.config.world_size + if max_expert_idx >= num_experts: + print(f"Invalid expert id: {max_expert_idx}") + assert False + + combine_output, combine_output_weight = op.combine( + dispatch_output, dispatch_weights, dispatch_indices, call_reset=True + ) + self.sync() + + +@pytest.fixture(scope="session") +def torch_dist_process_manager(): + os.environ["MORI_DISABLE_P2P"] = "1" + try: + torch.multiprocessing.set_start_method("spawn", force=True) + print("Multiprocessing start method set to spawn") + except RuntimeError: + pass + manager = TorchDistProcessManager() + manager.start_workers(world_size=8) + yield manager + manager.shutdown() + + +def _test_dispatch_combine( + rank, + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, +): + config = mori.ops.EpDispatchCombineConfig( + data_type=data_type, + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + max_num_inp_token_per_rank=max_num_inp_token_per_rank, + num_experts_per_rank=num_experts_per_rank, + num_experts_per_token=num_experts_per_token, + max_token_type_size=2, + block_num=16, + warp_num_per_block=16, + kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode, + ) + op = mori.ops.EpDispatchCombineOp(config) + test_case = EpDispatchCombineTestCase(config) + test_data = test_case.gen_test_data(True) + num_reps = 2048 + for idx in range(num_reps): + test_case.run_test_once(op, test_data) + if rank == 0: + print(f"Passed {idx}/{num_reps}") + + +# TODO: create a sub process group so that we can test worlds size < 8 +@pytest.mark.parametrize("world_size", (8,)) +@pytest.mark.parametrize("data_type", (torch.float8_e4m3fnuz,)) +@pytest.mark.parametrize("hidden_dim", (7168,)) +@pytest.mark.parametrize("scale_dim", (56,)) +@pytest.mark.parametrize("scale_type_size", (4,)) +@pytest.mark.parametrize("max_num_inp_token_per_rank", (4096,)) +@pytest.mark.parametrize("num_experts_per_rank", (32,)) +@pytest.mark.parametrize("num_experts_per_token", (8,)) +def test_dispatch_combine( + torch_dist_process_manager, + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, +): + for i in range(world_size): + torch_dist_process_manager.task_queue.put( + ( + _test_dispatch_combine, + [ + world_size, + data_type, + hidden_dim, + scale_dim, + scale_type_size, + max_num_inp_token_per_rank, + num_experts_per_rank, + num_experts_per_token, + ], + ) + ) + + results = [] + for i in range(world_size): + ( + rank, + result, + ) = torch_dist_process_manager.result_queue.get() + results.append(result) + + for result in results: + if result is not None: + pytest.assume(False, result)