Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ std::pair<GroupedTensorWrapper, py::object> NoneQuantizer::create_grouped_tensor
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to specify that it's a grouped shape?

Since we subclass a torch.Tensor, what is the shape field of a grouped tensor?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape of the GroupedTensor is the logical shape.

static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down Expand Up @@ -386,8 +389,11 @@ std::pair<GroupedTensorWrapper, py::object> Float8Quantizer::create_grouped_tens
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down Expand Up @@ -704,8 +710,11 @@ std::pair<GroupedTensorWrapper, py::object> Float8CurrentScalingQuantizer::creat
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down Expand Up @@ -1062,8 +1071,11 @@ std::pair<GroupedTensorWrapper, py::object> Float8BlockQuantizer::create_grouped
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down Expand Up @@ -1478,8 +1490,11 @@ std::pair<GroupedTensorWrapper, py::object> MXFP8Quantizer::create_grouped_tenso
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down Expand Up @@ -1906,8 +1921,11 @@ std::pair<GroupedTensorWrapper, py::object> NVFP4Quantizer::create_grouped_tenso
py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal);
py::dict kwargs;
py::tuple args(0);
kwargs["shape"] = py::cast(std::vector<int64_t>{static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)});
const std::vector<int64_t> grouped_shape = {static_cast<int64_t>(logical_first_dim),
static_cast<int64_t>(logical_last_dim)};
const std::vector<int64_t> grouped_stride = stride_from_shape(grouped_shape);
kwargs["shape"] = py::cast(grouped_shape);
kwargs["stride"] = py::cast(grouped_stride);
kwargs["dtype"] = py::cast(GetATenDType(dtype));
kwargs["num_tensors"] = py::cast(num_tensors);
kwargs["quantizer"] = quantizer;
Expand Down
166 changes: 152 additions & 14 deletions transformer_engine/pytorch/tensor/grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,34 @@
from .storage.grouped_tensor_storage import GroupedTensorStorage


def _stride_from_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is already defined in some utils, that you can reuse.

"""Calculate contiguous stride from shape."""
if len(shape) == 0:
return ()
stride = [1] * len(shape)
for i in range(len(shape) - 2, -1, -1):
stride[i] = stride[i + 1] * shape[i + 1]
return tuple(stride)


class _GroupedIdentityFunc(torch.autograd.Function):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

@ksivaman ksivaman Mar 5, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mcore distopt does weight.expand_as(param) to force a graph edge (autograd). This is needed to safely implement that for GroupedTensor.

"""Identity autograd function used to create a dummy grad_fn node."""

@staticmethod
def forward(ctx, tensor: "GroupedTensor") -> "GroupedTensor":
# pylint: disable=missing-function-docstring
ctx.input_dtype = tensor.dtype
return tensor.detach()

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# pylint: disable=missing-function-docstring
grad_input = grad_output
if grad_input.dtype != ctx.input_dtype:
grad_input = grad_input.to(ctx.input_dtype)
return grad_input


