From f1700a27f0facd501b3c6f82c71b4b1b64bbe275 Mon Sep 17 00:00:00 2001 From: Yan Zhang Date: Tue, 3 Mar 2026 14:21:59 +0800 Subject: [PATCH 1/2] change ops serial execution to parallel execution --- src/rma/rma_coll.cc | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/rma/rma_coll.cc b/src/rma/rma_coll.cc index 23cd1b7..d4aea59 100644 --- a/src/rma/rma_coll.cc +++ b/src/rma/rma_coll.cc @@ -113,7 +113,7 @@ template static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollState* rmaCollState, struct ncclRmaArgs* rmaArgs, cudaStream_t mainStream, int taskCount/*tasks of particular type*/, NcclRmaFunc_t func/*Rma funcName*/, SetWorkFn setWorkField/*Lambda for setting tmpWork*/, - int& opCnt) { + cudaEvent_t opEvent, int& opCnt) { if (taskCount <= 0) { return ncclSuccess; // no need to update opCnt } @@ -136,8 +136,7 @@ static ncclResult_t launchRmaOpHelper(struct ncclComm* comm, struct ncclRmaCollS } else { // Subsequent operations: launch on separate rmaCollStream with synchronization cudaStream_t opStream = rmaCollState->rmaCollStream[opCnt - 1]; - cudaEvent_t opEvent = rmaCollState->rmaCollEvent[opCnt - 1]; - CUDACHECK(cudaEventRecord(opEvent, mainStream)); + assert(opEvent != nullptr); CUDACHECK(cudaStreamWaitEvent(opStream, opEvent, 0)); NCCLCHECK(func(comm, &tmpWork, opStream)); } @@ -203,6 +202,20 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla // } int opCnt = 0; // Counter for number of operations launched in this batch + // Record one batch-level start event on main stream and reuse it for all + // secondary operation launches in this batch. + int activeOpTypes = 0; + activeOpTypes += (batch->nProxyPut > 0); + activeOpTypes += (batch->nProxyWaitSignal > 0); + activeOpTypes += (batch->nCePut > 0); + activeOpTypes += (batch->nCeWaitSignal > 0); + cudaEvent_t batchStartEvent = nullptr; + if (activeOpTypes > 1) { + // Use a dedicated event slot that is not used by per-stream completion sync. + batchStartEvent = rmaCollState->rmaCollEvent[NCCL_RMA_COLL_MAX_STREAMS - 1]; + CUDACHECKGOTO(cudaEventRecord(batchStartEvent, mainStream), ret, fail); + } + // Launch the four types of RMA operations in parallel: // 1. ProxyPut NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -213,7 +226,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksProxy = batch->nProxyPut; w.rmaTaskQueueProxy = batch->proxyPutQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // 2. ProxyWaitSignal NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -223,7 +236,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksProxy = batch->nProxyWaitSignal; w.rmaTaskQueueProxy = batch->proxyWaitSignalQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // 3. CePut NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -233,7 +246,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksCe = batch->nCePut; w.rmaTaskQueueCe = batch->cePutQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // 4. CeWaitSignal NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, @@ -243,7 +256,7 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla w.rmaArgs->nRmaTasksCe = batch->nCeWaitSignal; w.rmaTaskQueueCe = batch->ceWaitSignalQueue; }, - opCnt), ret, fail); + batchStartEvent, opCnt), ret, fail); // Synchronize all secondary streams back to main stream for (int idx = 0; idx < opCnt - 1; idx++) { From 575a57f7fcf3b89c8403d97b38ed74c219de1169 Mon Sep 17 00:00:00 2001 From: Yan Zhang Date: Wed, 4 Mar 2026 15:18:14 +0800 Subject: [PATCH 2/2] temp fix to enable full parallel. --- src/include/rma/rma.h | 2 +- src/rma/rma_coll.cc | 32 ++++++++++++++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/include/rma/rma.h b/src/include/rma/rma.h index d045f28..8adf1c1 100644 --- a/src/include/rma/rma.h +++ b/src/include/rma/rma.h @@ -51,7 +51,7 @@ struct ncclRmaCollArgs { // Only rmaArgs, rmaTaskQueueProxy and rmaTaskQueueCe fields are used in ncclRmaWork. using ncclRmaWork = ncclKernelPlan; -constexpr int NCCL_RMA_COLL_MAX_STREAMS = 4; +constexpr int NCCL_RMA_COLL_MAX_STREAMS = 16; static_assert(NCCL_RMA_COLL_MAX_STREAMS >= 4, "NCCL_RMA_COLL_MAX_STREAMS must be at least 4"); struct ncclRmaCollState { bool initialized; diff --git a/src/rma/rma_coll.cc b/src/rma/rma_coll.cc index d4aea59..9f652e5 100644 --- a/src/rma/rma_coll.cc +++ b/src/rma/rma_coll.cc @@ -204,13 +204,13 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla // Record one batch-level start event on main stream and reuse it for all // secondary operation launches in this batch. - int activeOpTypes = 0; - activeOpTypes += (batch->nProxyPut > 0); - activeOpTypes += (batch->nProxyWaitSignal > 0); - activeOpTypes += (batch->nCePut > 0); - activeOpTypes += (batch->nCeWaitSignal > 0); + int activeOps = 0; + activeOps += (batch->nProxyPut > 0); + activeOps += (batch->nProxyWaitSignal > 0); + activeOps += batch->nCePut; + activeOps += (batch->nCeWaitSignal > 0); cudaEvent_t batchStartEvent = nullptr; - if (activeOpTypes > 1) { + if (activeOps > 1) { // Use a dedicated event slot that is not used by per-stream completion sync. batchStartEvent = rmaCollState->rmaCollEvent[NCCL_RMA_COLL_MAX_STREAMS - 1]; CUDACHECKGOTO(cudaEventRecord(batchStartEvent, mainStream), ret, fail); @@ -239,14 +239,18 @@ ncclResult_t ncclLaunchRmaColl(struct ncclComm* comm, struct ncclKernelPlan* pla batchStartEvent, opCnt), ret, fail); // 3. CePut - NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, - batch->nCePut, - ncclRmaPutCe, - [&](ncclRmaWork& w) { - w.rmaArgs->nRmaTasksCe = batch->nCePut; - w.rmaTaskQueueCe = batch->cePutQueue; - }, - batchStartEvent, opCnt), ret, fail); + for (int cePutIdx = 0; cePutIdx < batch->nCePut; cePutIdx++) { + struct ncclTaskRma* cePutTask = ncclIntruQueueDequeue(&batch->cePutQueue); + NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream, + 1, + ncclRmaPutCe, + [&](ncclRmaWork& w) { + w.rmaArgs->nRmaTasksCe = 1; + ncclIntruQueueConstruct(&w.rmaTaskQueueCe); + ncclIntruQueueEnqueue(&w.rmaTaskQueueCe, cePutTask); + }, + batchStartEvent, opCnt), ret, fail); + } // 4. CeWaitSignal NCCLCHECKGOTO(launchRmaOpHelper(comm, rmaCollState, rmaArgs, mainStream,