Skip to content

Generalized Tensor Parallelism (GTP) #3005

Open
fanshiqing wants to merge 7 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release
Open

Generalized Tensor Parallelism (GTP) #3005
fanshiqing wants to merge 7 commits into
NVIDIA:mainfrom
fanshiqing:gtp_release

Conversation

@fanshiqing

@fanshiqing fanshiqing commented May 18, 2026

Copy link
Copy Markdown
Member

Deisgn doc: GTP.docx

Description

Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.

Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.

Summary of features

  1. Fine-grained materialization & gradient reduction
  • Weight, gradient, and optimizer states are sharded along the GTP group.
  • Weights are temporarily materialized through prefetching in both the forward and backward passes.
  1. Composability with TP / SP / EP / DDP with efficient overlapping of computation and communication
  • GEMM + TP/EP communication + GTP communication + DDP communication.
  1. GTP + partial Cudagraphs with fine-grained synchronization across graphs
  • GTP reduce-scatter overlapping across graphs.
  1. Low-Precision quantize-then-gather
  • MXFP8 / NVFP4
  • Auto-padding/stripping to satisfy low-precision alignment requirements.
  1. Parallel folding for MoE layer
  • Support configuring the GTP size for dense layers and MoE layers separately.
  1. Distributed checkpointing

How Mcore interacts with TE

① Mcore registers callbacks into TE at import time.

② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).

③ Mcore extensions forward gtp_group= at module init.

④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.

image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • transformer_engine/pytorch/module/base.py (+76 / −2)
    • GTP hook registry: register_gtp_hooks(), maybe_wrap_gtp()
  • transformer_engine/pytorch/module/linear.py (+72 / −2)
    • Linear(gtp_group=…) kwarg
    • fwd: optional all_gather_and_prefetch rebind and skip workspace save;
    • bwd: re-gather + wgrad_reduce_scatter + main_grad write-back guard + sharded
      wgrad_shape.
  • transformer_engine/pytorch/module/layernorm_linear.py (+60 / −5)
    • same pattern mirrored for the fused LN+Linear path
  • transformer_engine/pytorch/module/grouped_linear.py (+115 / −16)
    • GroupedLinear(gtp_group=…) + maybe_wrap_gtp(..., is_grouped=True); dual saved-tensor
      carving (with/without GTP);
    • batched_all_gather_and_prefetch + batched_all_gather_and_prefetch_bwd + batched_wgrad_reduce_scatter
  • transformer_engine/pytorch/distributed.py (+142 / −53)
    • in-place .copy_() for amax/scale_inv/data so storage addresses stay stable across CUDA-graph replay.
    • GTP runtime depends on this for prefetch overlap.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • [] I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces Generalized Tensor Parallelism (GTP) across TE's Linear, LayerNormLinear, GroupedLinear, and fused GroupedMLP modules, enabling fine-grained sharding and just-in-time weight materialization with computation/communication overlap. The changes are substantial — adding gtp_group kwargs to module constructors, GTP hook callbacks in base.py, CUDA-graph-safe in-place .copy_() semantics in distributed.py, and new batched_* gather/RS paths throughout.

  • distributed.py adds output= / output_tensor= / grouped= parameters to the RS and AG helpers, switches from attribute rebinding to in-place .copy_() for CUDA-graph stability, and introduces _NVFP4AllGatherAsyncHandle.post_process_nvfp4_gather() for caller-driven post-processing.
  • module/base.py registers GTP hook slots (_gtp_slice_fn, _gtp_finalize_fn, _gtp_wrap_fn) and wires them into reset_parameters, passing expert_idx=idx (the raw enumeration counter) to the slice hook for each parameter.
  • module/linear.py and layernorm_linear.py mirror GTP all-gather + RS paths in their forward/backward functions; layernorm_linear.py correctly guards the dummy-wgrad fuse_wgrad_accumulation block with if ctx.gtp_size > 1: pass, but linear.py is missing the equivalent guard.
  • grouped_mlp.py moves FC2 weight preparation after the FC1 activation kernel (required for "one weight live at a time" GTP memory discipline), adds fine-grained per-op activation-offload markers, and places the FC1 backward GTP all-gather inside the if use_nvfp4: … else: branch, leaving the NVFP4 FC1 dgrad path without a weight gather.

Confidence Score: 3/5

Several previously-flagged defects remain unresolved in the changed files: the wrong expert_idx counter in base.py, the missing GTP dummy-wgrad guard in linear.py, the FC1 dgrad path in grouped_mlp.py that skips the weight all-gather for NVFP4, and the shape-mismatch and double-processing risks in distributed.py. These affect core forward/backward correctness on the GTP paths.

The TE-side hook registry and module-level plumbing are straightforward additions that look mechanically correct. The in-place .copy_() discipline in distributed.py is a sound approach for CUDA-graph pointer stability. However, multiple previously-identified defects in the same changed files have not been addressed: the expert_idx counter includes non-weight parameters (breaking GTP sharding for LayerNormLinear), linear.py silently discards the RS handle when fuse_wgrad_accumulation=True and the weight carries grad_added_to_main_grad, grouped_mlp.py produces wrong FC1 input gradients when use_nvfp4=True with GTP active, and distributed.py has a potential shape mismatch on the columnwise_scale_inv copy path and a double-execution risk in post_process_nvfp4_gather.

transformer_engine/pytorch/module/linear.py (missing GTP guard before dummy-wgrad block), transformer_engine/pytorch/module/base.py (expert_idx counter includes non-weight params), transformer_engine/pytorch/ops/fused/grouped_mlp.py (FC1 nvfp4 dgrad path skips weight all-gather), transformer_engine/pytorch/distributed.py (columnwise_scale_inv shape mismatch and _synchronized not set in post_process_nvfp4_gather)

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds output-buffer and grouped-coalescing parameters to gather/RS helpers; switches amax/data/scale_inv to in-place .copy_() for CUDA-graph pointer stability. Shape-mismatch risk when K is not a multiple of 128 (flagged previously) and a double-processing path via post_process_nvfp4_gather (flagged previously) remain unresolved.
transformer_engine/pytorch/module/base.py Adds GTP hook registry and wires _gtp_slice_fn into reset_parameters; expert_idx passed as the raw parameter-enumeration index, which includes non-weight params (LN weight/bias), so weights in LayerNormLinear and biased GroupedLinear arrive at the slice hook with incorrect indices (flagged previously).
transformer_engine/pytorch/module/linear.py Adds gtp_group kwarg and GTP all-gather/RS paths to forward and backward; fuse_wgrad_accumulation dummy-wgrad block missing GTP guard that layernorm_linear.py has, risking silent discard of the RS handle when fuse_wgrad_accumulation=True (flagged previously).
transformer_engine/pytorch/module/layernorm_linear.py Mirrors the GTP all-gather + wgrad RS pattern from linear.py and correctly guards the dummy-wgrad fuse_wgrad_accumulation block; GTP wiring looks consistent with linear.py modulo the intentional LN weight exclusion.
transformer_engine/pytorch/module/grouped_linear.py Adds gtp_group kwarg; weight_names now assigned unconditionally (fixes prior AttributeError); GTP batched gather/RS logic in forward/backward looks consistent. Moves GTP all-gather before dgrad GEMM outside the requires_dgrad guard, which is correct for pipeline prefetch.
transformer_engine/pytorch/ops/fused/grouped_mlp.py FC2 weight preparation moved after FC1 kernel for GTP memory discipline; fine-grained activation offload markers added. FC1 backward GTP all-gather placed only in the MXFP8 branch, leaving the NVFP4 FC1 dgrad path using raw sharded weights (flagged previously).

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Mcore as Mcore Runtime
    participant TE_Module as TE Module (Linear/GroupedLinear)
    participant Dist as distributed.py helpers
    participant NCCL as NCCL/GPU

    Note over Mcore,TE_Module: Module init
    Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
    TE_Module->>TE_Module: "reset_parameters() → _gtp_slice_fn(expert_idx=idx)"
    TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)

    Note over TE_Module,NCCL: Forward pass
    TE_Module->>Mcore: "weight.all_gather_and_prefetch(fwd=True)"
    Mcore->>Dist: "gather_along_first_dim(..., grouped=True)"
    Dist->>NCCL: all_gather_into_tensor (under outer coalescing mgr)
    Dist-->>Mcore: _NVFP4AllGatherAsyncHandle / MXFP8 handle
    TE_Module->>TE_Module: GEMM with gathered weight
    TE_Module->>TE_Module: save sharded GTPShardedParam (not gathered weight)

    Note over TE_Module,NCCL: Backward pass
    TE_Module->>Mcore: saved_weight.all_gather_and_prefetch_bwd()
    Mcore->>NCCL: async all_gather (columnwise layout for dgrad)
    TE_Module->>TE_Module: dgrad GEMM
    TE_Module->>TE_Module: wgrad GEMM (full-sized)
    TE_Module->>Mcore: saved_weight.wgrad_reduce_scatter(wgrad)
    Mcore->>NCCL: async reduce_scatter → main_grad shard
    TE_Module-->>Mcore: wgrad handle (or dummy for MCore DDP path)
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Mcore as Mcore Runtime
    participant TE_Module as TE Module (Linear/GroupedLinear)
    participant Dist as distributed.py helpers
    participant NCCL as NCCL/GPU

    Note over Mcore,TE_Module: Module init
    Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
    TE_Module->>TE_Module: "reset_parameters() → _gtp_slice_fn(expert_idx=idx)"
    TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)

    Note over TE_Module,NCCL: Forward pass
    TE_Module->>Mcore: "weight.all_gather_and_prefetch(fwd=True)"
    Mcore->>Dist: "gather_along_first_dim(..., grouped=True)"
    Dist->>NCCL: all_gather_into_tensor (under outer coalescing mgr)
    Dist-->>Mcore: _NVFP4AllGatherAsyncHandle / MXFP8 handle
    TE_Module->>TE_Module: GEMM with gathered weight
    TE_Module->>TE_Module: save sharded GTPShardedParam (not gathered weight)

    Note over TE_Module,NCCL: Backward pass
    TE_Module->>Mcore: saved_weight.all_gather_and_prefetch_bwd()
    Mcore->>NCCL: async all_gather (columnwise layout for dgrad)
    TE_Module->>TE_Module: dgrad GEMM
    TE_Module->>TE_Module: wgrad GEMM (full-sized)
    TE_Module->>Mcore: saved_weight.wgrad_reduce_scatter(wgrad)
    Mcore->>NCCL: async reduce_scatter → main_grad shard
    TE_Module-->>Mcore: wgrad handle (or dummy for MCore DDP path)
Loading

Reviews (13): Last reviewed commit: "Merge remote-tracking branch 'ADLR_githu..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
Comment thread transformer_engine/pytorch/module/generalized_tensor_parallelism.py Outdated
@fanshiqing

Copy link
Copy Markdown
Member Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/distributed.py
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1287 to 1295
# Fix the interleaved transposed data from gathering along first dim.
out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size)
out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size)
# In-place .copy_() (not `=` rebind) to keep the storage address stable
# for CUDA graph capture — replays see the same pointer they captured.
out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size))
out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size))

# Optionally pad the scaling inverse if needed.
out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv)
# Optionally pad the scaling inverse if needed (same in-place pattern).
out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Shape mismatch in _post_process_nvfp4_gather breaks any K not a multiple of 128

out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.

The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines 1660 to +1680
@@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
with get_rng_state_tracker().fork():
init_fn(param)

# GTP slice: shard the freshly-init weight into a GTPShardedParam;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Wrong expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing

expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.

Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.

A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Comment on lines +1316 to 1336
def post_process_nvfp4_gather(self) -> None:
"""Fix interleaved transposed data + pad scale_inv after the async AG completes.

Idempotent: gated by ``_synchronized`` in :meth:`wait`.
"""
_post_process_nvfp4_gather(
self.output,
self.columnwise_data_interleaved,
self.columnwise_scale_inv_interleaved,
self.world_size,
)

def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
if self.async_handle is not None:
self.async_handle.wait()
self.post_process_nvfp4_gather()
self._synchronized = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 post_process_nvfp4_gather doesn't set _synchronized, enabling double-processing via wait()

post_process_nvfp4_gather() is a newly public method intended for callers using an outer coalescing manager (the grouped=True path). In that flow the GTP runtime is expected to call this method once the outer manager flushes, then later may call wait() for finalization. Because post_process_nvfp4_gather never sets self._synchronized = True, the _synchronized guard in wait() does not fire, and wait() calls post_process_nvfp4_gather() a second time. A double _swap_first_dims(..., world_size) reverts the data back to the interleaved format, silently producing corrupt weights for any forward pass that follows.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants