Skip to content

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443

Merged
denera merged 86 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp
Jun 4, 2026
Merged

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
denera merged 86 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12
Comment thread transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -17,6 +18,12 @@

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

@@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder {
ExtComm comm);

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 16, 2025

Greptile Summary

This PR integrates the cuBLASMp backend into the Comm+GEMM overlap API across both the PyTorch and JAX frameworks, adding new constructors, Python bindings, and test infrastructure so callers can opt in via initialize_ub(with_cublasmp=True) or --use-cublasmp flags.

  • C++ layer: Adds cuBLASMp constructors to CommOverlapCore, CommOverlapBase, and CommOverlapP2PBase that accept a raw ncclComm_t, plus a cublasmp_capture_warmup helper in the PyTorch extension layer that runs a dummy matmul to force cuBLASMp's lazy NCCL window registration and workspace allocation before any CUDA-graph capture occurs. CommOverlapHelper is extended to bootstrap and manage NCCL communicators via ncclCommInitRank when built with NVTE_WITH_CUBLASMP=1.
  • Python layer (PyTorch): initialize_ub() gains a with_cublasmp flag; module classes (Linear, LayerNormLinear, LayerNormMLP) add runtime guards that disable bulk overlaps and emit warnings for the cuBLASMp path, and the forward/backward functions are updated to read results from the GEMM output tensor rather than the extra-output buffer.
  • Python layer (JAX): cgemm_helper.cpp is updated to select between Userbuffers and cuBLASMp executors based on CgemmConfig::use_cublasmp, including use_cublasmp in the plan-ID hash for correct executor caching.

Confidence Score: 3/5

Not safe to merge as-is: several correctness issues introduced by this PR remain unaddressed across the C++, Python, and test layers.

The warmup helper leaks up to three device allocations whenever the cuBLASMp matmul or stream-sync throws, which happens inside the object constructor and therefore silently survives into production. The JAX NCCL ID file still has no per-run isolation token, so concurrent test jobs on a shared machine can corrupt each other's communicator bootstrap. The test script's non-FP8 cuBLASMp path never assigns all_outputs2 for the atomic+AG+check-numerics combination, meaning that code path would crash at runtime. The public comm_gemm_overlap.h header now pulls in nccl.h and comm_gemm.h unconditionally, forcing NCCL onto every downstream translation unit regardless of build flags, and test_comm_gemm.cu includes cublasmp.h without a compile-time guard, breaking non-cuBLASMp developer builds.

transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup memory leak), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h (unconditional NCCL headers), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (NCCL ID file race), tests/pytorch/distributed/run_gemm_with_overlap.py (all_outputs2 undefined), tests/cpp_distributed/test_comm_gemm.cu (unconditional cublasmp.h include)

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp Adds cuBLASMp bootstrap logic to CommOverlapHelper (ncclCommInitRank, shared_ptr lifetime management) and the cublasmp_capture_warmup helper; CUDA memory is freed with bare cudaFree calls that will leak on any exception thrown by the warmup GEMM or stream sync.
transformer_engine/pytorch/module/linear.py Forward path gains cuBLASMp-aware output routing and an unconditional out.view() that reshapes the output for every caller (not just cuBLASMp), which changes output shape for existing 3-D input users.
transformer_engine/pytorch/module/base.py initialize_ub() gains with_cublasmp parameter; correctly bypasses multicast check, disables bulk/external methods, and sets _ub_initialized flag. CommOverlapHelper now always takes two process-group args for the single-domain case.
transformer_engine/pytorch/module/layernorm_linear.py Bulk-overlap flags are correctly disabled at runtime via using_cublasmp_backend(); cuBLASMp-aware output routing added for forward and backward; ln_out_total aliasing guard prevents premature deallocation.
transformer_engine/pytorch/module/layernorm_mlp.py any() call now correctly takes a list; bulk flags cleared under cuBLASMp; ln_out_total alias guard added; fc2 wgrad re-gather for cuBLASMp path looks correct.
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp use_cublasmp is now included in the plan_id hash and the correct tp_rank accessor is used; NCCL ID file path still lacks a per-run isolation token, creating a cross-job race on shared machines.
tests/pytorch/distributed/run_gemm_with_overlap.py torch.transpose and tuple-unpack bugs fixed; but all_outputs2 is still never assigned for the non-FP8 cuBLASMp path (--quantization none --atomic --comm-type AG --check-numerics), causing a NameError at line 829.
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py cuBLASMp column-parallel backward path: async x gather started before dgrad GEMM+RS, waited before wgrad; FP8/MXFP8 quantizer usage split looks correct.
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h Adds cuBLASMp constructors and _with_cublasmp/_cublasmp_ctx members; nccl.h and comm_gemm.h included unconditionally without #ifdef guards, requiring NCCL in all translation units that include this header.
tests/cpp_distributed/test_comm_gemm.cu cublasmp.h included unconditionally; CMakeLists.txt appends $ENV{CUBLASMP_HOME}/include without guarding, so non-cuBLASMp builds fail to compile this test.

Reviews (42): Last reviewed commit: "Merge remote-tracking branch 'upstream/m..." | Re-trigger Greptile

greptile-apps[bot]

This comment was marked as outdated.

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
denera and others added 6 commits December 17, 2025 02:16
…rk extensions

Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 22, 2026

/te-ci pytorch

@denera denera requested a review from timmoon10 as a code owner May 27, 2026 05:14
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 27, 2026

/te-ci

@denera denera requested a review from ptrendx May 27, 2026 15:06
denera and others added 3 commits May 27, 2026 20:07
…d non-deterministic failures, removed XLA_FLAGS modifications for TE/JAX tests

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 28, 2026

/te-ci L1

Comment on lines +131 to +142
void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need those functions if we already have the nvte_* calls? In general this is supposed to be C API so we should not add and rely on even more C++ things to it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These exist only because the Userbuffers backend is all C++ and it's just keeping everything together.

I agree with you about the C API issue, but I think that's going to need to be a separate refactor PR entirely.

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Jun 1, 2026

/te-ci L1

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from d010dad to f8261d3 Compare June 1, 2026 16:34
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Jun 1, 2026

/te-ci L1

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Jun 2, 2026

/te-ci L1

@denera denera merged commit 815bf36 into NVIDIA:main Jun 4, 2026
47 of 53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants