Skip to content

Commit 06b1862

Browse files
Regina8023meta-codesync[bot]
authored andcommitted
Refactor cudagraph-aware AllToAll: move cudagraph-aware logic out from GPE submit path
Summary: Refactored cudagraph-aware alltoall D77554973 to be similar to cudagraph-aware alltoallvDynamic D78133900: moved cudagraph-aware to a function, this helps simplify GPE submit function and also remove the Ctran.h dependency from ctranGPE Reviewed By: minsii Differential Revision: D83850249 fbshipit-source-id: 5516c0a1d5166fc65f12c8f349f69786b68b65fe
1 parent 4cfa7a2 commit 06b1862

File tree

10 files changed

+163
-83
lines changed

10 files changed

+163
-83
lines changed

comms/ctran/algos/AllToAll/AllToAll.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "comms/ctran/CtranComm.h"
88
#include "comms/ctran/algos/AllToAll/AllToAllImpl.h"
9+
#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h"
910
#include "comms/ctran/algos/AllToAll/AllToAllvImpl.h"
1011
#include "comms/ctran/algos/CtranAlgo.h"
1112
#include "comms/ctran/gpe/CtranGpe.h"
@@ -125,12 +126,17 @@ commResult_t ctranAllToAll(
125126
std::vector<std::unique_ptr<struct OpElem>> opGroup;
126127
FB_COMMCHECK(setupGpeOp(
127128
sendbuff, recvbuff, count, datatype, comm, stream, opCount, opGroup));
128-
129+
ctran::PreLaunchGraphPrepareFn graphPrepareFn = nullptr;
130+
if (NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE) {
131+
graphPrepareFn = ctran::alltoallp::prepareCudagraphAwareAllToAll;
132+
}
129133
FB_COMMCHECK(comm->ctran_->gpe->submit(
130134
std::move(opGroup),
131135
opIbImpl,
132136
config,
133-
reinterpret_cast<void*>(ctran::alltoall::alltoallKerns[datatype])));
137+
reinterpret_cast<void*>(ctran::alltoall::alltoallKerns[datatype]),
138+
std::nullopt, /* timeout */
139+
graphPrepareFn));
134140

135141
return commSuccess;
136142
}

comms/ctran/algos/AllToAll/AllToAllP.cc

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ commResult_t AllToAllPInit(
3030
const auto nRanks = statex->nRanks();
3131

3232
SetCudaDevRAII setCudaDev(statex->cudaDev());
33-
size_t size = maxRecvCount * commTypeSize(datatype);
34-
void* regHdl{nullptr};
35-
bool localReg = false;
3633
AlgoImpl* algo = new AlgoImpl(comm, stream);
3734
if (!algo) {
3835
return commSystemError;
@@ -43,29 +40,10 @@ commResult_t AllToAllPInit(
4340
delete algo;
4441
}
4542
});
46-
// TODO: Pass-in a flag searchOnly to avoid dynamic register instead of reg
47-
// then deregister.
48-
FB_COMMCHECK(comm->ctran_->mapper->searchRegHandle(
49-
recvbuff, size, &regHdl, &localReg));
50-
if (localReg) {
51-
comm->ctran_->mapper->deregDynamic(regHdl);
52-
CLOGF(
53-
ERR,
54-
"recvbuff is not registered. Pointer: {} length: {}",
55-
recvbuff,
56-
size);
57-
return commInternalError;
58-
}
59-
6043
std::string skip_ctrl_msg;
6144
hints.get("ncclx_alltoallp_skip_ctrl_msg_exchange", skip_ctrl_msg);
62-
algo->pArgs = {
63-
.recvbuff = recvbuff,
64-
.recvHdl = regHdl,
65-
.maxRecvCount = maxRecvCount,
66-
.datatype = datatype,
67-
.skipCtrlMsg = (skip_ctrl_msg == "true"),
68-
};
45+
FB_COMMCHECK(algo->setPArgs(
46+
recvbuff, maxRecvCount, skip_ctrl_msg == "true", datatype));
6947
FB_COMMCHECK(algo->init());
7048
request = new CtranPersistentRequest(
7149
CtranPersistentRequest::Type::ALLTOALL_P, comm, stream);
@@ -77,11 +55,10 @@ commResult_t AllToAllPInit(
7755
CLOGF_SUBSYS(
7856
INFO,
7957
COLL,
80-
"AllToAllPInit: rank {} initialized request {}, recvbuff {} recvHdl {}, comm {} commHash {:x} commDesc {} [nranks={}, localRanks={}] stream={}",
58+
"AllToAllPInit: rank {} initialized request {}, recvbuff {}, comm {} commHash {:x} commDesc {} [nranks={}, localRanks={}] stream={}",
8159
statex->rank(),
8260
(void*)request,
8361
(void*)recvbuff,
84-
(void*)regHdl,
8562
(void*)comm,
8663
statex->commHash(),
8764
statex->commDesc(),

comms/ctran/algos/AllToAll/AllToAllPImpl.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h"
4+
#include "Types.h"
45
#include "comms/ctran/CtranComm.h"
56
#include "comms/ctran/algos/AllToAll/AllToAllImpl.h"
67
#include "comms/ctran/algos/AllToAll/AllToAllvImpl.h"
@@ -373,4 +374,48 @@ commResult_t AlgoImpl::exec(const void* sendbuff, const size_t count) {
373374
reinterpret_cast<void*>(ctran::alltoall::alltoallKerns[datatype])));
374375
return commSuccess;
375376
}
377+
378+
commResult_t AlgoImpl::updatePersistentFuncAndOp(
379+
opFunc& opFunc,
380+
struct OpElem* op) {
381+
opFunc = gpeFn;
382+
op->type = OpElem::opType::ALLTOALLP;
383+
op->alltoallP.sendbuff = op->alltoall.sendbuff;
384+
op->alltoallP.count = op->alltoall.count;
385+
op->alltoallP.pArgs = &pArgs;
386+
CLOGF_TRACE(
387+
COLL,
388+
"AllToAllP: rank {} updated op to {} and gpeFn to persistent version.",
389+
comm_->statex_->rank(),
390+
(void*)op);
391+
return commSuccess;
392+
}
393+
394+
commResult_t prepareCudagraphAwareAllToAll(
395+
opFunc& opFunc,
396+
struct OpElem* op,
397+
PersistentObj& pObj) {
398+
pObj = std::make_unique<AlgoImpl>(op->comm_, op->stream);
399+
auto algoImplPtr = std::get<std::unique_ptr<AlgoImpl>>(pObj).get();
400+
if (!algoImplPtr) {
401+
return commSystemError;
402+
}
403+
404+
FB_COMMCHECK(algoImplPtr->setPArgs(
405+
op->alltoall.recvbuff,
406+
op->alltoall.count * op->comm_->statex_->nRanks(),
407+
true /* skipCtrlMsg */,
408+
op->alltoall.datatype));
409+
410+
// Exchange mem handles and record in pArgs. This will not be captured
411+
// by cudagraph.
412+
FB_COMMCHECK(algoImplPtr->init());
413+
414+
// Replace gpe func by the persistent version (skip exchanging mem
415+
// handle); and OpGroup by the persistent op which has the remote
416+
// handles recorded.
417+
418+
FB_COMMCHECK(algoImplPtr->updatePersistentFuncAndOp(opFunc, op));
419+
return commSuccess;
420+
}
376421
} // namespace ctran::alltoallp

comms/ctran/algos/AllToAll/AllToAllPImpl.h

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,16 @@
33
#pragma once
44

55
#include <folly/synchronization/CallOnce.h>
6+
#include "Types.h"
67
#include "comms/ctran/CtranComm.h"
8+
#include "comms/ctran/gpe/CtranGpe.h"
9+
#include "comms/ctran/hints/Hints.h"
10+
#include "comms/ctran/mapper/CtranMapper.h"
711
#include "comms/ctran/mapper/CtranMapperTypes.h"
12+
#include "comms/ctran/utils/ExtUtils.h"
813
#include "comms/utils/cvars/nccl_cvars.h"
914