# For now, conservatively ban all shape manipulating ops.
BANNED_SHAPE_OPS = {
torch.ops.aten.view.default,
Expand All @@ -34,8 +62,6 @@
torch.ops.aten.select.int,
torch.ops.aten.split.Tensor,
torch.ops.aten.chunk.default,
torch.ops.aten.expand.default,
torch.ops.aten.expand_as.default,
torch.ops.aten.cat.default,
torch.ops.aten.stack.default,
}
Comment thread
ksivaman marked this conversation as resolved.
Expand All @@ -48,6 +74,7 @@ def __new__(
cls,
shape: Tuple[int, int],
dtype: torch.dtype,
*,
num_tensors: int,
shapes: Optional[List[Tuple[int, int]]] = None,
quantizer: Optional[Quantizer] = None,
Expand All @@ -64,12 +91,9 @@ def __new__(
offsets: Optional[List[int]] = None,
scale_inv_offsets: Optional[List[int]] = None,
columnwise_scale_inv_offsets: Optional[List[int]] = None,
requires_grad: bool = False,
stride: Optional[List[int]] = None,
):
del quantizer
del offsets
del scale_inv_offsets
del columnwise_scale_inv_offsets

if (
shapes is not None
and len(shapes) == num_tensors
Expand Down Expand Up @@ -99,29 +123,136 @@ def __new__(
if device is None:
device = torch.device("cuda")

strides = [1] * len(wrapper_shape)
for i in range(len(wrapper_shape) - 2, -1, -1):
strides[i] = strides[i + 1] * wrapper_shape[i + 1]
return torch.Tensor._make_wrapper_subclass(
# Match QuantizedTensor __new__: accept externally-computed stride to
# avoid Python-side stride computation overhead for C++ construction.
strides = _stride_from_shape(tuple(wrapper_shape)) if stride is None else tuple(stride)
instance = torch.Tensor._make_wrapper_subclass(
cls,
wrapper_shape,
strides=tuple(strides),
strides=strides,
storage_offset=0,
dtype=dtype,
layout=torch.strided,
requires_grad=False,
requires_grad=requires_grad,
device=device,
)
GroupedTensorStorage._initialize_storage_fields(
instance=instance,
shape=shape,
dtype=dtype,
num_tensors=num_tensors,
shapes=shapes,
quantizer=quantizer,
data=data,
columnwise_data=columnwise_data,
scale_inv=scale_inv,
columnwise_scale_inv=columnwise_scale_inv,
amax=amax,
columnwise_amax=columnwise_amax,
scale=scale,
first_dims=first_dims,
last_dims=last_dims,
tensor_offsets=tensor_offsets,
offsets=offsets,
scale_inv_offsets=scale_inv_offsets,
columnwise_scale_inv_offsets=columnwise_scale_inv_offsets,
)
return instance

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
"""Dispatch by dequantizing grouped members, then requantizing writes."""
if kwargs is None:
kwargs = {}

def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> None:
"""Shallow-copy grouped-storage metadata onto wrapper outputs."""
dst.num_tensors = src.num_tensors
dst.quantizer = src.quantizer
dst.tensor_shapes = src.tensor_shapes
dst.fake_dtype = src.fake_dtype
dst.rowwise_data = src.rowwise_data
dst.columnwise_data = src.columnwise_data
dst.scale_inv = src.scale_inv
dst.columnwise_scale_inv = src.columnwise_scale_inv
dst.amax = src.amax
dst.columnwise_amax = src.columnwise_amax
dst.scale = src.scale
dst.first_dims = src.first_dims
dst.last_dims = src.last_dims
dst.tensor_offsets = src.tensor_offsets
dst.offsets = src.offsets
dst.scale_inv_offsets = src.scale_inv_offsets
dst.columnwise_scale_inv_offsets = src.columnwise_scale_inv_offsets
dst.logical_shape = src.logical_shape
dst.quantized_tensors = src.quantized_tensors

def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor:
"""Create a wrapper of the same type and tensor metadata as src."""
out = torch.Tensor._make_wrapper_subclass(
type(src),
tuple(src.shape),
strides=tuple(src.stride()),
storage_offset=src.storage_offset(),
dtype=src.dtype,
layout=src.layout,
requires_grad=requires_grad,
device=src.device,
)
copy_grouped_storage_metadata(out, src)
return out

# Parameter construction calls detach()/alias-like paths.
if func in (torch.ops.aten.detach.default, torch.ops.aten.alias.default):
return args[0]
src = args[0]
assert isinstance(src, GroupedTensor)
if func == torch.ops.aten.detach.default:
return make_wrapper_like(src, requires_grad=False)
return make_wrapper_like(src, requires_grad=src.requires_grad)

# Parameter construction may invoke aten.expand on tensor subclasses.
# Handle this explicitly so grouped parameters can be created safely.
if func == torch.ops.aten.expand.default:
src = args[0]
assert isinstance(src, GroupedTensor)
expanded_shape = tuple(args[1])
Comment on lines +215 to +216

@vthumbe1503 vthumbe1503 Mar 5, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious why all this torch dispatch logic is not needed in MXFP8 tensor. As in how does DDP work even with discrete MXFP8 weights, if MCore uses all this ops?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question for expand_as, view, unsafe_view

src_shape = tuple(src.shape)
if len(expanded_shape) == len(src_shape):
normalized_shape = tuple(
src_shape[i] if dim == -1 else dim for i, dim in enumerate(expanded_shape)
)
if normalized_shape == src_shape:
return make_wrapper_like(src, requires_grad=src.requires_grad)
return super().__torch_dispatch__(func, types, args, kwargs)

# DDP and mcore use expand_as(self) to build a dummy autograd node and
# access gradient accumulators during parameter hook registration.
if func == torch.ops.aten.expand_as.default:
src = args[0]
other = args[1]
assert isinstance(src, GroupedTensor)
if other is src:
return _GroupedIdentityFunc.apply(src)
if tuple(other.shape) == tuple(src.shape):
return make_wrapper_like(src, requires_grad=src.requires_grad)
return super().__torch_dispatch__(func, types, args, kwargs)

# Distributed optimizer flattens detached parameters via
# model_param.detach().view(-1). Support this path explicitly by
# returning a flat view of grouped backing storage.
if func in (torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default):
src = args[0]
assert isinstance(src, GroupedTensor)
target_shape = tuple(args[1])
if target_shape in ((-1,), (src.numel(),)):
if src.rowwise_data is not None:
return src.rowwise_data.view(-1)
Comment thread
ksivaman marked this conversation as resolved.
raise RuntimeError(
f"{cls.__name__} view(-1) requires rowwise_data to be initialized"
)
raise RuntimeError(
f"{cls.__name__} only supports view(-1) for distributed optimizer flattening"
)

# Don't allow reshape/view etc.
if func in BANNED_SHAPE_OPS:
Expand Down Expand Up @@ -203,3 +334,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {}
# Do not force GroupedTensor on outputs.
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)

def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Needed during parameter creation/hook registration paths.
if other is self:
return _GroupedIdentityFunc.apply(self)
return super().expand_as(other)
Loading
Loading