Skip to content

Generalized Tensor Parallelism (GTP) #3005

Open
fanshiqing wants to merge 3 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release
Open

Generalized Tensor Parallelism (GTP) #3005
fanshiqing wants to merge 3 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release

Conversation

@fanshiqing

@fanshiqing fanshiqing commented May 18, 2026

Copy link
Copy Markdown
Member

Deisgn doc: GTP.docx

Description

Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.

Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.

Summary of features

  1. Fine-grained materialization & gradient reduction
  • Weight, gradient, and optimizer states are sharded along the GTP group.
  • Weights are temporarily materialized through prefetching in both the forward and backward passes.
  1. Composability with TP / SP / EP / DDP with efficient overlapping of computation and communication
  • GEMM + TP/EP communication + GTP communication + DDP communication.
  1. GTP + partial Cudagraphs with fine-grained synchronization across graphs
  • GTP reduce-scatter overlapping across graphs.
  1. Low-Precision quantize-then-gather
  • MXFP8 / NVFP4
  • Auto-padding/stripping to satisfy low-precision alignment requirements.
  1. Parallel folding for MoE layer
  • Support configuring the GTP size for dense layers and MoE layers separately.
  1. Distributed checkpointing

How Mcore interacts with TE

① Mcore registers callbacks into TE at import time.

② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).

③ Mcore extensions forward gtp_group= at module init.

④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.

image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • transformer_engine/pytorch/module/base.py (+76 / −2)
    • GTP hook registry: register_gtp_hooks(), maybe_wrap_gtp()
  • transformer_engine/pytorch/module/linear.py (+72 / −2)
    • Linear(gtp_group=…) kwarg
    • fwd: optional all_gather_and_prefetch rebind and skip workspace save;
    • bwd: re-gather + wgrad_reduce_scatter + main_grad write-back guard + sharded
      wgrad_shape.
  • transformer_engine/pytorch/module/layernorm_linear.py (+60 / −5)
    • same pattern mirrored for the fused LN+Linear path
  • transformer_engine/pytorch/module/grouped_linear.py (+115 / −16)
    • GroupedLinear(gtp_group=…) + maybe_wrap_gtp(..., is_grouped=True); dual saved-tensor
      carving (with/without GTP);
    • batched_all_gather_and_prefetch + batched_all_gather_and_prefetch_bwd + batched_wgrad_reduce_scatter
  • transformer_engine/pytorch/distributed.py (+142 / −53)
    • in-place .copy_() for amax/scale_inv/data so storage addresses stay stable across CUDA-graph replay.
    • GTP runtime depends on this for prefetch overlap.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [] I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds Generalized Tensor Parallelism (GTP) to TransformerEngine's Linear, LayerNormLinear, and GroupedLinear modules, enabling fine-grained sharding of weights/gradients with compute–communication overlap via a callback-based integration with Megatron-Core.

  • GTP hook registry (register_gtp_hooks, maybe_wrap_gtp) added to base.py decouples TE from Mcore at import time; per-weight _gtp_slice_fn fires inside reset_parameters, and _gtp_wrap_fn fires at module init.
  • Forward/backward plumbing in all three Linear modules adds all_gather_and_prefetch / wgrad_reduce_scatter calls, guards FP8 quantize for sharded params, and saves sharded references for backward re-gather.
  • distributed.py switches NVFP4/MXFP8 gather post-processing to in-place .copy_() for CUDA-graph pointer stability and adds output_tensor/grouped parameters to support GTP buffer caching and batched NCCL coalescing.

Confidence Score: 3/5

The backward path in linear.py may silently discard the weight gradient produced by wgrad_reduce_scatter when GTP is active alongside Megatron's custom-DDP fuse_wgrad_accumulation, due to a missing guard that layernorm_linear.py correctly includes.

The core wgrad path in _linear_backward lacks the if gtp_size > 1: pass guard that layernorm_linear.py uses to skip the custom-DDP dummy-wgrad section. If GTPShardedParam exposes grad_added_to_main_grad (which the shape-fix comment implies it does), the RS result is overwritten with a zero/dummy tensor on every Linear backward under GTP + fuse_wgrad_accumulation, silently zeroing the weight gradient. The two modules implement the same GTP pattern differently, and one of them is wrong.

