-
Notifications
You must be signed in to change notification settings - Fork 743
[PyTorch] Support GroupedTensor torch ops for DDP and distributed optimizer
#2736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
5222d75
02c732d
63dd8e5
7f2c127
d6b758e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,36 @@ | |
| from .storage.grouped_tensor_storage import GroupedTensorStorage | ||
|
|
||
|
|
||
| # For now, conservatively ban all shape manipulating ops. | ||
| def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function is already defined in some utils, that you can reuse. |
||
| """Calculate contiguous stride from shape.""" | ||
| if len(shape) == 0: | ||
| return () | ||
| stride = [1] * len(shape) | ||
| for i in range(len(shape) - 2, -1, -1): | ||
| stride[i] = stride[i + 1] * shape[i + 1] | ||
| return tuple(stride) | ||
|
|
||
|
|
||
| class _GroupedIdentityFunc(torch.autograd.Function): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mcore distopt does |
||
| """Identity autograd function used to create a dummy grad_fn node.""" | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, tensor: "GroupedTensor") -> "GroupedTensor": | ||
| # pylint: disable=missing-function-docstring | ||
| ctx.input_dtype = tensor.dtype | ||
| return tensor.detach() | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| # pylint: disable=missing-function-docstring | ||
| grad_input = grad_output | ||
| if grad_input.dtype != ctx.input_dtype: | ||
| grad_input = grad_input.to(ctx.input_dtype) | ||
| return grad_input | ||
|
|
||
|
|
||
| # For now, conservatively ban 'most' shape manipulating ops. | ||
| BANNED_SHAPE_OPS = { | ||
| torch.ops.aten.view.default, | ||
| torch.ops.aten._unsafe_view.default, | ||
| torch.ops.aten.reshape.default, | ||
| torch.ops.aten._reshape_alias.default, | ||
| torch.ops.aten.flatten.using_ints, | ||
|
|
@@ -34,8 +60,6 @@ | |
| torch.ops.aten.select.int, | ||
| torch.ops.aten.split.Tensor, | ||
| torch.ops.aten.chunk.default, | ||
| torch.ops.aten.expand.default, | ||
| torch.ops.aten.expand_as.default, | ||
| torch.ops.aten.cat.default, | ||
| torch.ops.aten.stack.default, | ||
| } | ||
|
ksivaman marked this conversation as resolved.
|
||
|
|
@@ -48,6 +72,7 @@ def __new__( | |
| cls, | ||
| shape: Tuple[int, int], | ||
| dtype: torch.dtype, | ||
| *, | ||
| num_tensors: int, | ||
| shapes: Optional[List[Tuple[int, int]]] = None, | ||
| quantizer: Optional[Quantizer] = None, | ||
|
|
@@ -64,12 +89,9 @@ def __new__( | |
| offsets: Optional[List[int]] = None, | ||
| scale_inv_offsets: Optional[List[int]] = None, | ||
| columnwise_scale_inv_offsets: Optional[List[int]] = None, | ||
| requires_grad: bool = False, | ||
| stride: Optional[List[int]] = None, | ||
| ): | ||
| del quantizer | ||
| del offsets | ||
| del scale_inv_offsets | ||
| del columnwise_scale_inv_offsets | ||
|
|
||
| if ( | ||
| shapes is not None | ||
| and len(shapes) == num_tensors | ||
|
|
@@ -99,29 +121,136 @@ def __new__( | |
| if device is None: | ||
| device = torch.device("cuda") | ||
|
|
||
| strides = [1] * len(wrapper_shape) | ||
| for i in range(len(wrapper_shape) - 2, -1, -1): | ||
| strides[i] = strides[i + 1] * wrapper_shape[i + 1] | ||
| return torch.Tensor._make_wrapper_subclass( | ||
| # Match QuantizedTensor __new__: accept externally-computed stride to | ||
| # avoid Python-side stride computation overhead for C++ construction. | ||
| strides = _stride_from_shape(tuple(wrapper_shape)) if stride is None else tuple(stride) | ||
| instance = torch.Tensor._make_wrapper_subclass( | ||
| cls, | ||
| wrapper_shape, | ||
| strides=tuple(strides), | ||
| strides=strides, | ||
| storage_offset=0, | ||
| dtype=dtype, | ||
| layout=torch.strided, | ||
| requires_grad=False, | ||
| requires_grad=requires_grad, | ||
| device=device, | ||
| ) | ||
| GroupedTensorStorage._initialize_storage_fields( | ||
| instance=instance, | ||
| shape=shape, | ||
| dtype=dtype, | ||
| num_tensors=num_tensors, | ||
| shapes=shapes, | ||
| quantizer=quantizer, | ||
| data=data, | ||
| columnwise_data=columnwise_data, | ||
| scale_inv=scale_inv, | ||
| columnwise_scale_inv=columnwise_scale_inv, | ||
| amax=amax, | ||
| columnwise_amax=columnwise_amax, | ||
| scale=scale, | ||
| first_dims=first_dims, | ||
| last_dims=last_dims, | ||
| tensor_offsets=tensor_offsets, | ||
| offsets=offsets, | ||
| scale_inv_offsets=scale_inv_offsets, | ||
| columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, | ||
| ) | ||
| return instance | ||
|
|
||
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
| """Dispatch by dequantizing grouped members, then requantizing writes.""" | ||
| if kwargs is None: | ||
| kwargs = {} | ||
|
|
||
| def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> None: | ||
| """Shallow-copy grouped-storage metadata onto wrapper outputs.""" | ||
| dst.num_tensors = src.num_tensors | ||
| dst.quantizer = src.quantizer | ||
| dst.tensor_shapes = src.tensor_shapes | ||
| dst.fake_dtype = src.fake_dtype | ||
| dst.rowwise_data = src.rowwise_data | ||
| dst.columnwise_data = src.columnwise_data | ||
| dst.scale_inv = src.scale_inv | ||
| dst.columnwise_scale_inv = src.columnwise_scale_inv | ||
| dst.amax = src.amax | ||
| dst.columnwise_amax = src.columnwise_amax | ||
| dst.scale = src.scale | ||
| dst.first_dims = src.first_dims | ||
| dst.last_dims = src.last_dims | ||
| dst.tensor_offsets = src.tensor_offsets | ||
| dst.offsets = src.offsets | ||
| dst.scale_inv_offsets = src.scale_inv_offsets | ||
| dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets | ||
| dst.logical_shape = src.logical_shape | ||
| dst.quantized_tensors = src.quantized_tensors | ||
|
|
||
| def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: | ||
| """Create a wrapper of the same type and tensor metadata as src.""" | ||
| out = torch.Tensor._make_wrapper_subclass( | ||
| type(src), | ||
| tuple(src.shape), | ||
| strides=tuple(src.stride()), | ||
| storage_offset=src.storage_offset(), | ||
| dtype=src.dtype, | ||
| layout=src.layout, | ||
| requires_grad=requires_grad, | ||
| device=src.device, | ||
| ) | ||
| copy_grouped_storage_metadata(out, src) | ||
| return out | ||
|
|
||
| # Parameter construction calls detach()/alias-like paths. | ||
| if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): | ||
| return args[0] | ||
| src = args[0] | ||
| assert isinstance(src, GroupedTensor) | ||
| if func == torch.ops.aten.detach.default: | ||
| return make_wrapper_like(src, requires_grad=False) | ||
| return make_wrapper_like(src, requires_grad=src.requires_grad) | ||
|
|
||
| # Parameter construction may invoke aten.expand on tensor subclasses. | ||
| # Handle this explicitly so grouped parameters can be created safely. | ||
| if func == torch.ops.aten.expand.default: | ||
| src = args[0] | ||
| assert isinstance(src, GroupedTensor) | ||
| expanded_shape = tuple(args[1]) | ||
|
Comment on lines
+215
to
+216
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious why all this torch dispatch logic is not needed in MXFP8 tensor. As in how does DDP work even with discrete MXFP8 weights, if MCore uses all this ops?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question for expand_as, view, unsafe_view |
||
| src_shape = tuple(src.shape) | ||
| if len(expanded_shape) == len(src_shape): | ||
| normalized_shape = tuple( | ||
| src_shape[i] if dim == -1 else dim for i, dim in enumerate(expanded_shape) | ||
| ) | ||
| if normalized_shape == src_shape: | ||
| return make_wrapper_like(src, requires_grad=src.requires_grad) | ||
| return super().__torch_dispatch__(func, types, args, kwargs) | ||
|
|
||
| # DDP and mcore use expand_as(self) to build a dummy autograd node and | ||
| # access gradient accumulators during parameter hook registration. | ||
| if func == torch.ops.aten.expand_as.default: | ||
| src = args[0] | ||
| other = args[1] | ||
| assert isinstance(src, GroupedTensor) | ||
| if other is src: | ||
| return _GroupedIdentityFunc.apply(src) | ||
| if tuple(other.shape) == tuple(src.shape): | ||
| return make_wrapper_like(src, requires_grad=src.requires_grad) | ||
| return super().__torch_dispatch__(func, types, args, kwargs) | ||
|
|
||
| # Distributed optimizer flattens detached parameters via | ||
| # model_param.detach().view(-1). Support this path explicitly by | ||
| # returning a flat view of grouped backing storage. | ||
| if func in (torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default): | ||
| src = args[0] | ||
| assert isinstance(src, GroupedTensor) | ||
| target_shape = tuple(args[1]) | ||
| if target_shape in ((-1,), (src.numel(),)): | ||
| if src.rowwise_data is not None: | ||
| return src.rowwise_data.view(-1) | ||
|
ksivaman marked this conversation as resolved.
|
||
| raise RuntimeError( | ||
| f"{cls.__name__} view(-1) requires rowwise_data to be initialized" | ||
| ) | ||
| raise RuntimeError( | ||
| f"{cls.__name__} only supports view(-1) for distributed optimizer flattening" | ||
| ) | ||
|
|
||
| # Don't allow reshape/view etc. | ||
| if func in BANNED_SHAPE_OPS: | ||
|
|
@@ -203,3 +332,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): | |
| kwargs = {} | ||
| # Do not force GroupedTensor on outputs. | ||
| return torch._C._disabled_torch_function_impl(func, types, args, kwargs) | ||
|
|
||
| def expand_as(self, other: torch.Tensor) -> torch.Tensor: | ||
| # pylint: disable=missing-function-docstring | ||
| # Needed during parameter creation/hook registration paths. | ||
| if other is self: | ||
| return _GroupedIdentityFunc.apply(self) | ||
| return super().expand_as(other) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to specify that it's a grouped shape?
Since we subclass a torch.Tensor, what is the shape field of a grouped tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape of the GroupedTensor is the logical shape.