1015
namespace ctran::alltoallp {
11-
struct PersistArgs {
12-
void* recvbuff;
13-
void* recvHdl;
14-
size_t maxRecvCount;
15-
commDataType_t datatype;
16-
bool skipCtrlMsg;
17-
std::vector<void*> remoteRecvBuffs;
18-
std::vector<struct CtranMapperRemoteAccessKey> remoteAccessKeys;
19-
};
20-
2116
class AlgoImpl {
2217
public:
2318
PersistArgs pArgs;
@@ -30,6 +25,40 @@ class AlgoImpl {
3025

3126
commResult_t exec(const void* sendbuff, const size_t count);
3227

28+
inline commResult_t setPArgs(
29+
void* recvbuff,
30+
const size_t maxRecvCount,
31+
bool skipCtrlMsg,
32+
commDataType_t datatype) {
33+
size_t size = maxRecvCount * commTypeSize(datatype);
34+
void* regHdl{nullptr};
35+
bool localReg = false;
36+
// TODO: Pass-in a flag searchOnly to avoid dynamic register instead of reg
37+
// then deregister.
38+
FB_COMMCHECK(comm_->ctran_->mapper->searchRegHandle(
39+
recvbuff, size, &regHdl, &localReg));
40+
if (localReg) {
41+
comm_->ctran_->mapper->deregDynamic(regHdl);
42+
CLOGF(
43+
ERR,
44+
"recvbuff is not registered. Pointer: {} length: {}",
45+
recvbuff,
46+
size);
47+
return commInternalError;
48+
}
49+
50+
pArgs = {
51+
.recvbuff = recvbuff,
52+
.recvHdl = regHdl,
53+
.maxRecvCount = maxRecvCount,
54+
.datatype = datatype,
55+
.skipCtrlMsg = skipCtrlMsg,
56+
};
57+
return commSuccess;
58+
}
59+
60+
commResult_t updatePersistentFuncAndOp(opFunc& opFunc, struct OpElem* op);
61+
3362
static inline const std::string algoName(enum NCCL_ALLTOALL_ALGO algo) {
3463
switch (algo) {
3564
case NCCL_ALLTOALL_ALGO::ctran:
@@ -43,4 +72,9 @@ class AlgoImpl {
4372
CtranComm* comm_{nullptr};
4473
cudaStream_t stream_{nullptr};
4574
};
75+
76+
commResult_t prepareCudagraphAwareAllToAll(
77+
opFunc& opFunc,
78+
struct OpElem* op,
79+
PersistentObj& pObj);
4680
} // namespace ctran::alltoallp

comms/ctran/algos/AllToAll/Types.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
#include <vector>
5+
6+
#include "comms/ctran/mapper/CtranMapperTypes.h"
7+
#include "comms/utils/commSpecs.h"
8+
9+
namespace ctran::alltoallp {
10+
struct PersistArgs {
11+
void* recvbuff;
12+
void* recvHdl;
13+
size_t maxRecvCount;
14+
commDataType_t datatype;
15+
bool skipCtrlMsg;
16+
std::vector<void*> remoteRecvBuffs;
17+
std::vector<struct CtranMapperRemoteAccessKey> remoteAccessKeys;
18+
};
19+
20+
class AlgoImpl;
21+
} // namespace ctran::alltoallp

comms/ctran/gpe/CtranGpe.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,16 @@ commResult_t CtranGpe::submit(
373373
opFunc func,
374374
KernelConfig& kernelConfig,
375375
const void* ncclKernel,
376-
std::optional<std::chrono::milliseconds> timeout) {
376+
std::optional<std::chrono::milliseconds> timeout,
377+
PreLaunchGraphPrepareFn graphPrepareFn) {
377378
return this->pimpl->submit(
378379
CtranGpeCmd::TypeEnum::GRAPH_ENQUEUE,
379380
std::move(opGroup),
380381
func,
381382
kernelConfig,
382383
ncclKernel,
383-
timeout);
384+
timeout,
385+
graphPrepareFn);
384386
}
385387

386388
commResult_t CtranGpe::submitHost(

comms/ctran/gpe/CtranGpe.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "comms/ctran/CtranComm.h"
1414
#include "comms/ctran/CtranExImpl.h"
15+
#include "comms/ctran/algos/AllToAll/Types.h"
1516
#include "comms/ctran/algos/CtranAlgoDev.h"
1617
#include "comms/ctran/algos/common/GpeKernelSync.h"
1718
#include "comms/ctran/gpe/CtranGpeDev.h"
@@ -20,6 +21,13 @@
2021
typedef commResult_t (*opFunc)(
2122
const std::vector<std::unique_ptr<struct OpElem>>& opGroup);
2223

24+
namespace ctran {
25+
using PersistentObj =
26+
std::variant<std::monostate, std::unique_ptr<ctran::alltoallp::AlgoImpl>>;
27+
using PreLaunchGraphPrepareFn =
28+
commResult_t (*)(opFunc& opFunc, struct OpElem* op, PersistentObj& pObj);
29+
} // namespace ctran
30+
2331
struct OpElem {
2432
enum opType {
2533
ALLGATHER,
@@ -369,7 +377,8 @@ class CtranGpe {
369377
opFunc func,
370378
KernelConfig& kernelConfig,
371379
const void* ncclKernel,
372-
std::optional<std::chrono::milliseconds> timeout = std::nullopt);
380+
std::optional<std::chrono::milliseconds> timeout = std::nullopt,
381+
ctran::PreLaunchGraphPrepareFn graphPrepareFn = nullptr);
373382

374383
// Submit host mem communication. No kernel is launched, and only the host
375384
// side func will be submitted to the GPE thread. Also the op won't be

comms/ctran/gpe/CtranGpeImpl.cc

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <folly/dynamic.h>
88

9-
#include "comms/ctran/Ctran.h"
9+
#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h"
1010
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h"
1111
#include "comms/ctran/algos/common/GpeKernel.h"
1212
#include "comms/ctran/gpe/CtranChecksum.h"
@@ -17,7 +17,6 @@
1717
#include "comms/ctran/tracing/CollTraceWrapper.h"
1818
#include "comms/ctran/tracing/MapperTrace.h"
1919
#include "comms/ctran/utils/Checks.h"
20-
#include "comms/ctran/utils/CudaGraphUtils.h"
2120
#include "comms/ctran/utils/CudaWrap.h"
2221
#include "comms/ctran/utils/Debug.h"
2322
#include "comms/ctran/utils/Exception.h"
@@ -150,7 +149,8 @@ commResult_t CtranGpe::Impl::submit(
150149
opFunc func,
151150
KernelConfig& kernelConfig,
152151
const void* ncclKernel,
153-
std::optional<std::chrono::milliseconds> timeout) {
152+
std::optional<std::chrono::milliseconds> timeout,
153+
PreLaunchGraphPrepareFn graphPrepareFn) {
154154
commResult_t res = commSuccess;
155155

156156
// Reclaim once to gain back available flags
@@ -250,49 +250,13 @@ commResult_t CtranGpe::Impl::submit(
250250
}
251251
cmd->coll.comm = comm;
252252
}
253-
254253
if (streamCaptureInfo.status == cudaStreamCaptureStatusActive) {
254+
FB_COMMCHECK(preLaunchGraphPrepare(cmd, graphPrepareFn));
255255
struct cmdCbPlan* plan = new struct cmdCbPlan;
256256
// cudagraph-aware alltoall: transfer alltoall to alltoallPersistent for
257257
// perf optimization
258258
auto op = cmd->coll.opGroup.front().get();
259259
if (NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE &&
260-
op->type == OpElem::opType::ALLTOALL) {
261-
CtranPersistentRequest* pReq;
262-
// FIXME: update alltoall API to allow passing hints to skip/not skip
263-
// ctrl msg exchange.
264-
meta::comms::Hints hints;
265-
hints.set("ncclx_alltoallp_skip_ctrl_msg_exchange", "true");
266-
// The init will submit a GPE op exchangeMemHandle that not captured by
267-
// cudagraph.
268-
// FIXME: for cudagraph, the sendbuff is also persistent, should record
269-
// its handle in pReq and skip searchRegHandle in exec.
270-
// FIXME: the gpe thread should call algo impl instead of user API to
271-
// allow more flexibility in cudagraph mode.
272-
ctran::AllToAllPInit(
273-
op->alltoall.recvbuff,
274-
op->alltoall.count * op->comm_->statex_->nRanks(),
275-
hints,
276-
op->alltoall.datatype,
277-
op->comm_,
278-
op->stream,
279-
pReq);
280-
281-
// Capture alltoallp op instead of alltoall because alltoall under
282-
// cudagraph is essentially alltoallp. A new alltoallp op will be
283-
// submitted inside AllToAllPExec so we can return once it's done.
284-
// Release kernel args grabbed earlier
285-
if (kernelFlag != nullptr) {
286-
kernelFlag->reset();
287-
}
288-
// Add callback for alltoallp cmd instead.
289-
// FIXME: the gpe thread should call algo impl instead of user API to
290-
// allow more flexibility in cudagraph mode.
291-
FB_COMMCHECK(ctran::AllToAllPExec(
292-
op->alltoall.sendbuff, op->alltoall.count, pReq));
293-
return commSuccess;
294-
} else if (
295-
NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE &&
296260
op->type == OpElem::opType::ALLTOALLV_DYNAMIC_SPLIT_NON_CONTIG) {
297261
// FIXME: this should control by hints passed from user instead of CVAR
298262
// so we can have per-collective control

0 commit comments

Comments
 (0)