diff --git a/src/include/rma/rma.h b/src/include/rma/rma.h index d045f28..8adf1c1 100644 --- a/src/include/rma/rma.h +++ b/src/include/rma/rma.h @@ -51,7 +51,7 @@ struct ncclRmaCollArgs { // Only rmaArgs, rmaTaskQueueProxy and rmaTaskQueueCe fields are used in ncclRmaWork. using ncclRmaWork = ncclKernelPlan; -constexpr int NCCL_RMA_COLL_MAX_STREAMS = 4; +constexpr int NCCL_RMA_COLL_MAX_STREAMS = 16; static_assert(NCCL_RMA_COLL_MAX_STREAMS >= 4, "NCCL_RMA_COLL_MAX_STREAMS must be at least 4"); struct ncclRmaCollState { bool initialized; diff --git a/src/rma/rma_coll.cc b/src/rma/rma_coll.cc index 23cd1b7..9f652e5 100644 --- a/src/rma/rma_coll.cc +++ b/src/rma/rma_coll.cc @@ -113,7 +113,7 @@ template static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollState* rmaCollState, struct ncclRmaArgs* rmaArgs, cudaStream_t mainStream, int taskCount/*tasks of particular type*/, NcclRmaFunc_t func/*Rma funcName*/, SetWorkFn setWorkField/*Lambda for setting tmpWork*/, - int& opCnt) { + cudaEvent_t opEvent, int& opCnt) { if (taskCount <= 0) { return ncclSuccess; // no need to update opCnt } @@ -136,8 +136,7 @@ static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollS } else { // Subsequent operations: launch on separate rmaCollStream with synchronization cudaStream_t opStream = rmaCollState->rmaCollStream[opCnt - 1]; - cudaEvent_t opEvent = rmaCollState->rmaCollEvent[opCnt - 1]; - CUDACHECK(cudaEventRecord(opEvent, mainStream)); + assert(opEvent != nullptr); CUDACHECK(cudaStreamWaitEvent(opStream, opEvent, 0)); NCCLCHECK(func(comm, &tmpWork, opStream)); } @@ -203,6 +202,20 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla // } int opCnt = 0; // Counter for number of operations launched in this batch + // Record one batch-level start event on main stream and reuse it for all + // secondary operation launches in this batch. + int activeOps = 0; + activeOps += (batch->nProxyPut > 0); + activeOps += (batch->nProxyWaitSignal > 0); + activeOps += batch->nCePut; + activeOps += (batch->nCeWaitSignal > 0); + cudaEvent_t batchStartEvent = nullptr; + if (activeOps > 1) { + // Use a dedicated event slot that is not used by per-stream completion sync. + batchStartEvent = rmaCollState->rmaCollEvent[NCCL_RMA_COLL_MAX_STREAMS - 1]; + CUDACHECKGOTO(cudaEventRecord(batchStartEvent, mainStream), ret, fail); + } + // Launch the four types of RMA operations in parallel: // 1. ProxyPut NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -213,7 +226,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksProxy = batch->nProxyPut; w.rmaTaskQueueProxy = batch->proxyPutQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // 2. ProxyWaitSignal NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -223,17 +236,21 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksProxy = batch->nProxyWaitSignal; w.rmaTaskQueueProxy = batch->proxyWaitSignalQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // 3. CePut - NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, - batch->nCePut, - ncclRmaPutCe, - [&](ncclRmaWork& w) { - w.rmaArgs->nRmaTasksCe = batch->nCePut; - w.rmaTaskQueueCe = batch->cePutQueue; - }, - opCnt), ret, fail); + for (int cePutIdx = 0; cePutIdx < batch->nCePut; cePutIdx++) { + struct ncclTaskRma* cePutTask = ncclIntruQueueDequeue(&batch->cePutQueue); + NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, + 1, + ncclRmaPutCe, + [&](ncclRmaWork& w) { + w.rmaArgs->nRmaTasksCe = 1; + ncclIntruQueueConstruct(&w.rmaTaskQueueCe); + ncclIntruQueueEnqueue(&w.rmaTaskQueueCe, cePutTask); + }, + batchStartEvent, opCnt), ret, fail); + } // 4. CeWaitSignal NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -243,7 +260,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksCe = batch->nCeWaitSignal; w.rmaTaskQueueCe = batch->ceWaitSignalQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // Synchronize all secondary streams back to main stream for (int idx = 0; idx < opCnt - 1; idx++) {