Skip to content

[PyTorch] Support GroupedTensor torch ops for DDP and distributed optimizer#2736

Merged
ksivaman merged 5 commits into
NVIDIA:mainfrom
ksivaman:fix_e2e_grouped_tensor
Mar 5, 2026
Merged

[PyTorch] Support GroupedTensor torch ops for DDP and distributed optimizer#2736
ksivaman merged 5 commits into
NVIDIA:mainfrom
ksivaman:fix_e2e_grouped_tensor

Conversation

@ksivaman

@ksivaman ksivaman commented Mar 5, 2026

Copy link
Copy Markdown
Member

Description

As a follow-up to #2731, adds support for specific operations required for e2e execution using GroupedTensor. Also make some minor optimizations and cleanup.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Implement expand, expand_as, and view for GroupedTensor.
  • Calculate tensor strides in C++ in order to avoid high CPU overhead.
  • Add requires_grad option during initialization.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ksivaman added 2 commits March 5, 2026 04:00
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman added the MoE label Mar 5, 2026
@ksivaman

ksivaman commented Mar 5, 2026

Copy link
Copy Markdown
Member Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Mar 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends GroupedTensor to support the subset of torch ops required by DDP and the distributed optimizer: identity expand/expand_as, flat view(-1), and proper detach/alias semantics. It also pre-computes contiguous strides in C++ (stride_from_shape) to reduce Python-side CPU overhead during tensor construction, and refactors GroupedTensorStorage.__init__ into a static _initialize_storage_fields helper so both the C++ construction path and the Python copy path share a single field-population routine.

The implementation is correct and safe for the documented DDP and distributed-optimizer use cases. The expand.default and expand_as.default handlers correctly implement identity operations via dedicated dispatch logic, stride passed from C++ is properly validated and consumed, and _GroupedIdentityFunc is well-scoped for hook plumbing. The C++ stride_from_shape helper is straightforward and correctly propagated to all quantizer creation sites.

Confidence Score: 4/5

  • Safe to merge. Core functionality (identity expand, view(-1), detach, alias) is correctly implemented for DDP and distributed-optimizer workflows.
  • The PR implements well-defined operations with clear scope: identity expand/expand_as handlers are explicit and correct, view(-1) intentionally returns flat backing storage for optimizer flattening, and detach/alias semantics preserve metadata. The C++ stride optimization is straightforward. All field initialization and storage propagation is consistent. The implementation is focused, well-commented, and addresses the documented use case without introducing regressions.
  • No files require special attention. All changes are safe and focused on the documented use case.

Sequence Diagram

sequenceDiagram
    participant DDP as DDP / DistOptimizer
    participant GT as GroupedTensor
    participant Dispatch as __torch_dispatch__
    participant GIF as _GroupedIdentityFunc

    DDP->>GT: param.expand_as(param)
    GT->>GT: expand_as() override\n(other is self)
    GT->>GIF: _GroupedIdentityFunc.apply(self)
    GIF->>GT: forward: tensor.detach()
    GT->>Dispatch: detach.default
    Dispatch->>GT: make_wrapper_like(src, requires_grad=False)
    GT-->>GIF: detached GroupedTensor
    GIF-->>DDP: GroupedTensor with grad_fn (identity)

    DDP->>GT: param.detach().view(-1)
    GT->>Dispatch: detach.default
    Dispatch->>GT: make_wrapper_like(src, requires_grad=False)
    GT-->>DDP: detached GroupedTensor
    DDP->>GT: view(-1)
    GT->>Dispatch: view.default / _unsafe_view.default
    Dispatch->>GT: rowwise_data.view(-1)
    GT-->>DDP: flat 1-D tensor (raw backing storage)
Loading

Last reviewed commit: d6b758e

Comment thread transformer_engine/pytorch/tensor/grouped_tensor.py
Comment thread transformer_engine/pytorch/tensor/grouped_tensor.py
Comment thread transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py
zhongbozhu
zhongbozhu previously approved these changes Mar 5, 2026

@zhongbozhu zhongbozhu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return tuple(stride)


class _GroupedIdentityFunc(torch.autograd.Function):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

@ksivaman ksivaman Mar 5, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcore distopt does weight.expand_as(param) to force a graph edge (autograd). This is needed to safely implement that for GroupedTensor.

py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to specify that it's a grouped shape?

Since we subclass a torch.Tensor, what is the shape field of a grouped tensor?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape of the GroupedTensor is the logical shape.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
…ge.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Outdated
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

@zhongbozhu zhongbozhu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending CI to merge



# For now, conservatively ban all shape manipulating ops.
def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is already defined in some utils, that you can reuse.

Comment on lines +215 to +216
assert isinstance(src, GroupedTensor)
expanded_shape = tuple(args[1])

@vthumbe1503 vthumbe1503 Mar 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious why all this torch dispatch logic is not needed in MXFP8 tensor. As in how does DDP work even with discrete MXFP8 weights, if MCore uses all this ops?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for expand_as, view, unsafe_view

@ksivaman ksivaman merged commit d9152b0 into NVIDIA:main Mar 5, 2026
9 of 12 checks passed
@ksivaman ksivaman deleted the fix_e2e_grouped_tensor branch March 5, 2026 07:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants