Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/include/rma/rma.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct ncclRmaCollArgs {
// Only rmaArgs, rmaTaskQueueProxy and rmaTaskQueueCe fields are used in ncclRmaWork.
using ncclRmaWork = ncclKernelPlan;

constexpr int NCCL_RMA_COLL_MAX_STREAMS = 4;
constexpr int NCCL_RMA_COLL_MAX_STREAMS = 16;
static_assert(NCCL_RMA_COLL_MAX_STREAMS >= 4, "NCCL_RMA_COLL_MAX_STREAMS must be at least 4");
struct ncclRmaCollState {
bool initialized;
Expand Down
45 changes: 31 additions & 14 deletions src/rma/rma_coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ template <typename SetWorkFn>
static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollState* rmaCollState,
struct ncclRmaArgs* rmaArgs, cudaStream_t mainStream, int taskCount/*tasks of particular type*/,
NcclRmaFunc_t func/*Rma funcName*/, SetWorkFn setWorkField/*Lambda for setting tmpWork*/,
int& opCnt) {
cudaEvent_t opEvent, int& opCnt) {
if (taskCount <= 0) {
return ncclSuccess; // no need to update opCnt
}
Expand All @@ -136,8 +136,7 @@ static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollS
} else {
// Subsequent operations: launch on separate rmaCollStream with synchronization
cudaStream_t opStream = rmaCollState->rmaCollStream[opCnt - 1];
cudaEvent_t opEvent = rmaCollState->rmaCollEvent[opCnt - 1];
CUDACHECK(cudaEventRecord(opEvent, mainStream));
assert(opEvent != nullptr);
CUDACHECK(cudaStreamWaitEvent(opStream, opEvent, 0));
NCCLCHECK(func(comm, &tmpWork, opStream));
}
Expand Down Expand Up @@ -203,6 +202,20 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla
// }
int opCnt = 0; // Counter for number of operations launched in this batch

// Record one batch-level start event on main stream and reuse it for all
// secondary operation launches in this batch.
int activeOps = 0;
activeOps += (batch->nProxyPut > 0);
activeOps += (batch->nProxyWaitSignal > 0);
activeOps += batch->nCePut;
activeOps += (batch->nCeWaitSignal > 0);
cudaEvent_t batchStartEvent = nullptr;
if (activeOps > 1) {
// Use a dedicated event slot that is not used by per-stream completion sync.
batchStartEvent = rmaCollState->rmaCollEvent[NCCL_RMA_COLL_MAX_STREAMS - 1];
CUDACHECKGOTO(cudaEventRecord(batchStartEvent, mainStream), ret, fail);
}

// Launch the four types of RMA operations in parallel:
// 1. ProxyPut
NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,
Expand All @@ -213,7 +226,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla
w.rmaArgs->nRmaTasksProxy = batch->nProxyPut;
w.rmaTaskQueueProxy = batch->proxyPutQueue;
},
opCnt), ret, fail);
batchStartEvent, opCnt), ret, fail);

// 2. ProxyWaitSignal
NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,
Expand All @@ -223,17 +236,21 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla
w.rmaArgs->nRmaTasksProxy = batch->nProxyWaitSignal;
w.rmaTaskQueueProxy = batch->proxyWaitSignalQueue;
},
opCnt), ret, fail);
batchStartEvent, opCnt), ret, fail);

// 3. CePut
NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,
batch->nCePut,
ncclRmaPutCe,
[&](ncclRmaWork& w) {
w.rmaArgs->nRmaTasksCe = batch->nCePut;
w.rmaTaskQueueCe = batch->cePutQueue;
},
opCnt), ret, fail);
for (int cePutIdx = 0; cePutIdx < batch->nCePut; cePutIdx++) {
struct ncclTaskRma* cePutTask = ncclIntruQueueDequeue(&batch->cePutQueue);
NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,
1,
ncclRmaPutCe,
[&](ncclRmaWork& w) {
w.rmaArgs->nRmaTasksCe = 1;
ncclIntruQueueConstruct(&w.rmaTaskQueueCe);
ncclIntruQueueEnqueue(&w.rmaTaskQueueCe, cePutTask);
},
batchStartEvent, opCnt), ret, fail);
}

// 4. CeWaitSignal
NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,
Expand All @@ -243,7 +260,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla
w.rmaArgs->nRmaTasksCe = batch->nCeWaitSignal;
w.rmaTaskQueueCe = batch->ceWaitSignalQueue;
},
opCnt), ret, fail);
batchStartEvent, opCnt), ret, fail);

// Synchronize all secondary streams back to main stream
for (int idx = 0; idx < opCnt - 1; idx++) {
Expand Down