[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
Conversation
908bbc2 to
69cf235
Compare
| @@ -17,6 +18,12 @@ | |||
|
|
|||
| #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 | |||
|
|
|||
| /* \brief Check if TE is built with cuBlasMp. | |||
| @@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder { | |||
| ExtComm comm); | |||
|
|
|||
| void ub_barrier(ExtComm comm); | |||
|
|
|||
| int64_t get_nccl_comm_ptr(std::string comm_name) { | |||
| NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); | |||
There was a problem hiding this comment.
This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."
4596411 to
b4ad546
Compare
Greptile SummaryThis PR integrates the cuBLASMp backend into the Comm+GEMM overlap API across both the PyTorch and JAX frameworks, adding new constructors, Python bindings, and test infrastructure so callers can opt in via
Confidence Score: 3/5Not safe to merge as-is: several correctness issues introduced by this PR remain unaddressed across the C++, Python, and test layers. The warmup helper leaks up to three device allocations whenever the cuBLASMp matmul or stream-sync throws, which happens inside the object constructor and therefore silently survives into production. The JAX NCCL ID file still has no per-run isolation token, so concurrent test jobs on a shared machine can corrupt each other's communicator bootstrap. The test script's non-FP8 cuBLASMp path never assigns all_outputs2 for the atomic+AG+check-numerics combination, meaning that code path would crash at runtime. The public comm_gemm_overlap.h header now pulls in nccl.h and comm_gemm.h unconditionally, forcing NCCL onto every downstream translation unit regardless of build flags, and test_comm_gemm.cu includes cublasmp.h without a compile-time guard, breaking non-cuBLASMp developer builds. transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup memory leak), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h (unconditional NCCL headers), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (NCCL ID file race), tests/pytorch/distributed/run_gemm_with_overlap.py (all_outputs2 undefined), tests/cpp_distributed/test_comm_gemm.cu (unconditional cublasmp.h include) Important Files Changed
Reviews (42): Last reviewed commit: "Merge remote-tracking branch 'upstream/m..." | Re-trigger Greptile |
147036f to
c5471f8
Compare
…rk extensions Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
c5471f8 to
d79bf21
Compare
364b416 to
ee517d3
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
5cb8204 to
51b64fb
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
|
/te-ci |
…d non-deterministic failures, removed XLA_FLAGS modifications for TE/JAX tests Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
|
/te-ci L1 |
| void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
||
| void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
||
| void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
There was a problem hiding this comment.
Why do we need those functions if we already have the nvte_* calls? In general this is supposed to be C API so we should not add and rely on even more C++ things to it.
There was a problem hiding this comment.
These exist only because the Userbuffers backend is all C++ and it's just keeping everything together.
I agree with you about the C API issue, but I think that's going to need to be a separate refactor PR entirely.
…tstrapping cuBLASMp Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
|
/te-ci L1 |
Signed-off-by: Alp Dener <adener@nvidia.com>
d010dad to
f8261d3
Compare
|
/te-ci L1 |
Signed-off-by: Alp Dener <adener@nvidia.com>
|
/te-ci L1 |
Description
This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.
Type of change
Checklist: