diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c1d45511df..ac304d3379 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -16,7 +16,10 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import ( + GroupedTensor, + GroupedTensorStorage, +) from .base import ( get_dummy_wgrad, quantize_weight, @@ -135,9 +138,9 @@ def _make_grouped_tensor( base_split_offsets: torch.Tensor, last_dim: int, dtype: torch.dtype, - ) -> GroupedTensor: - """Wrap a packed 2D buffer as a varying-first-dimension GroupedTensor.""" - return GroupedTensor( + ) -> GroupedTensorStorage: + """Wrap a packed 2D buffer as a varying-first-dimension GroupedTensorStorage.""" + return GroupedTensorStorage( shape=(data.size(0), last_dim), dtype=dtype, num_tensors=num_gemms, @@ -154,13 +157,13 @@ def _make_grouped_bias( num_gemms: int, out_features: int, dtype: torch.dtype, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """Pack per-GEMM biases into the grouped GEMM bias format.""" bias_data = torch.stack( [_GroupedLinear._maybe_dequantize(bias, dtype) for bias in biases], dim=0, ).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_gemms, out_features), dtype=dtype, num_tensors=num_gemms, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 95e0440303..6b17d66fcd 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1050,6 +1050,9 @@ def op_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): + # No special CPU offloading logic is needed for weights. saved_weight is + # either self.weight (nn.Parameter, auto-excluded from offload) or a + # workspace freshly created each forward pass. mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) ctx.with_quantized_compute = with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 521ee59fa0..73e328c9d1 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -23,6 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer @@ -783,7 +784,7 @@ def _get_grouped_weight_for_gemm( columnwise_usage: bool, with_quantized_compute: bool, dtype: torch.dtype, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """Prepare weights for ``general_grouped_gemm_for_grouped_tensor``. Supports MXFP8/BF16/FP16 compute paths. """ @@ -800,7 +801,7 @@ def _get_grouped_weight_for_gemm( weight_parts = weight_param.split_into_quantized_tensors() dequantized = [maybe_dequantize(w, dtype) for w in weight_parts] weight_data = torch.stack(dequantized, dim=0).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups * self.out_features, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -814,7 +815,7 @@ def _get_grouped_weight_for_gemm( if weight_param.rowwise_data.dtype == dtype: return weight_param weight_data = weight_param.rowwise_data.to(dtype=dtype) - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups * self.out_features, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -866,8 +867,8 @@ def _get_weight_tensors(self) -> list[torch.nn.Parameter]: def _get_grouped_bias_for_gemm( self, dtype: torch.dtype, - ) -> Optional[torch.Tensor]: - """Build a uniform GroupedTensor of per-group biases for the cublas + ) -> Optional[GroupedTensorStorage]: + """Build a uniform GroupedTensorStorage of per-group biases for the cublas grouped GEMM. Each group expects a (1, out_features) bias vector. Returns ``None`` @@ -888,7 +889,7 @@ def _get_grouped_bias_for_gemm( ] bias_data = torch.stack(bias_list, dim=0).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1026,6 +1027,25 @@ def fuser_forward_save_ctx( return ctx = basic_op_ctxs[0] + + # Activation CPU offloading + # Note: No special logic is needed for weights. They are + # either nn.Parameter (auto-excluded from offload) or are + # temporary workspaces freshly created in each forward pass. + if is_cpu_offload_enabled(): + saved = tensors_to_save[0] + offset = 4 if self._scale_bias else 3 + if use_grouped_tensor_path: + # Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights] + grouped_x = saved[offset] + if grouped_x is not None: + mark_activation_offload(grouped_x) + else: + # Layout: [split_sizes, None, None, (scales?), *xs, *ws] + live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None] + if live_xs: + mark_activation_offload(*live_xs) + ctx.save_for_backward(*tensors_to_save[0]) num_groups = self.num_groups @@ -1110,6 +1130,10 @@ def _fuser_forward_split_quantize( xs = tex.split_quantize(x, split_sizes_int, input_quantizers) else: xs = torch.split(x, split_sizes_int) + if is_cpu_offload_enabled(): + live_xs = [t for t in xs if t is not None] + if live_xs: + start_offload(*live_xs) # Allocate output tensor in_shape = list(input_.size()) @@ -1205,7 +1229,7 @@ def _fuser_forward_grouped_tensor( grouped_x = tex.group_quantize(x, input_quantizer, num_groups, split_sizes) else: # No quantize: wrap the contiguous high-precision buffer. - grouped_x = GroupedTensor( + grouped_x = GroupedTensorStorage( shape=(total_tokens, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -1215,6 +1239,9 @@ def _fuser_forward_grouped_tensor( tensor_offsets=base_split_offsets * self.in_features, ) + if is_cpu_offload_enabled() and grouped_x is not None: + start_offload(grouped_x) + # Build the weight GroupedTensor / list. if self.single_grouped_weight: # GroupedTensor @@ -1238,7 +1265,7 @@ def _fuser_forward_grouped_tensor( # Allocate output buffer and wrap as a GroupedTensor view. out_shape = original_shape[:-1] + [self.out_features] out = torch.empty(out_shape, dtype=dtype, device=device) - grouped_out = GroupedTensor( + grouped_out = GroupedTensorStorage( shape=(total_tokens, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1566,7 +1593,7 @@ def _fuser_backward_grouped_tensor( else: dy_2d = maybe_dequantize(dy_2d, dtype) # Wrap BF16/FP16 buffer as a GroupedTensor for grouped gemm - grouped_dy = GroupedTensor( + grouped_dy = GroupedTensorStorage( shape=(total_tokens, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1602,7 +1629,7 @@ def _fuser_backward_grouped_tensor( if ctx.input_requires_grad: grad_input_shape = list(grad_output.size())[:-1] + [self.in_features] grad_input = torch.empty(grad_input_shape, dtype=dtype, device=device) - grouped_grad_input = GroupedTensor( + grouped_grad_input = GroupedTensorStorage( shape=(total_tokens, self.in_features), dtype=dtype, num_tensors=num_groups, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index ece670a539..f03ccc15b5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,6 +13,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -23,6 +24,7 @@ mark_grouped_tensor, ) from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU @@ -316,6 +318,7 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True + fc1_input_quantizer.internal = True input_quantizer = getattr(input_, "quantizer", None) if isinstance(input_, GroupedTensor) and ( isinstance(fc1_input_quantizer, MXFP8Quantizer) @@ -323,7 +326,36 @@ def fuser_forward( or isinstance(fc1_input_quantizer, NVFP4Quantizer) and isinstance(input_quantizer, NVFP4Quantizer) ): - grouped_fc1_x = input_ + # GroupedTensor is a torch.Tensor subclass, so the CPU offload + # infrastructure's prepare_for_saving treats it as a plain tensor + # and does not decompose it into its component data tensors. By + # repacking into a GroupedTensorStorage (not a torch.Tensor), we + # ensure the fuser's prepare_for_saving call correctly decomposes + # the activation before save_for_backward. + grouped_fc1_x = GroupedTensorStorage( + shape=input_.logical_shape, + dtype=input_.fake_dtype, + num_tensors=input_.num_tensors, + shapes=input_.tensor_shapes, + quantizer=input_.quantizer, + data=input_.rowwise_data, + columnwise_data=input_.columnwise_data, + scale_inv=input_.scale_inv, + columnwise_scale_inv=input_.columnwise_scale_inv, + amax=input_.amax, + columnwise_amax=input_.columnwise_amax, + scale=input_.scale, + first_dims=input_.first_dims, + last_dims=input_.last_dims, + tensor_offsets=input_.tensor_offsets, + offsets=input_.offsets, + scale_inv_offsets=input_.scale_inv_offsets, + columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales, + row_scaled_nvfp4=input_.row_scaled_nvfp4, + nvfp4_use_4over6=input_.nvfp4_use_4over6, + nvfp4_e4m3_max=input_.nvfp4_e4m3_max, + ) else: fc1_x = maybe_dequantize(input_, dtype) grouped_fc1_x = _group_quantize_for_grouped_mlp( @@ -587,7 +619,7 @@ def fuser_forward( else: fc2_out_buf = fc2_out_buf + token_bias else: - fc2_out_grouped = GroupedTensor( + fc2_out_grouped = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[0]), dtype=dtype, num_tensors=num_groups, @@ -616,7 +648,7 @@ def fuser_forward( fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - grouped_fc2_x = GroupedTensor( + grouped_fc2_x = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[1]), dtype=dtype, num_tensors=num_groups, @@ -695,6 +727,7 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] + cpu_offloading = is_cpu_offload_enabled() activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -716,6 +749,13 @@ def fuser_forward( grouped_fc_x.rowwise_data = None grouped_fc_x.scale_inv = None + if cpu_offloading: + activation_tensors = [ + t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None + ] + start_offload(*activation_tensors) + mark_activation_offload(*activation_tensors) + # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, # grouped_fc1_x, *fc1_weight_tensors] @@ -755,7 +795,7 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - fc2_saved: list[Optional[torch.Tensor]] = [ + fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..c112634024 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -387,6 +387,21 @@ def restore_from_saved( self.tensor_offsets = tensors[9] return tensors[10:] + def get_data_tensors(self): + """Get tensor fields that may be saved or offloaded.""" + return ( + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ) + def clear(self) -> None: """ Reset tensor data and clear all buffers.