From 5222d755e7a894b8d029c27bf7747df321b5796c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Mar 2026 04:00:00 +0000 Subject: [PATCH 1/5] Fix e2e execution of GroupedTensor in distributed settings Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/csrc/quantizer.cpp | 42 +++-- .../pytorch/tensor/grouped_tensor.py | 166 ++++++++++++++++-- .../tensor/storage/grouped_tensor_storage.py | 97 +++++++--- 3 files changed, 258 insertions(+), 47 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0135c7f01c..0214f7ff71 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -180,8 +180,11 @@ std::pair NoneQuantizer::create_grouped_tensor py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; @@ -386,8 +389,11 @@ std::pair Float8Quantizer::create_grouped_tens py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; @@ -704,8 +710,11 @@ std::pair Float8CurrentScalingQuantizer::creat py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; @@ -1062,8 +1071,11 @@ std::pair Float8BlockQuantizer::create_grouped py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; @@ -1478,8 +1490,11 @@ std::pair MXFP8Quantizer::create_grouped_tenso py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; @@ -1906,8 +1921,11 @@ std::pair NVFP4Quantizer::create_grouped_tenso py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; py::tuple args(0); - kwargs["shape"] = py::cast(std::vector{static_cast(logical_first_dim), - static_cast(logical_last_dim)}); + const std::vector grouped_shape = {static_cast(logical_first_dim), + static_cast(logical_last_dim)}; + const std::vector grouped_stride = stride_from_shape(grouped_shape); + kwargs["shape"] = py::cast(grouped_shape); + kwargs["stride"] = py::cast(grouped_stride); kwargs["dtype"] = py::cast(GetATenDType(dtype)); kwargs["num_tensors"] = py::cast(num_tensors); kwargs["quantizer"] = quantizer; diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index 767b0ccb35..ff2d1b311b 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -14,6 +14,34 @@ from .storage.grouped_tensor_storage import GroupedTensorStorage +def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate contiguous stride from shape.""" + if len(shape) == 0: + return () + stride = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + stride[i] = stride[i + 1] * shape[i + 1] + return tuple(stride) + + +class _GroupedIdentityFunc(torch.autograd.Function): + """Identity autograd function used to create a dummy grad_fn node.""" + + @staticmethod + def forward(ctx, tensor: "GroupedTensor") -> "GroupedTensor": + # pylint: disable=missing-function-docstring + ctx.input_dtype = tensor.dtype + return tensor.detach() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # pylint: disable=missing-function-docstring + grad_input = grad_output + if grad_input.dtype != ctx.input_dtype: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input + + # For now, conservatively ban all shape manipulating ops. BANNED_SHAPE_OPS = { torch.ops.aten.view.default, @@ -34,8 +62,6 @@ torch.ops.aten.select.int, torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default, - torch.ops.aten.expand.default, - torch.ops.aten.expand_as.default, torch.ops.aten.cat.default, torch.ops.aten.stack.default, } @@ -48,6 +74,7 @@ def __new__( cls, shape: Tuple[int, int], dtype: torch.dtype, + *, num_tensors: int, shapes: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, @@ -64,12 +91,9 @@ def __new__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, + requires_grad: bool = False, + stride: Optional[List[int]] = None, ): - del quantizer - del offsets - del scale_inv_offsets - del columnwise_scale_inv_offsets - if ( shapes is not None and len(shapes) == num_tensors @@ -99,19 +123,41 @@ def __new__( if device is None: device = torch.device("cuda") - strides = [1] * len(wrapper_shape) - for i in range(len(wrapper_shape) - 2, -1, -1): - strides[i] = strides[i + 1] * wrapper_shape[i + 1] - return torch.Tensor._make_wrapper_subclass( + # Match QuantizedTensor __new__: accept externally-computed stride to + # avoid Python-side stride computation overhead for C++ construction. + strides = _stride_from_shape(tuple(wrapper_shape)) if stride is None else tuple(stride) + instance = torch.Tensor._make_wrapper_subclass( cls, wrapper_shape, - strides=tuple(strides), + strides=strides, storage_offset=0, dtype=dtype, layout=torch.strided, - requires_grad=False, + requires_grad=requires_grad, device=device, ) + GroupedTensorStorage._initialize_storage_fields( + instance=instance, + shape=shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=shapes, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + ) + return instance @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -119,9 +165,94 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if kwargs is None: kwargs = {} + def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> None: + """Shallow-copy grouped-storage metadata onto wrapper outputs.""" + dst.num_tensors = src.num_tensors + dst.quantizer = src.quantizer + dst.tensor_shapes = src.tensor_shapes + dst.fake_dtype = src.fake_dtype + dst.rowwise_data = src.rowwise_data + dst.columnwise_data = src.columnwise_data + dst.scale_inv = src.scale_inv + dst.columnwise_scale_inv = src.columnwise_scale_inv + dst.amax = src.amax + dst.columnwise_amax = src.columnwise_amax + dst.scale = src.scale + dst.first_dims = src.first_dims + dst.last_dims = src.last_dims + dst.tensor_offsets = src.tensor_offsets + dst.offsets = src.offsets + dst.scale_inv_offsets = src.scale_inv_offsets + dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets + dst.logical_shape = src.logical_shape + dst.quantized_tensors = src.quantized_tensors + + def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: + """Create a wrapper of the same type and tensor metadata as src.""" + out = torch.Tensor._make_wrapper_subclass( + type(src), + tuple(src.shape), + strides=tuple(src.stride()), + storage_offset=src.storage_offset(), + dtype=src.dtype, + layout=src.layout, + requires_grad=requires_grad, + device=src.device, + ) + copy_grouped_storage_metadata(out, src) + return out + # Parameter construction calls detach()/alias-like paths. if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default): - return args[0] + src = args[0] + assert isinstance(src, GroupedTensor) + if func == torch.ops.aten.detach.default: + return make_wrapper_like(src, requires_grad=False) + return make_wrapper_like(src, requires_grad=src.requires_grad) + + # Parameter construction may invoke aten.expand on tensor subclasses. + # Handle this explicitly so grouped parameters can be created safely. + if func == torch.ops.aten.expand.default: + src = args[0] + assert isinstance(src, GroupedTensor) + expanded_shape = tuple(args[1]) + src_shape = tuple(src.shape) + if len(expanded_shape) == len(src_shape): + normalized_shape = tuple( + src_shape[i] if dim == -1 else dim for i, dim in enumerate(expanded_shape) + ) + if normalized_shape == src_shape: + return make_wrapper_like(src, requires_grad=src.requires_grad) + return super().__torch_dispatch__(func, types, args, kwargs) + + # DDP and mcore use expand_as(self) to build a dummy autograd node and + # access gradient accumulators during parameter hook registration. + if func == torch.ops.aten.expand_as.default: + src = args[0] + other = args[1] + assert isinstance(src, GroupedTensor) + if other is src: + return _GroupedIdentityFunc.apply(src) + if tuple(other.shape) == tuple(src.shape): + return make_wrapper_like(src, requires_grad=src.requires_grad) + return super().__torch_dispatch__(func, types, args, kwargs) + + # Distributed optimizer flattens detached parameters via + # model_param.detach().view(-1). Support this path explicitly by + # returning a flat view of grouped backing storage. + if func in (torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default): + src = args[0] + assert isinstance(src, GroupedTensor) + target_shape = tuple(args[1]) + if target_shape == (-1,) or target_shape == (src.numel(),): + if src.rowwise_data is not None: + return src.rowwise_data.view(-1) + raise RuntimeError( + f"{cls.__name__} view(-1) requires rowwise_data to be initialized" + ) + raise RuntimeError( + f"{cls.__name__} only supports view(-1) for distributed optimizer flattening" + ) # Don't allow reshape/view etc. if func in BANNED_SHAPE_OPS: @@ -203,3 +334,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} # Do not force GroupedTensor on outputs. return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def expand_as(self, other: torch.Tensor) -> torch.Tensor: + # pylint: disable=missing-function-docstring + # Needed during parameter creation/hook registration paths. + if other is self: + return _GroupedIdentityFunc.apply(self) + return super().expand_as(other) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 92006ba45b..f20d85e710 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -48,8 +48,9 @@ class GroupedTensorStorage: Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. """ - def __init__( - self, + @staticmethod + def _initialize_storage_fields( + instance: "GroupedTensorStorage", shape: Tuple[int, int], dtype: torch.dtype, num_tensors: int, @@ -68,6 +69,8 @@ def __init__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, + requires_grad: bool = False, + stride: Optional[List[int]] = None, ) -> None: """ Initialize a GroupedTensor. @@ -90,31 +93,33 @@ def __init__( tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) offsets: Vector of integer offsets for each tensor. """ - self.num_tensors = num_tensors - self.quantizer = quantizer - self.tensor_shapes = shapes - self.fake_dtype = dtype + del requires_grad + del stride + instance.num_tensors = num_tensors + instance.quantizer = quantizer + instance.tensor_shapes = shapes + instance.fake_dtype = dtype # Data buffers - self.rowwise_data = data - self.columnwise_data = columnwise_data - self.scale_inv = scale_inv - self.columnwise_scale_inv = columnwise_scale_inv - self.amax = amax - self.columnwise_amax = columnwise_amax - self.scale = scale + instance.rowwise_data = data + instance.columnwise_data = columnwise_data + instance.scale_inv = scale_inv + instance.columnwise_scale_inv = columnwise_scale_inv + instance.amax = amax + instance.columnwise_amax = columnwise_amax + instance.scale = scale # For convenient indexing for python GroupedTensor API. - self.scale_inv_offsets = scale_inv_offsets - self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + instance.scale_inv_offsets = scale_inv_offsets + instance.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets # Shape information (OPTIONAL - None if dimension is uniform across all tensors) # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) - self.first_dims = ( + instance.first_dims = ( first_dims # Device pointer to int64_t array of length num_tensors (or None) ) - self.last_dims = ( + instance.last_dims = ( last_dims # Device pointer to int64_t array of length num_tensors (or None) ) @@ -122,19 +127,69 @@ def __init__( # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) - self.tensor_offsets = ( + instance.tensor_offsets = ( tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) ) - self.offsets = offsets # Vector of integer offsets for each tensor. + instance.offsets = offsets # Vector of integer offsets for each tensor. # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) # Represents how the 1D flattened data should be interpreted as 2D # Always 2D with positive dimensions - self.logical_shape = shape + instance.logical_shape = shape # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. # Used as a convenience. - self.quantized_tensors = None + instance.quantized_tensors = None + + def __new__( + cls, + shape: Tuple[int, int], + dtype: torch.dtype, + num_tensors: int, + shapes: Optional[List[Tuple[int, int]]] = None, + quantizer: Optional[Quantizer] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + *, + requires_grad: bool = False, + stride: Optional[List[int]] = None, + ): + instance = object.__new__(cls) + cls._initialize_storage_fields( + instance=instance, + shape=shape, + dtype=dtype, + num_tensors=num_tensors, + shapes=shapes, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + requires_grad=requires_grad, + stride=stride, + ) + return instance def has_data(self) -> bool: """ From 02c732d2da812b2f29e2beb5c2dc32d61e12a8fb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Mar 2026 04:22:18 +0000 Subject: [PATCH 2/5] Minor fixes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/tensor/grouped_tensor.py | 2 +- .../pytorch/tensor/storage/grouped_tensor_storage.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index ff2d1b311b..fad8a770c8 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -244,7 +244,7 @@ def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: src = args[0] assert isinstance(src, GroupedTensor) target_shape = tuple(args[1]) - if target_shape == (-1,) or target_shape == (src.numel(),): + if target_shape in ((-1,), (src.numel(),)): if src.rowwise_data is not None: return src.rowwise_data.view(-1) raise RuntimeError( diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index f20d85e710..04935d1d4f 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -145,6 +145,7 @@ def __new__( cls, shape: Tuple[int, int], dtype: torch.dtype, + *, num_tensors: int, shapes: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, @@ -161,7 +162,6 @@ def __new__( offsets: Optional[List[int]] = None, scale_inv_offsets: Optional[List[int]] = None, columnwise_scale_inv_offsets: Optional[List[int]] = None, - *, requires_grad: bool = False, stride: Optional[List[int]] = None, ): From 63dd8e51b67c495f19d9d03b469188d756a57d41 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Mar 2026 05:02:11 +0000 Subject: [PATCH 3/5] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/tensor/grouped_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index fad8a770c8..685b2c5548 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -42,10 +42,8 @@ def backward(ctx, grad_output: torch.Tensor): return grad_input -# For now, conservatively ban all shape manipulating ops. +# For now, conservatively ban 'most' shape manipulating ops. BANNED_SHAPE_OPS = { - torch.ops.aten.view.default, - torch.ops.aten._unsafe_view.default, torch.ops.aten.reshape.default, torch.ops.aten._reshape_alias.default, torch.ops.aten.flatten.using_ints, From 7f2c127b67738c728b9608db69595bcca06878fd Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Mar 2026 10:33:12 +0530 Subject: [PATCH 4/5] Update transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/tensor/storage/grouped_tensor_storage.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 04935d1d4f..a3f81a7290 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -93,8 +93,11 @@ def _initialize_storage_fields( tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) offsets: Vector of integer offsets for each tensor. """ - del requires_grad - del stride +# `requires_grad` and `stride` are accepted for API symmetry with +# GroupedTensor.__new__ but are not relevant for storage-only +# initialization; they are intentionally ignored here. +del requires_grad +del stride instance.num_tensors = num_tensors instance.quantizer = quantizer instance.tensor_shapes = shapes From d6b758e41b9ac06defdbe20670654345391115f9 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 5 Mar 2026 05:08:43 +0000 Subject: [PATCH 5/5] fix greptile commit Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/tensor/storage/grouped_tensor_storage.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index a3f81a7290..3b7b9bc169 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -93,11 +93,12 @@ def _initialize_storage_fields( tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) offsets: Vector of integer offsets for each tensor. """ -# `requires_grad` and `stride` are accepted for API symmetry with -# GroupedTensor.__new__ but are not relevant for storage-only -# initialization; they are intentionally ignored here. -del requires_grad -del stride + # `requires_grad` and `stride` are accepted for API symmetry with + # GroupedTensor.__new__ but are not relevant for storage-only + # initialization; they are intentionally ignored here. + del requires_grad + del stride + instance.num_tensors = num_tensors instance.quantizer = quantizer instance.tensor_shapes = shapes