Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import transformer_engine_torch as tex

from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor
from transformer_engine.pytorch.tensor.grouped_tensor import (
GroupedTensor,
GroupedTensorStorage,
)
from .base import (
get_dummy_wgrad,
quantize_weight,
Expand Down Expand Up @@ -135,9 +138,9 @@ def _make_grouped_tensor(
base_split_offsets: torch.Tensor,
last_dim: int,
dtype: torch.dtype,
) -> GroupedTensor:
"""Wrap a packed 2D buffer as a varying-first-dimension GroupedTensor."""
return GroupedTensor(
) -> GroupedTensorStorage:
"""Wrap a packed 2D buffer as a varying-first-dimension GroupedTensorStorage."""
return GroupedTensorStorage(
shape=(data.size(0), last_dim),
dtype=dtype,
num_tensors=num_gemms,
Expand All @@ -154,13 +157,13 @@ def _make_grouped_bias(
num_gemms: int,
out_features: int,
dtype: torch.dtype,
) -> GroupedTensor:
) -> GroupedTensorStorage:
"""Pack per-GEMM biases into the grouped GEMM bias format."""
bias_data = torch.stack(
[_GroupedLinear._maybe_dequantize(bias, dtype) for bias in biases],
dim=0,
).contiguous()
return GroupedTensor(
return GroupedTensorStorage(
shape=(num_gemms, out_features),
dtype=dtype,
num_tensors=num_gemms,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,9 @@ def op_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
Expand Down
47 changes: 37 additions & 10 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
Expand Down Expand Up @@ -783,7 +784,7 @@ def _get_grouped_weight_for_gemm(
columnwise_usage: bool,
with_quantized_compute: bool,
dtype: torch.dtype,
) -> GroupedTensor:
) -> GroupedTensorStorage:
"""Prepare weights for ``general_grouped_gemm_for_grouped_tensor``.
Supports MXFP8/BF16/FP16 compute paths.
"""
Expand All @@ -800,7 +801,7 @@ def _get_grouped_weight_for_gemm(
weight_parts = weight_param.split_into_quantized_tensors()
dequantized = [maybe_dequantize(w, dtype) for w in weight_parts]
weight_data = torch.stack(dequantized, dim=0).contiguous()
return GroupedTensor(
return GroupedTensorStorage(
shape=(num_groups * self.out_features, self.in_features),
dtype=dtype,
num_tensors=num_groups,
Expand All @@ -814,7 +815,7 @@ def _get_grouped_weight_for_gemm(
if weight_param.rowwise_data.dtype == dtype:
return weight_param
weight_data = weight_param.rowwise_data.to(dtype=dtype)
return GroupedTensor(
return GroupedTensorStorage(
shape=(num_groups * self.out_features, self.in_features),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -866,8 +867,8 @@ def _get_weight_tensors(self) -> list[torch.nn.Parameter]:
def _get_grouped_bias_for_gemm(
self,
dtype: torch.dtype,
) -> Optional[torch.Tensor]:
"""Build a uniform GroupedTensor of per-group biases for the cublas
) -> Optional[GroupedTensorStorage]:
"""Build a uniform GroupedTensorStorage of per-group biases for the cublas
grouped GEMM.

Each group expects a (1, out_features) bias vector. Returns ``None``
Expand All @@ -888,7 +889,7 @@ def _get_grouped_bias_for_gemm(
]
bias_data = torch.stack(bias_list, dim=0).contiguous()

return GroupedTensor(
return GroupedTensorStorage(
shape=(num_groups, self.out_features),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -1026,6 +1027,25 @@ def fuser_forward_save_ctx(
return

ctx = basic_op_ctxs[0]

# Activation CPU offloading
# Note: No special logic is needed for weights. They are
# either nn.Parameter (auto-excluded from offload) or are
# temporary workspaces freshly created in each forward pass.
if is_cpu_offload_enabled():
saved = tensors_to_save[0]
offset = 4 if self._scale_bias else 3
if use_grouped_tensor_path:
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
if grouped_x is not None:
mark_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
if live_xs:
mark_activation_offload(*live_xs)

ctx.save_for_backward(*tensors_to_save[0])

num_groups = self.num_groups
Expand Down Expand Up @@ -1110,6 +1130,10 @@ def _fuser_forward_split_quantize(
xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
else:
xs = torch.split(x, split_sizes_int)
if is_cpu_offload_enabled():
live_xs = [t for t in xs if t is not None]
if live_xs:
start_offload(*live_xs)

# Allocate output tensor
in_shape = list(input_.size())
Expand Down Expand Up @@ -1205,7 +1229,7 @@ def _fuser_forward_grouped_tensor(
grouped_x = tex.group_quantize(x, input_quantizer, num_groups, split_sizes)
else:
# No quantize: wrap the contiguous high-precision buffer.
grouped_x = GroupedTensor(
grouped_x = GroupedTensorStorage(
shape=(total_tokens, self.in_features),
dtype=dtype,
num_tensors=num_groups,
Expand All @@ -1215,6 +1239,9 @@ def _fuser_forward_grouped_tensor(
tensor_offsets=base_split_offsets * self.in_features,
)

if is_cpu_offload_enabled() and grouped_x is not None:
start_offload(grouped_x)

# Build the weight GroupedTensor / list.
if self.single_grouped_weight:
# GroupedTensor
Expand All @@ -1238,7 +1265,7 @@ def _fuser_forward_grouped_tensor(
# Allocate output buffer and wrap as a GroupedTensor view.
out_shape = original_shape[:-1] + [self.out_features]
out = torch.empty(out_shape, dtype=dtype, device=device)
grouped_out = GroupedTensor(
grouped_out = GroupedTensorStorage(
shape=(total_tokens, self.out_features),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -1566,7 +1593,7 @@ def _fuser_backward_grouped_tensor(
else:
dy_2d = maybe_dequantize(dy_2d, dtype)
# Wrap BF16/FP16 buffer as a GroupedTensor for grouped gemm
grouped_dy = GroupedTensor(
grouped_dy = GroupedTensorStorage(
shape=(total_tokens, self.out_features),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -1602,7 +1629,7 @@ def _fuser_backward_grouped_tensor(
if ctx.input_requires_grad:
grad_input_shape = list(grad_output.size())[:-1] + [self.in_features]
grad_input = torch.empty(grad_input_shape, dtype=dtype, device=device)
grouped_grad_input = GroupedTensor(
grouped_grad_input = GroupedTensorStorage(
shape=(total_tokens, self.in_features),
dtype=dtype,
num_tensors=num_groups,
Expand Down
48 changes: 44 additions & 4 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
from ...quantization import Recipe
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
Expand All @@ -23,6 +24,7 @@
mark_grouped_tensor,
)
from ...tensor.grouped_tensor import GroupedTensor
from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE
from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU
Expand Down Expand Up @@ -316,14 +318,44 @@ def fuser_forward(
# Group-quantize input tensor and convert dtypes if needed
fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
fc1_input_quantizer.optimize_for_gemm = True
fc1_input_quantizer.internal = True
input_quantizer = getattr(input_, "quantizer", None)
if isinstance(input_, GroupedTensor) and (
isinstance(fc1_input_quantizer, MXFP8Quantizer)
and isinstance(input_quantizer, MXFP8Quantizer)
or isinstance(fc1_input_quantizer, NVFP4Quantizer)
and isinstance(input_quantizer, NVFP4Quantizer)
):
grouped_fc1_x = input_
# GroupedTensor is a torch.Tensor subclass, so the CPU offload
# infrastructure's prepare_for_saving treats it as a plain tensor
# and does not decompose it into its component data tensors. By
# repacking into a GroupedTensorStorage (not a torch.Tensor), we
# ensure the fuser's prepare_for_saving call correctly decomposes
# the activation before save_for_backward.
grouped_fc1_x = GroupedTensorStorage(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we set the input quantizer with .internal = True, isn't it redundant to repack grouped_fc1_x into a GroupedTensorStorage?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are targeting different cases. The input quantizer with .interal= True```` takes effects on bf16 input, where we need to quantize it by fc1_input_quantizer. The second case is that the input is already a quantized fp8 tensor, where we need to repack it into a GroupedTensorStorage```.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Jun 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could it already be quantized? The only way to create it is from the quantizer, either just now or from a previous step (e.g. with activation recompute). If the quantizer has .internal=True, it can only be GroupedTensorStorage.

If something is incorrectly producing GroupedTensor, then that's a bug. Fixing it here is papering over the real problem.

Actually, on second thought, it makes sense that input_ can be a GroupedTensor since it comes from outside the op. It would be useful to know what use-case hit this bug though. Activation recompute?

Really the root cause is that CPU offloading doesn't handle GroupedTensor gracefully, but that would be a more involved effort.

shape=input_.logical_shape,
dtype=input_.fake_dtype,
num_tensors=input_.num_tensors,
shapes=input_.tensor_shapes,
quantizer=input_.quantizer,
data=input_.rowwise_data,
columnwise_data=input_.columnwise_data,
scale_inv=input_.scale_inv,
columnwise_scale_inv=input_.columnwise_scale_inv,
amax=input_.amax,
columnwise_amax=input_.columnwise_amax,
scale=input_.scale,
first_dims=input_.first_dims,
last_dims=input_.last_dims,
tensor_offsets=input_.tensor_offsets,
offsets=input_.offsets,
scale_inv_offsets=input_.scale_inv_offsets,
columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets,
with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales,
row_scaled_nvfp4=input_.row_scaled_nvfp4,
nvfp4_use_4over6=input_.nvfp4_use_4over6,
nvfp4_e4m3_max=input_.nvfp4_e4m3_max,
)
Comment thread
lhb8125 marked this conversation as resolved.
else:
fc1_x = maybe_dequantize(input_, dtype)
grouped_fc1_x = _group_quantize_for_grouped_mlp(
Expand Down Expand Up @@ -587,7 +619,7 @@ def fuser_forward(
else:
fc2_out_buf = fc2_out_buf + token_bias
else:
fc2_out_grouped = GroupedTensor(
fc2_out_grouped = GroupedTensorStorage(
shape=(in_shape[0], fc2_weight_shape[0]),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -616,7 +648,7 @@ def fuser_forward(
fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"]
fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3)

grouped_fc2_x = GroupedTensor(
grouped_fc2_x = GroupedTensorStorage(
shape=(in_shape[0], fc2_weight_shape[1]),
dtype=dtype,
num_tensors=num_groups,
Expand Down Expand Up @@ -695,6 +727,7 @@ def fuser_forward(
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
getattr(activation_op, "activation_recompute_in_mlp", False)
Expand All @@ -716,6 +749,13 @@ def fuser_forward(
grouped_fc_x.rowwise_data = None
grouped_fc_x.scale_inv = None

if cpu_offloading:
activation_tensors = [
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
]
start_offload(*activation_tensors)
mark_activation_offload(*activation_tensors)

# FC1 saved-tensor layout.
# [split_sizes, base_split_offsets, split_points,
# grouped_fc1_x, *fc1_weight_tensors]
Expand Down Expand Up @@ -755,7 +795,7 @@ def fuser_forward(
fc2_weight_tensors = (
[grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight
)
fc2_saved: list[Optional[torch.Tensor]] = [
fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [
split_sizes,
base_split_offsets,
split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,21 @@ def restore_from_saved(
self.tensor_offsets = tensors[9]
return tensors[10:]

def get_data_tensors(self):
"""Get tensor fields that may be saved or offloaded."""
return (
self.rowwise_data,
self.columnwise_data,
self.scale_inv,
self.columnwise_scale_inv,
self.amax,
self.columnwise_amax,
self.scale,
self.first_dims,
self.last_dims,
self.tensor_offsets,
)

def clear(self) -> None:
"""
Reset tensor data and clear all buffers.
Expand Down
Loading