Generalized Tensor Parallelism (GTP) #3005
Conversation
Greptile SummaryThis PR introduces Generalized Tensor Parallelism (GTP) across TE's
Confidence Score: 3/5Several previously-flagged defects remain unresolved in the changed files: the wrong expert_idx counter in base.py, the missing GTP dummy-wgrad guard in linear.py, the FC1 dgrad path in grouped_mlp.py that skips the weight all-gather for NVFP4, and the shape-mismatch and double-processing risks in distributed.py. These affect core forward/backward correctness on the GTP paths. The TE-side hook registry and module-level plumbing are straightforward additions that look mechanically correct. The in-place .copy_() discipline in distributed.py is a sound approach for CUDA-graph pointer stability. However, multiple previously-identified defects in the same changed files have not been addressed: the expert_idx counter includes non-weight parameters (breaking GTP sharding for LayerNormLinear), linear.py silently discards the RS handle when fuse_wgrad_accumulation=True and the weight carries grad_added_to_main_grad, grouped_mlp.py produces wrong FC1 input gradients when use_nvfp4=True with GTP active, and distributed.py has a potential shape mismatch on the columnwise_scale_inv copy path and a double-execution risk in post_process_nvfp4_gather. transformer_engine/pytorch/module/linear.py (missing GTP guard before dummy-wgrad block), transformer_engine/pytorch/module/base.py (expert_idx counter includes non-weight params), transformer_engine/pytorch/ops/fused/grouped_mlp.py (FC1 nvfp4 dgrad path skips weight all-gather), transformer_engine/pytorch/distributed.py (columnwise_scale_inv shape mismatch and _synchronized not set in post_process_nvfp4_gather) Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Mcore as Mcore Runtime
participant TE_Module as TE Module (Linear/GroupedLinear)
participant Dist as distributed.py helpers
participant NCCL as NCCL/GPU
Note over Mcore,TE_Module: Module init
Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
TE_Module->>TE_Module: "reset_parameters() → _gtp_slice_fn(expert_idx=idx)"
TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)
Note over TE_Module,NCCL: Forward pass
TE_Module->>Mcore: "weight.all_gather_and_prefetch(fwd=True)"
Mcore->>Dist: "gather_along_first_dim(..., grouped=True)"
Dist->>NCCL: all_gather_into_tensor (under outer coalescing mgr)
Dist-->>Mcore: _NVFP4AllGatherAsyncHandle / MXFP8 handle
TE_Module->>TE_Module: GEMM with gathered weight
TE_Module->>TE_Module: save sharded GTPShardedParam (not gathered weight)
Note over TE_Module,NCCL: Backward pass
TE_Module->>Mcore: saved_weight.all_gather_and_prefetch_bwd()
Mcore->>NCCL: async all_gather (columnwise layout for dgrad)
TE_Module->>TE_Module: dgrad GEMM
TE_Module->>TE_Module: wgrad GEMM (full-sized)
TE_Module->>Mcore: saved_weight.wgrad_reduce_scatter(wgrad)
Mcore->>NCCL: async reduce_scatter → main_grad shard
TE_Module-->>Mcore: wgrad handle (or dummy for MCore DDP path)
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Mcore as Mcore Runtime
participant TE_Module as TE Module (Linear/GroupedLinear)
participant Dist as distributed.py helpers
participant NCCL as NCCL/GPU
Note over Mcore,TE_Module: Module init
Mcore->>TE_Module: register_gtp_hooks(slice_fn, finalize_fn, wrap_fn)
TE_Module->>TE_Module: "reset_parameters() → _gtp_slice_fn(expert_idx=idx)"
TE_Module->>Mcore: _gtp_wrap_fn(module, weight_names, gtp_group)
Note over TE_Module,NCCL: Forward pass
TE_Module->>Mcore: "weight.all_gather_and_prefetch(fwd=True)"
Mcore->>Dist: "gather_along_first_dim(..., grouped=True)"
Dist->>NCCL: all_gather_into_tensor (under outer coalescing mgr)
Dist-->>Mcore: _NVFP4AllGatherAsyncHandle / MXFP8 handle
TE_Module->>TE_Module: GEMM with gathered weight
TE_Module->>TE_Module: save sharded GTPShardedParam (not gathered weight)
Note over TE_Module,NCCL: Backward pass
TE_Module->>Mcore: saved_weight.all_gather_and_prefetch_bwd()
Mcore->>NCCL: async all_gather (columnwise layout for dgrad)
TE_Module->>TE_Module: dgrad GEMM
TE_Module->>TE_Module: wgrad GEMM (full-sized)
TE_Module->>Mcore: saved_weight.wgrad_reduce_scatter(wgrad)
Mcore->>NCCL: async reduce_scatter → main_grad shard
TE_Module-->>Mcore: wgrad handle (or dummy for MCore DDP path)
Reviews (13): Last reviewed commit: "Merge remote-tracking branch 'ADLR_githu..." | Re-trigger Greptile |
|
/te-ci L1 pytorch |
3e70bdf to
ed9ce68
Compare
Co-authored-by: Jieming Zhang <jiemingz@nvidia.com> Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| # Fix the interleaved transposed data from gathering along first dim. | ||
| out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) | ||
| out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) | ||
| # In-place .copy_() (not `=` rebind) to keep the storage address stable | ||
| # for CUDA graph capture — replays see the same pointer they captured. | ||
| out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) | ||
| out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) | ||
|
|
||
| # Optionally pad the scaling inverse if needed. | ||
| out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) | ||
| # Optionally pad the scaling inverse if needed (same in-place pattern). | ||
| out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) | ||
|
|
There was a problem hiding this comment.
Shape mismatch in
_post_process_nvfp4_gather breaks any K not a multiple of 128
out._columnwise_scale_inv is allocated by NVFP4Quantizer.make_empty with shape (round_up(K, 128), round_up(ceil(M_total/16), 4)) — the fully-padded shape. The intermediate result from _swap_first_dims(columnwise_scale_inv_interleaved, world_size) has the unpadded shape (K_stripped, world_size * unpadded_dim1), because the gather side strips padding before the NCCL collect. When K is not a multiple of 128 (e.g. K=64 → padded to 128), the dimensions diverge and out._columnwise_scale_inv.copy_(...) raises a RuntimeError at the first all-gather call.
The pre-PR code used = rebinding, which handled arbitrary shapes. Replacing it with .copy_() is only safe when the caller pre-allocates buffers with the correct unpadded intermediate shape — which make_empty does not do. The GTP-prefetched output_tensor path has the same problem on the step-1 copy before the pad_columnwise_scale_inv call can correct things.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| @@ -1627,10 +1677,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: | |||
| with get_rng_state_tracker().fork(): | |||
| init_fn(param) | |||
|
|
|||
| # GTP slice: shard the freshly-init weight into a GTPShardedParam; | |||
There was a problem hiding this comment.
Wrong
expert_idx for LayerNormLinear (and GroupedLinear with bias) silently disables GTP weight slicing
expert_idx=idx uses the position of the parameter in named_parameters(recurse=False), which includes non-linear-weight parameters. For LayerNormLinear the iteration order is layer_norm_weight (idx=0), layer_norm_bias (idx=1 for non-RMSNorm), weight (idx=2 or 1). The linear weight therefore arrives at _gtp_slice_fn with expert_idx=2 (or 1 for RMSNorm) instead of expert_idx=0. A Mcore hook that maps expert_idx to a pre-registered shard slot would find no entry for idx=2 and return None, silently leaving the weight un-sharded while gtp_group is set — defeating GTP for the entire LayerNormLinear path this PR explicitly adds.
Similarly, for GroupedLinear with biases enabled, weight1 receives expert_idx=2 (interleaved with bias0), so every expert beyond the first is mis-indexed.
A correct counter only advances when gtp_sharded is not None, keeping it aligned with the weight-only registration slots.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
| def post_process_nvfp4_gather(self) -> None: | ||
| """Fix interleaved transposed data + pad scale_inv after the async AG completes. | ||
|
|
||
| Idempotent: gated by ``_synchronized`` in :meth:`wait`. | ||
| """ | ||
| _post_process_nvfp4_gather( | ||
| self.output, | ||
| self.columnwise_data_interleaved, | ||
| self.columnwise_scale_inv_interleaved, | ||
| self.world_size, | ||
| ) | ||
|
|
||
| def wait(self) -> None: | ||
| """Wait for the async operation to complete and post-process the tensor.""" | ||
| if self._synchronized: | ||
| return | ||
| if self.async_handle is not None: | ||
| self.async_handle.wait() | ||
| self.post_process_nvfp4_gather() | ||
| self._synchronized = True | ||
|
|
There was a problem hiding this comment.
post_process_nvfp4_gather doesn't set _synchronized, enabling double-processing via wait()
post_process_nvfp4_gather() is a newly public method intended for callers using an outer coalescing manager (the grouped=True path). In that flow the GTP runtime is expected to call this method once the outer manager flushes, then later may call wait() for finalization. Because post_process_nvfp4_gather never sets self._synchronized = True, the _synchronized guard in wait() does not fire, and wait() calls post_process_nvfp4_gather() a second time. A double _swap_first_dims(..., world_size) reverts the data back to the interleaved format, silently producing corrupt weights for any forward pass that follows.
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Signed-off-by: Shiqing Fan <shiqingf@nvidia.com>
Deisgn doc: GTP.docx
Description
Core-idea: add Generalized Tensor Parallelism (GTP), which is a flexible fine-grained sharding/just-in time materialization of both activations and parameters with efficient computation-communication overlap.
Mission: improve LLM pretraining efficiency through generalized tensor parallelism, enabling high performance, memory efficiency, ease of use, and strong scalability.
Summary of features
How Mcore interacts with TE
① Mcore registers callbacks into TE at import time.
② TE calls back into Mcore runtime during te.Linear(gtp_group=…) init AND during fwd/bwd (weight.all_gather_and_prefetch / wgrad_reduce_scatter).
③ Mcore extensions forward gtp_group= at module init.
④ TE provides FP8 / MXFP8 / NVFP4 tensor types AND the quantize-then-AG / RS collectives (gather_along_first_dim, reduce_scatter_along_first_dim) — imported by Mcore runtime; GTP wraps them with its own schedule, buffer cache, and stream choreography.
Type of change
Changes
Please list the changes introduced in this PR:
wgrad_shape.
carving (with/without GTP);
Checklist: