[PyTorch] Support GroupedTensor torch ops for DDP and distributed optimizer#2736
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryThis PR extends The implementation is correct and safe for the documented DDP and distributed-optimizer use cases. The Confidence Score: 4/5
Sequence DiagramsequenceDiagram
participant DDP as DDP / DistOptimizer
participant GT as GroupedTensor
participant Dispatch as __torch_dispatch__
participant GIF as _GroupedIdentityFunc
DDP->>GT: param.expand_as(param)
GT->>GT: expand_as() override\n(other is self)
GT->>GIF: _GroupedIdentityFunc.apply(self)
GIF->>GT: forward: tensor.detach()
GT->>Dispatch: detach.default
Dispatch->>GT: make_wrapper_like(src, requires_grad=False)
GT-->>GIF: detached GroupedTensor
GIF-->>DDP: GroupedTensor with grad_fn (identity)
DDP->>GT: param.detach().view(-1)
GT->>Dispatch: detach.default
Dispatch->>GT: make_wrapper_like(src, requires_grad=False)
GT-->>DDP: detached GroupedTensor
DDP->>GT: view(-1)
GT->>Dispatch: view.default / _unsafe_view.default
Dispatch->>GT: rowwise_data.view(-1)
GT-->>DDP: flat 1-D tensor (raw backing storage)
Last reviewed commit: d6b758e |
| return tuple(stride) | ||
|
|
||
|
|
||
| class _GroupedIdentityFunc(torch.autograd.Function): |
There was a problem hiding this comment.
mcore distopt does weight.expand_as(param) to force a graph edge (autograd). This is needed to safely implement that for GroupedTensor.
| py::tuple args(0); | ||
| kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim), | ||
| static_cast<int64_t>(logical_last_dim)}); | ||
| const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim), |
There was a problem hiding this comment.
Why do we need to specify that it's a grouped shape?
Since we subclass a torch.Tensor, what is the shape field of a grouped tensor?
There was a problem hiding this comment.
The shape of the GroupedTensor is the logical shape.
…ge.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
|
||
|
|
||
| # For now, conservatively ban all shape manipulating ops. | ||
| def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: |
There was a problem hiding this comment.
I think this function is already defined in some utils, that you can reuse.
| assert isinstance(src, GroupedTensor) | ||
| expanded_shape = tuple(args[1]) |
There was a problem hiding this comment.
I am curious why all this torch dispatch logic is not needed in MXFP8 tensor. As in how does DDP work even with discrete MXFP8 weights, if MCore uses all this ops?
There was a problem hiding this comment.
Same question for expand_as, view, unsafe_view
Description
As a follow-up to #2731, adds support for specific operations required for e2e execution using
GroupedTensor. Also make some minor optimizations and cleanup.Type of change
Changes
expand,expand_as, and view forGroupedTensor.requires_gradoption during initialization.Checklist: