Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/ops/dispatch_combine/test_dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def run_test_once(self, op, test_data):
print("Dispatch Pass")

total_recv_num_token = dispatch_recv_num_token[0].item()
combine_input = op.get_registered_input_buffer(self.config.data_type)
combine_input = op.get_registered_combine_input_buffer(self.config.data_type)
combine_input[:total_recv_num_token, :].copy_(
dispatch_output[:total_recv_num_token, :]
)
Expand Down
9 changes: 6 additions & 3 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class EpDispatchCombineHandle {
uint8_t* scalesBuf{nullptr};

// Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output
mori::application::SymmMemObjPtr shmemInpTokMemObj;
mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj;
mori::application::SymmMemObjPtr combineShmemInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;

Expand Down Expand Up @@ -227,7 +228,8 @@ struct EpDispatchCombineArgs {
T* outTokenBuf{nullptr};
float* weightsBuf{nullptr};
uint8_t* scalesBuf{nullptr};
mori::application::SymmMemObjPtr shmemInpTokMemObj;
mori::application::SymmMemObjPtr dispatchShmemInpTokMemObj;
mori::application::SymmMemObjPtr combineShmemInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;
mori::application::SymmMemObjPtr shmemInpWeightsMemObj;
Expand Down Expand Up @@ -270,7 +272,8 @@ EpDispatchCombineArgs<T> GetEpDispatchCombineArgs(const EpDispatchCombineHandle&
args.scalesBuf = handle.scalesBuf;
args.destPeTokenCounter = handle.destPeTokenCounter;
args.localPeTokenCounter = handle.localPeTokenCounter;
args.shmemInpTokMemObj = handle.shmemInpTokMemObj;
args.dispatchShmemInpTokMemObj = handle.dispatchShmemInpTokMemObj;
args.combineShmemInpTokMemObj = handle.combineShmemInpTokMemObj;
args.shmemOutTokMemObj = handle.shmemOutTokMemObj;
args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj;
args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj;
Expand Down
8 changes: 4 additions & 4 deletions python/mori/ops/dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(self, config):
self._get_dispatch_receiver_token_idx_map_func = _cpp_dispatch_combine_factory(
"get_dispatch_receiver_token_idx_map"
)
self._get_registered_input_buffer = _cpp_dispatch_combine_factory(
"get_registered_input_buffer"
self._get_registered_combine_input_buffer = _cpp_dispatch_combine_factory(
"get_registered_combine_input_buffer"
)

def get_registered_input_buffer(self, dtype: torch.dtype):
return self._get_registered_input_buffer(self._handle, dtype)
def get_registered_combine_input_buffer(self, dtype: torch.dtype):
return self._get_registered_combine_input_buffer(self._handle, dtype)

def dispatch(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/ops/dispatch_combine/dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
(config.hiddenDim * config.maxTokenTypeSize +
(sizeof(float) + sizeof(index_t)) * config.numExpertPerToken +
config.scaleDim * config.scaleTypeSize);
shmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
dispatchShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
combineShmemInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached);
shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);

Expand All @@ -90,7 +91,8 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
}

void EpDispatchCombineHandle::FinalizeShmemBuf() {
ShmemFree(shmemInpTokMemObj->localPtr);
ShmemFree(dispatchShmemInpTokMemObj->localPtr);
ShmemFree(combineShmemInpTokMemObj->localPtr);
ShmemFree(shmemOutTokMemObj->localPtr);
ShmemFree(shmemStagingTokMemObj->localPtr);
ShmemFree(shmemInpWeightsMemObj->localPtr);
Expand Down
20 changes: 10 additions & 10 deletions src/ops/dispatch_combine/internode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
tokenId * config.scaleDim * config.scaleTypeSize,
config.scaleDim * config.scaleTypeSize);
}
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, peSortedOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.dispatchShmemInpTokMemObj, peSortedOffset,
args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset,
destPe);
}
Expand Down Expand Up @@ -202,7 +202,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t srcOffset = srcIdx * stagingOffset;
const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset;
size_t dstOffset = dstIdx * stagingOffset;
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, dstOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.dispatchShmemInpTokMemObj, dstOffset,
args.shmemStagingTokMemObj, srcOffset,
actualTokenNum * stagingOffset, destPe);

Expand Down Expand Up @@ -297,23 +297,23 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset;

core::WarpCopy(args.shmemOutTokMemObj->template GetAs<char*>() + localTokenOffset,
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset,
args.dispatchShmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset,
config.hiddenDim * sizeof(T));
core::WarpCopy(
args.shmemOutWeightsMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(float),
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + weightOffset,
args.dispatchShmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + weightOffset,
config.numExpertPerToken * sizeof(float));
core::WarpCopy(
args.shmemOutIndicesMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(index_t),
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + indicesOffset,
args.dispatchShmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + indicesOffset,
config.numExpertPerToken * sizeof(index_t));
if (args.scalesBuf && (config.scaleDim > 0) && (config.scaleTypeSize > 0)) {
core::WarpCopy(
args.shmemOutScalesMemObj->template GetAs<char*>() +
localTokenIdx * config.scaleDim * config.scaleTypeSize,
args.shmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + scalesOffset,
args.dispatchShmemInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset + scalesOffset,
config.scaleDim * config.scaleTypeSize);
}
if (laneId == 0) {
Expand Down Expand Up @@ -426,7 +426,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
weightSize);
}

shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, peSortedOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.combineShmemInpTokMemObj, peSortedOffset,
args.shmemStagingTokMemObj, mapIdxOffset, tokenPackSize,
srcPe);
}
Expand Down Expand Up @@ -457,7 +457,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t srcOffset = srcIdx * tokenPackSize;
const index_t dstIdx = myPe * MaxNumTokensToRecvPerRank + startIdx + chunkOffset;
size_t dstOffset = dstIdx * tokenPackSize;
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.shmemInpTokMemObj, dstOffset,
shmem::ShmemPutTypeNbiWarp<uint8_t>(args.combineShmemInpTokMemObj, dstOffset,
args.shmemStagingTokMemObj, srcOffset,
actualTokenNum * tokenPackSize, srcPe);

Expand Down Expand Up @@ -531,9 +531,9 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {

if (destPe < config.worldSize) {
srcPtrs[j] =
reinterpret_cast<T*>(args.shmemInpTokMemObj->template GetAs<char*>() + byteOffset);
reinterpret_cast<T*>(args.combineShmemInpTokMemObj->template GetAs<char*>() + byteOffset);
srcWeightsPtr[j] = reinterpret_cast<float*>(
args.shmemInpTokMemObj->template GetAs<char*>() + weightByteOffset);
args.combineShmemInpTokMemObj->template GetAs<char*>() + weightByteOffset);
} else {
srcPtrs[j] = nullptr;
srcWeightsPtr[j] = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions src/ops/dispatch_combine/intranode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
index_t totalRecvTokenNum = args.totalRecvTokenNum[0];
if (args.config.useExternalInpBuffer) {
for (int i = globalWarpId; i < totalRecvTokenNum; i += globalWarpNum) {
core::WarpCopy(args.shmemInpTokMemObj->template GetAs<T*>() + i * config.hiddenDim,
core::WarpCopy(args.combineShmemInpTokMemObj->template GetAs<T*>() + i * config.hiddenDim,
args.inpTokenBuf + i * config.hiddenDim, config.hiddenDim);
}
}
Expand Down Expand Up @@ -252,7 +252,7 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {

if (destPe < config.worldSize) {
index_t destLocalTokId = destTokId - destPe * maxNumOutTokenPerRank;
srcPtrs[j] = args.shmemInpTokMemObj->template GetAs<T*>(destPe) +
srcPtrs[j] = args.combineShmemInpTokMemObj->template GetAs<T*>(destPe) +
destLocalTokId * config.hiddenDim + hiddenDimOffset;
srcWeightsPtr[j] = args.shmemInpWeightsMemObj->template GetAs<float*>(destPe) +
destLocalTokId * config.numExpertPerToken;
Expand Down
8 changes: 4 additions & 4 deletions src/pybind/mori.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ torch::Tensor GetDispatchReceiverTokenIdxMap(mori::moe::EpDispatchCombineHandle&
return tensor;
}

torch::Tensor GetRegisteredInputBuffer(mori::moe::EpDispatchCombineHandle& handle,
torch::Tensor GetRegisteredCombineInputBuffer(mori::moe::EpDispatchCombineHandle& handle,
at::ScalarType scalarType) {
torch::Tensor out =
torch::from_blob(handle.shmemInpTokMemObj->Get(),
torch::from_blob(handle.combineShmemInpTokMemObj->Get(),
{handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim},
torch::TensorOptions().dtype(scalarType).device(torch::kCUDA));
return out;
Expand Down Expand Up @@ -209,8 +209,8 @@ void DeclareEpDispatchCombineHandle(pybind11::module& m) {
funcName = std::string("get_dispatch_receiver_token_idx_map");
m.def(funcName.c_str(), &GetDispatchReceiverTokenIdxMap);

funcName = std::string("get_registered_input_buffer");
m.def(funcName.c_str(), &GetRegisteredInputBuffer);
funcName = std::string("get_registered_combine_input_buffer");
m.def(funcName.c_str(), &GetRegisteredCombineInputBuffer);
}

} // namespace
Expand Down
Loading