|
6 | 6 |
|
7 | 7 | #include <folly/dynamic.h> |
8 | 8 |
|
9 | | -#include "comms/ctran/Ctran.h" |
| 9 | +#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h" |
10 | 10 | #include "comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h" |
11 | 11 | #include "comms/ctran/algos/common/GpeKernel.h" |
12 | 12 | #include "comms/ctran/gpe/CtranChecksum.h" |
|
17 | 17 | #include "comms/ctran/tracing/CollTraceWrapper.h" |
18 | 18 | #include "comms/ctran/tracing/MapperTrace.h" |
19 | 19 | #include "comms/ctran/utils/Checks.h" |
20 | | -#include "comms/ctran/utils/CudaGraphUtils.h" |
21 | 20 | #include "comms/ctran/utils/CudaWrap.h" |
22 | 21 | #include "comms/ctran/utils/Debug.h" |
23 | 22 | #include "comms/ctran/utils/Exception.h" |
@@ -150,7 +149,8 @@ commResult_t CtranGpe::Impl::submit( |
150 | 149 | opFunc func, |
151 | 150 | KernelConfig& kernelConfig, |
152 | 151 | const void* ncclKernel, |
153 | | - std::optional<std::chrono::milliseconds> timeout) { |
| 152 | + std::optional<std::chrono::milliseconds> timeout, |
| 153 | + PreLaunchGraphPrepareFn graphPrepareFn) { |
154 | 154 | commResult_t res = commSuccess; |
155 | 155 |
|
156 | 156 | // Reclaim once to gain back available flags |
@@ -250,49 +250,13 @@ commResult_t CtranGpe::Impl::submit( |
250 | 250 | } |
251 | 251 | cmd->coll.comm = comm; |
252 | 252 | } |
253 | | - |
254 | 253 | if (streamCaptureInfo.status == cudaStreamCaptureStatusActive) { |
| 254 | + FB_COMMCHECK(preLaunchGraphPrepare(cmd, graphPrepareFn)); |
255 | 255 | struct cmdCbPlan* plan = new struct cmdCbPlan; |
256 | 256 | // cudagraph-aware alltoall: transfer alltoall to alltoallPersistent for |
257 | 257 | // perf optimization |
258 | 258 | auto op = cmd->coll.opGroup.front().get(); |
259 | 259 | if (NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE && |
260 | | - op->type == OpElem::opType::ALLTOALL) { |
261 | | - CtranPersistentRequest* pReq; |
262 | | - // FIXME: update alltoall API to allow passing hints to skip/not skip |
263 | | - // ctrl msg exchange. |
264 | | - meta::comms::Hints hints; |
265 | | - hints.set("ncclx_alltoallp_skip_ctrl_msg_exchange", "true"); |
266 | | - // The init will submit a GPE op exchangeMemHandle that not captured by |
267 | | - // cudagraph. |
268 | | - // FIXME: for cudagraph, the sendbuff is also persistent, should record |
269 | | - // its handle in pReq and skip searchRegHandle in exec. |
270 | | - // FIXME: the gpe thread should call algo impl instead of user API to |
271 | | - // allow more flexibility in cudagraph mode. |
272 | | - ctran::AllToAllPInit( |
273 | | - op->alltoall.recvbuff, |
274 | | - op->alltoall.count * op->comm_->statex_->nRanks(), |
275 | | - hints, |
276 | | - op->alltoall.datatype, |
277 | | - op->comm_, |
278 | | - op->stream, |
279 | | - pReq); |
280 | | - |
281 | | - // Capture alltoallp op instead of alltoall because alltoall under |
282 | | - // cudagraph is essentially alltoallp. A new alltoallp op will be |
283 | | - // submitted inside AllToAllPExec so we can return once it's done. |
284 | | - // Release kernel args grabbed earlier |
285 | | - if (kernelFlag != nullptr) { |
286 | | - kernelFlag->reset(); |
287 | | - } |
288 | | - // Add callback for alltoallp cmd instead. |
289 | | - // FIXME: the gpe thread should call algo impl instead of user API to |
290 | | - // allow more flexibility in cudagraph mode. |
291 | | - FB_COMMCHECK(ctran::AllToAllPExec( |
292 | | - op->alltoall.sendbuff, op->alltoall.count, pReq)); |
293 | | - return commSuccess; |
294 | | - } else if ( |
295 | | - NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE && |
296 | 260 | op->type == OpElem::opType::ALLTOALLV_DYNAMIC_SPLIT_NON_CONTIG) { |
297 | 261 | // FIXME: this should control by hints passed from user instead of CVAR |
298 | 262 | // so we can have per-collective control |
|
0 commit comments