Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/ops/dispatch_combine/test_dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def run_test_once(self, op, test_data):
print("Dispatch Pass")

total_recv_num_token = dispatch_recv_num_token[0].item()
combine_input = op.get_registered_input_buffer(self.config.data_type)
combine_input = op.get_registered_combine_input_buffer(self.config.data_type)
combine_input[:total_recv_num_token, :].copy_(
dispatch_output[:total_recv_num_token, :]
)
Expand Down
9 changes: 6 additions & 3 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class EpDispatchCombineHandle {
uint8_t* scalesBuf{nullptr};

// Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output
mori::application::SymmMemObjPtr shmemInpTokMemObj;
mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj;
mori::application::SymmMemObjPtr shmemCombineInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;

Expand Down Expand Up @@ -227,7 +228,8 @@ struct EpDispatchCombineArgs {
T* outTokenBuf{nullptr};
float* weightsBuf{nullptr};
uint8_t* scalesBuf{nullptr};
mori::application::SymmMemObjPtr shmemInpTokMemObj;
mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj;
mori::application::SymmMemObjPtr shmemCombineInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;
mori::application::SymmMemObjPtr shmemInpWeightsMemObj;
Expand Down Expand Up @@ -270,7 +272,8 @@ EpDispatchCombineArgs<T> GetEpDispatchCombineArgs(const EpDispatchCombineHandle&
args.scalesBuf = handle.scalesBuf;
args.destPeTokenCounter = handle.destPeTokenCounter;
args.localPeTokenCounter = handle.localPeTokenCounter;
args.shmemInpTokMemObj = handle.shmemInpTokMemObj;
args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj;
args.shmemCombineInpTokMemObj = handle.shmemCombineInpTokMemObj;
args.shmemOutTokMemObj = handle.shmemOutTokMemObj;
args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj;
args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj;
Expand Down
8 changes: 4 additions & 4 deletions python/mori/ops/dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(self, config):
self._get_dispatch_receiver_token_idx_map_func = _cpp_dispatch_combine_factory(
"get_dispatch_receiver_token_idx_map"
)
self._get_registered_input_buffer = _cpp_dispatch_combine_factory(
"get_registered_input_buffer"
self._get_registered_combine_input_buffer = _cpp_dispatch_combine_factory(
"get_registered_combine_input_buffer"
)

def get_registered_input_buffer(self, dtype: torch.dtype):
return self._get_registered_input_buffer(self._handle, dtype)
def get_registered_combine_input_buffer(self, dtype: torch.dtype):
return self._get_registered_combine_input_buffer(self._handle, dtype)

def dispatch(
self,
Expand Down
8 changes: 6 additions & 2 deletions src/ops/dispatch_combine/dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
(config.hiddenDim * config.maxTokenTypeSize +
(sizeof(float) + sizeof(index_t)) * config.numExpertPerToken +
config.scaleDim * config.scaleTypeSize);
shmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemDispatchInpTokMemObj =
ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemCombineInpTokMemObj =
ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached);
shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);

Expand All @@ -90,7 +93,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
}

void EpDispatchCombineHandle::FinalizeShmemBuf() {
ShmemFree(shmemInpTokMemObj->localPtr);
ShmemFree(shmemDispatchInpTokMemObj->localPtr);
ShmemFree(shmemCombineInpTokMemObj->localPtr);
ShmemFree(shmemOutTokMemObj->localPtr);
ShmemFree(shmemStagingTokMemObj->localPtr);
ShmemFree(shmemInpWeightsMemObj->localPtr);
Expand Down
46 changes: 23 additions & 23 deletions src/ops/dispatch_combine/internode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
tokenId * config.scaleDim * config.scaleTypeSize,
config.scaleDim * config.scaleTypeSize);
}
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, peSortedOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemDispatchInpTokMemObj, peSortedOffset,
args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset,
destPe);
}
Expand Down Expand Up @@ -202,7 +202,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t srcOffset = srcIdx * stagingOffset;
const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset;
size_t dstOffset = dstIdx * stagingOffset;
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, dstOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemDispatchInpTokMemObj, dstOffset,
args.shmemStagingTokMemObj, srcOffset,
actualTokenNum * stagingOffset, destPe);

Expand Down Expand Up @@ -297,24 +297,24 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset;

core::WarpCopy(args.shmemOutTokMemObj->template GetAs<char*>() + localTokenOffset,
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset,
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset,
config.hiddenDim * sizeof(T));
core::WarpCopy(
args.shmemOutWeightsMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(float),
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + weightOffset,
config.numExpertPerToken * sizeof(float));
core::WarpCopy(
args.shmemOutIndicesMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(index_t),
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + indicesOffset,
config.numExpertPerToken * sizeof(index_t));
core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(float),
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset +
weightOffset,
config.numExpertPerToken * sizeof(float));
core::WarpCopy(args.shmemOutIndicesMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(index_t),
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset +
indicesOffset,
config.numExpertPerToken * sizeof(index_t));
if (args.scalesBuf && (config.scaleDim > 0) && (config.scaleTypeSize > 0)) {
core::WarpCopy(
args.shmemOutScalesMemObj->template GetAs<char*>() +
localTokenIdx * config.scaleDim * config.scaleTypeSize,
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + scalesOffset,
config.scaleDim * config.scaleTypeSize);
core::WarpCopy(args.shmemOutScalesMemObj->template GetAs<char*>() +
localTokenIdx * config.scaleDim * config.scaleTypeSize,
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset +
scalesOffset,
config.scaleDim * config.scaleTypeSize);
}
if (laneId == 0) {
args.dispReceiverIdxMap[localTokenIdx] = peSortedId;
Expand Down Expand Up @@ -426,7 +426,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
weightSize);
}

shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, peSortedOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemCombineInpTokMemObj, peSortedOffset,
args.shmemStagingTokMemObj, mapIdxOffset, tokenPackSize,
srcPe);
}
Expand Down Expand Up @@ -457,7 +457,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t srcOffset = srcIdx * tokenPackSize;
const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset;
size_t dstOffset = dstIdx * tokenPackSize;
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, dstOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemCombineInpTokMemObj, dstOffset,
args.shmemStagingTokMemObj, srcOffset,
actualTokenNum * tokenPackSize, srcPe);

Expand Down Expand Up @@ -530,10 +530,10 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t weightByteOffset = size_t(peSortedId) * tokenPackSize + tokenSize;

if (destPe < config.worldSize) {
srcPtrs[j] =
reinterpret_cast<T*>(args.shmemInpTokMemObj->template GetAs<char*>() + byteOffset);
srcPtrs[j] = reinterpret_cast<T*>(args.shmemCombineInpTokMemObj->template GetAs<char*>() +
byteOffset);
srcWeightsPtr[j] = reinterpret_cast<float*>(
args.shmemInpTokMemObj->template GetAs<char*>() + weightByteOffset);
args.shmemCombineInpTokMemObj->template GetAs<char*>() + weightByteOffset);
} else {
srcPtrs[j] = nullptr;
srcWeightsPtr[j] = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/ops/dispatch_combine/intranode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
index_t totalRecvTokenNum = args.totalRecvTokenNum[0];
if (args.config.useExternalInpBuffer) {
for (int i = globalWarpId; i < totalRecvTokenNum; i += globalWarpNum) {
core::WarpCopy(args.shmemInpTokMemObj->template GetAs<T*>() + i * config.hiddenDim,
core::WarpCopy(args.shmemCombineInpTokMemObj->template GetAs<T*>() + i * config.hiddenDim,
args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim);
}
}
Expand Down Expand Up @@ -252,7 +252,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {

if (destPe < config.worldSize) {
index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank;
srcPtrs[j] = args.shmemInpTokMemObj->template GetAs<T*>(destPe) +
srcPtrs[j] = args.shmemCombineInpTokMemObj->template GetAs<T*>(destPe) +
destLocalTokId * config.hiddenDim + hiddenDimOffset;
srcWeightsPtr[j] = args.shmemInpWeightsMemObj->template GetAs<float*>(destPe) +
destLocalTokId * config.numExpertPerToken;
Expand Down
10 changes: 5 additions & 5 deletions src/pybind/mori.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ torch::Tensor GetDispatchReceiverTokenIdxMap(mori::moe::EpDispatchCombineHandle&
return tensor;
}

torch::Tensor GetRegisteredInputBuffer(mori::moe::EpDispatchCombineHandle& handle,
at::ScalarType scalarType) {
torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle& handle,
at::ScalarType scalarType) {
torch::Tensor out =
torch::from_blob(handle.shmemInpTokMemObj->Get(),
torch::from_blob(handle.shmemCombineInpTokMemObj->Get(),
{handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim},
torch::TensorOptions().dtype(scalarType).device(torch::kCUDA));
return out;
Expand Down Expand Up @@ -209,8 +209,8 @@ void DeclareEpDispatchCombineHandle(pybind11::module& m) {
funcName = std::string("get_dispatch_receiver_token_idx_map");
m.def(funcName.c_str(), &GetDispatchReceiverTokenIdxMap);

funcName = std::string("get_registered_input_buffer");
m.def(funcName.c_str(), &GetRegisteredInputBuffer);
funcName = std::string("get_registered_combine_input_buffer");
m.def(funcName.c_str(), &GetRegisteredCombineInputBuffer);
}

} // namespace
Expand Down
Loading