transformer_engine/pytorch/module/linear.py — the dummy-wgrad section around line 1259 needs the same if bwd_args.gtp_size > 1: pass guard present in layernorm_linear.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds gtp_group parameter and GTP forward/backward hooks to Linear; missing GTP guard in the dummy-wgrad section may silently discard the wgrad_reduce_scatter result when GTPShardedParam exposes grad_added_to_main_grad.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors the Linear GTP pattern for the fused LN+Linear path; correctly guards the custom-DDP dummy-wgrad section with if ctx.gtp_size > 1: pass, which linear.py omits.
transformer_engine/pytorch/module/grouped_linear.py Adds gtp_group to GroupedLinear, implements batched AG/RS hooks, and fixes the weight_names attribute ordering so it is always assigned before maybe_wrap_gtp.
transformer_engine/pytorch/module/base.py Introduces the GTP hook registry and integrates the per-weight _gtp_slice_fn call inside reset_parameters; expert_idx=idx counter issue already flagged in previous review threads.
transformer_engine/pytorch/distributed.py Switches NVFP4/MXFP8 gather post-processing to in-place .copy_() for CUDA-graph stability; MXFP8 grouped=True + async_op=True silently returns None handle, asymmetric with NVFP4's safe wrapper.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Defers FC2 weight preparation to after the FC1 GEMM launch to overlap EGTP all-gather with computation; saves sharded params for backward re-gather.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds EGTP backward AG/RS hooks; batched_wgrad_reduce_scatter return value is discarded before dummy wgrads are returned, inconsistent with the grouped_linear.py call site.

Sequence Diagram

sequenceDiagram
    participant Mcore
    participant TE_Module as TE Linear/GroupedLinear
    participant GTPShardedParam
    participant NCCL

    Note over Mcore,TE_Module: Init
    Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
    TE_Module->>GTPShardedParam: _gtp_slice_fn(module, name, param, expert_idx)
    TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)

    Note over Mcore,NCCL: Forward Pass
    TE_Module->>GTPShardedParam: weight.setup(weight_quantizer)
    GTPShardedParam->>NCCL: "all_gather_and_prefetch(fwd=True)"
    TE_Module->>TE_Module: GEMM(input, gathered_weight)
    TE_Module->>TE_Module: save sharded param refs for bwd

    Note over Mcore,NCCL: Backward Pass
    TE_Module->>GTPShardedParam: saved_weight.all_gather_and_prefetch_bwd()
    TE_Module->>TE_Module: dgrad GEMM(grad_output, gathered_weight)
    TE_Module->>TE_Module: wgrad GEMM(input, grad_output)
    GTPShardedParam->>NCCL: wgrad_reduce_scatter(wgrad) async on rs_stream
    NCCL-->>GTPShardedParam: shard lands in main_grad buffer
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/module/linear.py, line 1257-1278 (link)

    P1 Missing GTP guard before dummy-wgrad section — inconsistent with layernorm_linear.py

    layernorm_linear.py line 1108 has an explicit early exit: if ctx.gtp_size > 1: pass # GTP: skip — wgrad RS already produced the correct shard. which prevents the fuse_wgrad_accumulation dummy-wgrad path from running for GTP. linear.py has no equivalent guard; when GTPShardedParam exposes a grad_added_to_main_grad attribute (for Megatron custom-DDP compatibility) and fuse_wgrad_accumulation=True, wgrad — the tensor returned by saved_weight.wgrad_reduce_scatter(wgrad) on line 1217 — is overwritten here with get_dummy_wgrad(...) and the RS result is silently discarded. The shape-fix comment ("Use the param's local shape (sharded under GTP)") confirms GTP was anticipated at this call site, making the asymmetry with layernorm_linear.py harder to explain as intentional.

Reviews (9): Last reviewed commit: "GTP + gmm fusion" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
@fanshiqing

Copy link
Copy Markdown
Member Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/distributed.py
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1287 to 1295
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# In-place .copy_() (not `=` rebind) to keep the storage address stable
# for CUDA graph capture — replays see the same pointer they captured.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# Optionally pad the scaling inverse if needed (same in-place pattern).
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shape mismatch in _post_process_nvfp4_gather breaks any K not a multiple of 128

out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.

The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1660 to +1680
@@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
with get_rng_state_tracker().fork():
init_fn(param)

# GTP slice: shard the freshly-init weight into a GTPShardedParam;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing

expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.

Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.

A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
@fanshiqing fanshiqing requested a review from timmoon10 as a code owner June 4, 2026 04:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants