From 3a6683f8d43f8c1248c74a711606faa394e78e68 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Tue, 2 Jun 2026 03:58:34 -0700 Subject: [PATCH 1/6] Support selective offload for fused grouped MLP Signed-off-by: hongbinl --- .../pytorch/ops/fused/forward_grouped_mlp.py | 47 +++++++++++++++++-- .../tensor/storage/grouped_tensor_storage.py | 15 ++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index ece670a539..7d8efa77da 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,6 +13,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -23,6 +24,7 @@ mark_grouped_tensor, ) from ...tensor.grouped_tensor import GroupedTensor +from ...tensor.storage.grouped_tensor_storage import GroupedTensorStorage from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledSReLU, ScaledClampedQGeGLU @@ -316,6 +318,7 @@ def fuser_forward( # Group-quantize input tensor and convert dtypes if needed fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) fc1_input_quantizer.optimize_for_gemm = True + fc1_input_quantizer.internal = True input_quantizer = getattr(input_, "quantizer", None) if isinstance(input_, GroupedTensor) and ( isinstance(fc1_input_quantizer, MXFP8Quantizer) @@ -323,7 +326,30 @@ def fuser_forward( or isinstance(fc1_input_quantizer, NVFP4Quantizer) and isinstance(input_quantizer, NVFP4Quantizer) ): - grouped_fc1_x = input_ + grouped_fc1_x = GroupedTensorStorage( + shape=input_.logical_shape, + dtype=input_.fake_dtype, + num_tensors=input_.num_tensors, + shapes=input_.tensor_shapes, + quantizer=input_.quantizer, + data=input_.rowwise_data, + columnwise_data=input_.columnwise_data, + scale_inv=input_.scale_inv, + columnwise_scale_inv=input_.columnwise_scale_inv, + amax=input_.amax, + columnwise_amax=input_.columnwise_amax, + scale=input_.scale, + first_dims=input_.first_dims, + last_dims=input_.last_dims, + tensor_offsets=input_.tensor_offsets, + offsets=input_.offsets, + scale_inv_offsets=input_.scale_inv_offsets, + columnwise_scale_inv_offsets=input_.columnwise_scale_inv_offsets, + with_gemm_swizzled_scales=input_._with_gemm_swizzled_scales, + row_scaled_nvfp4=input_.row_scaled_nvfp4, + nvfp4_use_4over6=input_.nvfp4_use_4over6, + nvfp4_e4m3_max=input_.nvfp4_e4m3_max, + ) else: fc1_x = maybe_dequantize(input_, dtype) grouped_fc1_x = _group_quantize_for_grouped_mlp( @@ -616,7 +642,7 @@ def fuser_forward( fc2_in_col_scale = fc1_kernel_out["sfd_col_tensor"] fc2_in_col_scale = fc2_in_col_scale.permute(5, 2, 4, 0, 1, 3) - grouped_fc2_x = GroupedTensor( + grouped_fc2_x = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[1]), dtype=dtype, num_tensors=num_groups, @@ -695,6 +721,9 @@ def fuser_forward( if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] + cpu_offloading = is_cpu_offload_enabled() + no_offload_fc1_activation = bool(getattr(fc1_op, "no_offload_activation", False)) + no_offload_moe_activation = bool(getattr(activation_op, "no_offload_activation", False)) activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -722,6 +751,10 @@ def fuser_forward( fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) + if cpu_offloading: + if no_offload_fc1_activation: + mark_not_offload(grouped_fc1_x) + mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -740,6 +773,8 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation + if cpu_offloading and no_offload_moe_activation: + mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True @@ -755,7 +790,13 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - fc2_saved: list[Optional[torch.Tensor]] = [ + if cpu_offloading: + if saved_grouped_fc2_x is not None: + # FC2 input is saved for FC2 wgrad, but it is not the Megatron moe_act + # activation target controlled above. Keep this extra saved tensor resident. + mark_not_offload(saved_grouped_fc2_x) + mark_not_offload(*fc2_weight_tensors) + fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..c112634024 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -387,6 +387,21 @@ def restore_from_saved( self.tensor_offsets = tensors[9] return tensors[10:] + def get_data_tensors(self): + """Get tensor fields that may be saved or offloaded.""" + return ( + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ) + def clear(self) -> None: """ Reset tensor data and clear all buffers. From 376d28ceb63aa2a160c0e5474d3588ac529075c6 Mon Sep 17 00:00:00 2001 From: hongbinl Date: Fri, 5 Jun 2026 01:51:45 -0700 Subject: [PATCH 2/6] Add no_offload_activation to grouped MLP ops Signed-off-by: hongbinl --- .../pytorch/ops/basic/activation.py | 34 ++++++++++++++--- .../pytorch/ops/basic/grouped_linear.py | 27 +++++++++++++ .../pytorch/ops/basic/swiglu.py | 38 +++++++++++++++++-- .../pytorch/ops/fused/forward_grouped_mlp.py | 4 +- 4 files changed, 92 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index f4beffe90c..73bbb2c536 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -60,12 +60,21 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): pass. This will typically reduce memory usage but require extra compute and increase numerical error. This feature is highly experimental. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. """ - def __init__(self, *, cache_quantized_input: bool = False): + def __init__( + self, + *, + cache_quantized_input: bool = False, + no_offload_activation: bool = False, + ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input + self.no_offload_activation: bool = no_offload_activation @abc.abstractmethod def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: @@ -115,7 +124,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + if self.no_offload_activation: + mark_not_offload(x) + else: + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -359,13 +371,22 @@ class ScaledSReLU(BasicOperation): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. """ num_extra_inputs: int = 1 - def __init__(self, *, activation_recompute_in_mlp: bool = False) -> None: + def __init__( + self, + *, + activation_recompute_in_mlp: bool = False, + no_offload_activation: bool = False, + ) -> None: super().__init__() self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp + self.no_offload_activation: bool = no_offload_activation def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( @@ -415,7 +436,10 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + if self.no_offload_activation: + mark_not_offload(x, scales) + else: + mark_activation_offload(x) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 521ee59fa0..f7da9a8263 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex from ...constants import DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...cpp_extensions import general_grouped_gemm, general_grouped_gemm_for_grouped_tensor from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore @@ -104,6 +105,9 @@ class GroupedLinear(BasicOperation): additional extra input and adds ``bias * scales`` instead of ``bias`` in the forward pass. The scale tensor has shape ``(total_tokens,)`` and is split according to the split sizes. + no_offload_activation : bool, default = ``False`` + Keep saved input activation tensors resident on GPU when CPU offload + is enabled. """ @@ -125,10 +129,12 @@ def __init__( single_grouped_bias: bool = False, delay_wgrad_compute: bool = False, scale_bias: bool = False, + no_offload_activation: bool = False, ) -> None: super().__init__() self._scale_bias: bool = scale_bias and bias + self.no_offload_activation: bool = no_offload_activation if self._scale_bias: self.num_extra_inputs = 2 @@ -1026,6 +1032,27 @@ def fuser_forward_save_ctx( return ctx = basic_op_ctxs[0] + + if is_cpu_offload_enabled(): + saved_tensors = tensors_to_save[0] + activation_start = 4 if self._scale_bias else 3 + activation_count = 1 if use_grouped_tensor_path else self.num_groups + activation_end = activation_start + activation_count + activation_tensors = tuple( + tensor + for tensor in saved_tensors[activation_start:activation_end] + if tensor is not None + ) + weight_tensors = tuple( + tensor for tensor in saved_tensors[activation_end:] if tensor is not None + ) + + if self.no_offload_activation: + mark_not_offload(*activation_tensors) + else: + mark_activation_offload(*activation_tensors) + mark_not_offload(*weight_tensors) + ctx.save_for_backward(*tensors_to_save[0]) num_groups = self.num_groups diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 02f330ede3..a189d53d4f 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -12,7 +12,7 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -67,6 +67,9 @@ class SwiGLU(BasicOperation): when the interleave size is 2). This data format is highly experiental and is primarily intended to support some advanced fused kernels. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. """ @@ -75,10 +78,12 @@ def __init__( *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, + no_offload_activation: bool = False, ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size + self.no_offload_activation: bool = no_offload_activation def op_forward( self, @@ -128,7 +133,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(input_) + if self.no_offload_activation: + mark_not_offload(input_) + else: + mark_activation_offload(input_) ctx.save_for_backward(input_) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -218,6 +226,9 @@ class ClampedSwiGLU(BasicOperation): When set, the GLU activations will use an experimental block interleaved format. See the corresponding option in the SwiGLU operation for more details. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. """ @@ -229,6 +240,7 @@ def __init__( glu_linear_offset: float = 1.0, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, + no_offload_activation: bool = False, ): super().__init__() self.limit: float = limit @@ -236,6 +248,7 @@ def __init__( self.glu_linear_offset: float = glu_linear_offset self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size + self.no_offload_activation: bool = no_offload_activation def _tex_clamped_swiglu_forward( self, @@ -312,7 +325,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + if self.no_offload_activation: + mark_not_offload(x) + else: + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -382,10 +398,12 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, + no_offload_activation: bool = False, ) -> None: super().__init__() self.glu_interleave_size: Optional[int] = glu_interleave_size self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp + self.no_offload_activation: bool = no_offload_activation def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -463,7 +481,10 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(input_) + if self.no_offload_activation: + mark_not_offload(input_, scales) + else: + mark_activation_offload(input_) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype @@ -555,6 +576,9 @@ class ScaledSwiGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. """ @@ -585,6 +609,9 @@ class ScaledClampedQGeGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. + no_offload_activation : bool, default = ``False`` + Keep saved activation tensors resident on GPU when CPU offload + is enabled. limit : float, default ``7.0`` Clamp limit (see :class:`ClampedSwiGLU`). alpha : float, default ``1.702`` @@ -600,6 +627,7 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, + no_offload_activation: bool = False, limit: float = 7.0, alpha: float = 1.702, glu_linear_offset: float = 1.0, @@ -607,11 +635,13 @@ def __init__( super().__init__( glu_interleave_size, activation_recompute_in_mlp=activation_recompute_in_mlp, + no_offload_activation=no_offload_activation, ) self._clamped: ClampedSwiGLU = ClampedSwiGLU( limit=limit, alpha=alpha, glu_linear_offset=glu_linear_offset, + no_offload_activation=no_offload_activation, ) def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 7d8efa77da..f0eef41cd8 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -722,8 +722,8 @@ def fuser_forward( mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] cpu_offloading = is_cpu_offload_enabled() - no_offload_fc1_activation = bool(getattr(fc1_op, "no_offload_activation", False)) - no_offload_moe_activation = bool(getattr(activation_op, "no_offload_activation", False)) + no_offload_fc1_activation = fc1_op.no_offload_activation + no_offload_moe_activation = activation_op.no_offload_activation activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) From 933d64b8210e5326e084c0ca0580b5f312e730fb Mon Sep 17 00:00:00 2001 From: hongbinl Date: Fri, 5 Jun 2026 02:03:44 -0700 Subject: [PATCH 3/6] Use offload_activation API for activation offload control Signed-off-by: hongbinl --- .../pytorch/module/grouped_linear.py | 19 ++++++- .../pytorch/module/layernorm_linear.py | 11 +++- .../pytorch/module/layernorm_mlp.py | 32 ++++++++++- transformer_engine/pytorch/module/linear.py | 12 +++- .../pytorch/ops/basic/activation.py | 30 +++++----- .../pytorch/ops/basic/basic_linear.py | 12 +++- .../pytorch/ops/basic/dropout.py | 10 +++- .../pytorch/ops/basic/grouped_linear.py | 15 +++-- .../pytorch/ops/basic/l2normalization.py | 11 +++- .../pytorch/ops/basic/layer_norm.py | 11 +++- .../pytorch/ops/basic/rmsnorm.py | 11 +++- .../pytorch/ops/basic/swiglu.py | 56 +++++++++---------- .../pytorch/ops/fused/forward_grouped_mlp.py | 17 ++++-- .../fused/forward_linear_bias_activation.py | 8 ++- .../ops/fused/forward_linear_bias_add.py | 8 ++- .../ops/fused/forward_linear_scale_add.py | 8 ++- .../ops/fused/userbuffers_forward_linear.py | 8 ++- transformer_engine/pytorch/ops/linear.py | 4 ++ 18 files changed, 198 insertions(+), 85 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 15ec3fe322..5c31670c4d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -49,7 +49,12 @@ ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo -from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload +from ..cpu_offload import ( + is_cpu_offload_enabled, + mark_activation_offload, + mark_not_offload, + start_offload, +) from ..triton.grouped_dbias_dscales import compute_grouped_dbias from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer @@ -402,6 +407,7 @@ def forward( grad_output_quantizers, fuse_wgrad_accumulation, cpu_offloading, + offload_activation, sequence_parallel, activation_dtype, is_grad_enabled, @@ -627,6 +633,12 @@ def forward( else: inputmats = [None] * num_gemms + if cpu_offloading: + if offload_activation: + mark_activation_offload(*inputmats) + else: + mark_not_offload(*inputmats) + # Original weights are only needed by high_precision dgrad. The weakrefs # used for fused wgrad accumulation serve a different purpose: restoring # Python parameter attributes without keeping the parameter alive here. @@ -1232,6 +1244,8 @@ class GroupedLinear(TransformerEngineBaseModule): EXPERIMENTAL and subject to change. Gated by the ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var is not set this argument is forced to ``False`` with a warning. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. Notes ----- @@ -1264,6 +1278,7 @@ def __init__( save_original_input: bool = False, single_grouped_weight: bool = False, single_grouped_bias: bool = False, + offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1280,6 +1295,7 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + self.offload_activation = offload_activation single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( single_grouped_weight, single_grouped_bias ) @@ -1749,6 +1765,7 @@ def forward( grad_output_quantizers, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), + self.offload_activation, self.sequence_parallel, self.activation_dtype, is_grad_enabled, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7fc96d4779..66b9bb79f9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -119,6 +119,7 @@ def forward( grad_weight_quantizer, grad_output_quantizer, cpu_offloading, + offload_activation, tp_group, tp_size, sequence_parallel, @@ -458,7 +459,10 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) + if offload_activation: + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) + else: + mark_not_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1250,6 +1254,8 @@ class LayerNormLinear(TransformerEngineBaseModule): This can help in latency bound communication situations. Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce is used. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1281,6 +1287,7 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, + offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1300,6 +1307,7 @@ def __init__( ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1730,6 +1738,7 @@ def forward( grad_weight_quantizer, grad_output_quantizer, is_cpu_offload_enabled(), + self.offload_activation, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6c6cca74ef..e51e608c6e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -210,6 +210,7 @@ def _forward( fc2_grad_weight_quantizer, fc2_grad_output_quantizer, cpu_offloading, + offload_activation, tp_group, tp_size, sequence_parallel, @@ -279,6 +280,20 @@ def _forward( # save the initial state for recomputation by bwd if save_for_checkpoint: + if cpu_offloading: + if offload_activation: + mark_activation_offload(inp) + else: + mark_not_offload(inp) + mark_not_offload( + ln_weight, + ln_bias, + fc1_weight, + fc1_bias, + fc2_weight, + fc2_bias, + ) + # save tensors tensors_to_save, tensor_objects = prepare_for_saving( inp, @@ -312,6 +327,7 @@ def _forward( "fc2_grad_weight_quantizer": fc2_grad_weight_quantizer, "fc2_grad_output_quantizer": fc2_grad_output_quantizer, "cpu_offloading": cpu_offloading, + "offload_activation": offload_activation, "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": sequence_parallel, @@ -740,9 +756,14 @@ def _forward( if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: - mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out - ) + if offload_activation: + mark_activation_offload( + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + ) + else: + mark_not_offload( + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + ) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1928,6 +1949,8 @@ class LayerNormMLP(TransformerEngineBaseModule): whether to use selective activation checkpointing, where activations are not saved for bwd, and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute for memory. default is false, in which activations are saved in fwd. not supported for onnx forward + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1964,6 +1987,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, + offload_activation: bool = True, ) -> None: super().__init__(name) @@ -1985,6 +2009,7 @@ def __init__( self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type self.checkpoint = checkpoint + self.offload_activation = offload_activation # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( @@ -2378,6 +2403,7 @@ def forward( fc2_grad_weight_quantizer, fc2_grad_output_quantizer, is_cpu_offload_enabled(), + self.offload_activation, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c2d98d160..6a594e6c29 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -153,6 +153,7 @@ class LinearFwdArgs: # --- Misc --- cpu_offloading: bool + offload_activation: bool is_grad_enabled: bool @@ -269,6 +270,7 @@ def _linear_forward_impl( is_first_microbatch = args.is_first_microbatch fp8 = args.fp8 cpu_offloading = args.cpu_offloading + offload_activation = args.offload_activation tp_group = args.tp_group sequence_parallel = args.sequence_parallel activation_dtype = args.activation_dtype @@ -565,7 +567,10 @@ def _linear_forward_impl( saved_inputmat = inputmat if cpu_offloading and saved_inputmat is not None: - mark_activation_offload(saved_inputmat) + if offload_activation: + mark_activation_offload(saved_inputmat) + else: + mark_not_offload(saved_inputmat) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights @@ -1456,6 +1461,8 @@ class Linear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1484,6 +1491,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, + offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1499,6 +1507,7 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1947,6 +1956,7 @@ def forward( wgrad_store=wgrad_store, # misc cpu_offloading=is_cpu_offload_enabled(), + offload_activation=self.offload_activation, is_grad_enabled=is_grad_enabled, ) out, new_weight_workspace = linear_fn( diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 73bbb2c536..2f25ffd082 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -60,9 +60,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): pass. This will typically reduce memory usage but require extra compute and increase numerical error. This feature is highly experimental. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -70,11 +69,11 @@ def __init__( self, *, cache_quantized_input: bool = False, - no_offload_activation: bool = False, + offload_activation: bool = True, ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation @abc.abstractmethod def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: @@ -124,10 +123,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.no_offload_activation: - mark_not_offload(x) - else: + if self.offload_activation: mark_activation_offload(x) + else: + mark_not_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -371,9 +370,8 @@ class ScaledSReLU(BasicOperation): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ num_extra_inputs: int = 1 @@ -382,11 +380,11 @@ def __init__( self, *, activation_recompute_in_mlp: bool = False, - no_offload_activation: bool = False, + offload_activation: bool = True, ) -> None: super().__init__() self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( @@ -436,10 +434,10 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.no_offload_activation: - mark_not_offload(x, scales) - else: + if self.offload_activation: mark_activation_offload(x) + else: + mark_not_offload(x, scales) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 95e0440303..93bd17a013 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,7 +13,7 @@ import torch from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -91,6 +91,8 @@ class BasicLinear(BasicOperation): Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly experimental. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -107,8 +109,10 @@ def __init__( rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, userbuffers_options: Optional[dict[str, Any]] = None, + offload_activation: bool = True, ) -> None: super().__init__() + self.offload_activation: bool = offload_activation # Weight tensor dimensions self.in_features: int = in_features @@ -1050,7 +1054,11 @@ def op_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + if self.offload_activation: + mark_activation_offload(saved_input) + else: + mark_not_offload(saved_input) + mark_not_offload(saved_weight) ctx.save_for_backward(saved_input, saved_weight) ctx.with_quantized_compute = with_quantized_compute and backward_override is None ctx.backward_override = backward_override diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 8850604aad..fc113183cb 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,7 +9,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...tensor import Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize @@ -25,9 +25,10 @@ class Dropout(BasicOperation): """ - def __init__(self, p: float) -> None: + def __init__(self, p: float, *, offload_activation: bool = True) -> None: super().__init__() self.dropout_probability: float = p + self.offload_activation: bool = offload_activation def op_forward( self, @@ -72,7 +73,10 @@ def op_forward( # Save context for backward if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(mask) + if self.offload_activation: + mark_activation_offload(mask) + else: + mark_not_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index f7da9a8263..3ac7c2e69e 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -105,9 +105,8 @@ class GroupedLinear(BasicOperation): additional extra input and adds ``bias * scales`` instead of ``bias`` in the forward pass. The scale tensor has shape ``(total_tokens,)`` and is split according to the split sizes. - no_offload_activation : bool, default = ``False`` - Keep saved input activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved input activation tensors when CPU offload is enabled. """ @@ -129,12 +128,12 @@ def __init__( single_grouped_bias: bool = False, delay_wgrad_compute: bool = False, scale_bias: bool = False, - no_offload_activation: bool = False, + offload_activation: bool = True, ) -> None: super().__init__() self._scale_bias: bool = scale_bias and bias - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation if self._scale_bias: self.num_extra_inputs = 2 @@ -1047,10 +1046,10 @@ def fuser_forward_save_ctx( tensor for tensor in saved_tensors[activation_end:] if tensor is not None ) - if self.no_offload_activation: - mark_not_offload(*activation_tensors) - else: + if self.offload_activation: mark_activation_offload(*activation_tensors) + else: + mark_not_offload(*activation_tensors) mark_not_offload(*weight_tensors) ctx.save_for_backward(*tensors_to_save[0]) diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index be155c9356..bdf4ef8c51 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -11,7 +11,7 @@ import torch from ...torch_version import torch_version -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -48,6 +48,8 @@ class L2Normalization(BasicOperation): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -57,9 +59,11 @@ def __init__( eps: float = 1e-6, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps + self.offload_activation: bool = offload_activation # JIT warmup for L2Normalization fused operations if seq_length and micro_batch_size: @@ -103,7 +107,10 @@ def op_forward( # Save state for backward pass if requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, rsqrt_norm) + if self.offload_activation: + mark_activation_offload(x, rsqrt_norm) + else: + mark_not_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3fda5145c6..74005163e8 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -64,6 +64,8 @@ class LayerNorm(BasicOperation): For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", "inference"). + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -76,10 +78,12 @@ def __init__( dtype: Optional[torch.dtype] = None, zero_centered_gamma: bool = False, sm_margin: int | dict[str, int] = 0, + offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps self.zero_centered_gamma: bool = zero_centered_gamma + self.offload_activation: bool = offload_activation # Parameter shape if not isinstance(normalized_shape, Iterable): @@ -217,7 +221,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, means, rstdevs) + if self.offload_activation: + mark_activation_offload(x, means, rstdevs) + else: + mark_not_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..c7d8fe1e69 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -63,6 +63,8 @@ class RMSNorm(BasicOperation): For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", "inference"). + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -75,10 +77,12 @@ def __init__( dtype: Optional[torch.dtype] = None, zero_centered_gamma: bool = False, sm_margin: int = 0, + offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps self.zero_centered_gamma: bool = zero_centered_gamma + self.offload_activation: bool = offload_activation # Parameter shape if not isinstance(normalized_shape, Iterable): @@ -198,7 +202,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, rstdevs) + if self.offload_activation: + mark_activation_offload(x, rstdevs) + else: + mark_not_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index a189d53d4f..0df720311c 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -67,9 +67,8 @@ class SwiGLU(BasicOperation): when the interleave size is 2). This data format is highly experiental and is primarily intended to support some advanced fused kernels. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -78,12 +77,12 @@ def __init__( *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, - no_offload_activation: bool = False, + offload_activation: bool = True, ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation def op_forward( self, @@ -133,10 +132,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.no_offload_activation: - mark_not_offload(input_) - else: + if self.offload_activation: mark_activation_offload(input_) + else: + mark_not_offload(input_) ctx.save_for_backward(input_) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -226,9 +225,8 @@ class ClampedSwiGLU(BasicOperation): When set, the GLU activations will use an experimental block interleaved format. See the corresponding option in the SwiGLU operation for more details. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -240,7 +238,7 @@ def __init__( glu_linear_offset: float = 1.0, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, - no_offload_activation: bool = False, + offload_activation: bool = True, ): super().__init__() self.limit: float = limit @@ -248,7 +246,7 @@ def __init__( self.glu_linear_offset: float = glu_linear_offset self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation def _tex_clamped_swiglu_forward( self, @@ -325,10 +323,10 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.no_offload_activation: - mark_not_offload(x) - else: + if self.offload_activation: mark_activation_offload(x) + else: + mark_not_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -398,12 +396,12 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, - no_offload_activation: bool = False, + offload_activation: bool = True, ) -> None: super().__init__() self.glu_interleave_size: Optional[int] = glu_interleave_size self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp - self.no_offload_activation: bool = no_offload_activation + self.offload_activation: bool = offload_activation def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -481,10 +479,10 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.no_offload_activation: - mark_not_offload(input_, scales) - else: + if self.offload_activation: mark_activation_offload(input_) + else: + mark_not_offload(input_, scales) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype @@ -576,9 +574,8 @@ class ScaledSwiGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -609,9 +606,8 @@ class ScaledClampedQGeGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - no_offload_activation : bool, default = ``False`` - Keep saved activation tensors resident on GPU when CPU offload - is enabled. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. limit : float, default ``7.0`` Clamp limit (see :class:`ClampedSwiGLU`). alpha : float, default ``1.702`` @@ -627,7 +623,7 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, - no_offload_activation: bool = False, + offload_activation: bool = True, limit: float = 7.0, alpha: float = 1.702, glu_linear_offset: float = 1.0, @@ -635,13 +631,13 @@ def __init__( super().__init__( glu_interleave_size, activation_recompute_in_mlp=activation_recompute_in_mlp, - no_offload_activation=no_offload_activation, + offload_activation=offload_activation, ) self._clamped: ClampedSwiGLU = ClampedSwiGLU( limit=limit, alpha=alpha, glu_linear_offset=glu_linear_offset, - no_offload_activation=no_offload_activation, + offload_activation=offload_activation, ) def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f0eef41cd8..0e34131d2a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -722,8 +722,8 @@ def fuser_forward( mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] cpu_offloading = is_cpu_offload_enabled() - no_offload_fc1_activation = fc1_op.no_offload_activation - no_offload_moe_activation = activation_op.no_offload_activation + offload_fc1_activation = fc1_op.offload_activation + offload_moe_activation = activation_op.offload_activation activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -752,7 +752,9 @@ def fuser_forward( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) if cpu_offloading: - if no_offload_fc1_activation: + if offload_fc1_activation: + mark_activation_offload(grouped_fc1_x) + else: mark_not_offload(grouped_fc1_x) mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( @@ -773,8 +775,11 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation - if cpu_offloading and no_offload_moe_activation: - mark_not_offload(activation_in, scales) + if cpu_offloading: + if offload_moe_activation: + mark_activation_offload(activation_in, scales) + else: + mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8df929f799..f7724b153c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias @@ -130,7 +130,11 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + if linear_op.offload_activation: + mark_activation_offload(saved_input) + else: + mark_not_offload(saved_input) + mark_not_offload(saved_weight) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5376a7d264..71bb65a5a3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias @@ -127,7 +127,11 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + if linear_op.offload_activation: + mark_activation_offload(saved_input) + else: + mark_not_offload(saved_input) + mark_not_offload(saved_weight) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index abeb39adfa..803227c204 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale @@ -108,7 +108,11 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + if linear_op.offload_activation: + mark_activation_offload(saved_input) + else: + mark_not_offload(saved_input) + mark_not_offload(saved_weight) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3a8ff5438d..323225baa8 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,7 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...distributed import get_distributed_world_size from ...quantization import FP8GlobalStateManager from ...module.base import ( @@ -355,7 +355,11 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x_local) + if linear_op.offload_activation: + mark_activation_offload(x_local) + else: + mark_not_offload(x_local) + mark_not_offload(w) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index c6ca4786b8..7b2fc0f502 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -56,6 +56,8 @@ class Linear(FusedOperation): there is no guarantee that ``grad`` will be set or be meaningful. This is primarily intended to integrate with Megatron-LM. + offload_activation : bool, default = ``True`` + Offload saved activation tensors when CPU offload is enabled. """ @@ -72,6 +74,7 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + offload_activation: bool = True, ) -> None: # Tensor parallel configuration @@ -104,6 +107,7 @@ def __init__( "sequence_parallel": sequence_parallel, "rng_state_tracker_function": rng_state_tracker_function, "accumulate_into_main_grad": accumulate_into_main_grad, + "offload_activation": offload_activation, } bias_kwargs = { "size": out_features, From 1ce8fd28d63577afab67d007276a68151671504e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 5 Jun 2026 21:25:36 +0000 Subject: [PATCH 4/6] Fix CPU offloading correctness in ops layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Revert per-module offload_activation API added in commits 376d28ceb and 933d64b82; that belongs in a separate PR. - ops/basic/grouped_linear: add start_offload on input tensors before the GEMM, and mark_activation_offload / mark_not_offload in fuser_forward_save_ctx for both the split-quantize and grouped-tensor paths. - ops/fused/forward_grouped_mlp: remove no_offload_activation attribute lookups and the activation mark_not_offload calls that gated on them; add start_offload + mark_activation_offload for all saved activation tensors (grouped_fc1_x, activation_in, saved_grouped_fc2_x) and keep mark_not_offload only for weight tensors. Document why grouped_fc1_x is repacked into GroupedTensorStorage. - ops/basic/basic_linear: no change needed beyond the existing mark_activation_offload — unlike te.Linear there is no persistent weight cache, so the quantized weight workspace can be freely offloaded. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Tim Moon --- .../pytorch/module/grouped_linear.py | 19 +------- .../pytorch/module/layernorm_linear.py | 11 +---- .../pytorch/module/layernorm_mlp.py | 32 ++----------- transformer_engine/pytorch/module/linear.py | 12 +---- .../pytorch/ops/basic/activation.py | 32 +++---------- .../pytorch/ops/basic/basic_linear.py | 15 +++---- .../pytorch/ops/basic/dropout.py | 10 ++--- .../pytorch/ops/basic/grouped_linear.py | 45 ++++++++++--------- .../pytorch/ops/basic/l2normalization.py | 11 +---- .../pytorch/ops/basic/layer_norm.py | 11 +---- .../pytorch/ops/basic/rmsnorm.py | 11 +---- .../pytorch/ops/basic/swiglu.py | 34 ++------------ .../pytorch/ops/fused/forward_grouped_mlp.py | 35 +++++++-------- .../fused/forward_linear_bias_activation.py | 8 +--- .../ops/fused/forward_linear_bias_add.py | 8 +--- .../ops/fused/forward_linear_scale_add.py | 8 +--- .../ops/fused/userbuffers_forward_linear.py | 8 +--- transformer_engine/pytorch/ops/linear.py | 4 -- 18 files changed, 75 insertions(+), 239 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5c31670c4d..15ec3fe322 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -49,12 +49,7 @@ ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo -from ..cpu_offload import ( - is_cpu_offload_enabled, - mark_activation_offload, - mark_not_offload, - start_offload, -) +from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..triton.grouped_dbias_dscales import compute_grouped_dbias from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer @@ -407,7 +402,6 @@ def forward( grad_output_quantizers, fuse_wgrad_accumulation, cpu_offloading, - offload_activation, sequence_parallel, activation_dtype, is_grad_enabled, @@ -633,12 +627,6 @@ def forward( else: inputmats = [None] * num_gemms - if cpu_offloading: - if offload_activation: - mark_activation_offload(*inputmats) - else: - mark_not_offload(*inputmats) - # Original weights are only needed by high_precision dgrad. The weakrefs # used for fused wgrad accumulation serve a different purpose: restoring # Python parameter attributes without keeping the parameter alive here. @@ -1244,8 +1232,6 @@ class GroupedLinear(TransformerEngineBaseModule): EXPERIMENTAL and subject to change. Gated by the ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var is not set this argument is forced to ``False`` with a warning. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. Notes ----- @@ -1278,7 +1264,6 @@ def __init__( save_original_input: bool = False, single_grouped_weight: bool = False, single_grouped_bias: bool = False, - offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1295,7 +1280,6 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input - self.offload_activation = offload_activation single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( single_grouped_weight, single_grouped_bias ) @@ -1765,7 +1749,6 @@ def forward( grad_output_quantizers, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), - self.offload_activation, self.sequence_parallel, self.activation_dtype, is_grad_enabled, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 66b9bb79f9..7fc96d4779 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -119,7 +119,6 @@ def forward( grad_weight_quantizer, grad_output_quantizer, cpu_offloading, - offload_activation, tp_group, tp_size, sequence_parallel, @@ -459,10 +458,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - if offload_activation: - mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) - else: - mark_not_offload(inputmat, mu, rsigma, ln_out_to_save) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1254,8 +1250,6 @@ class LayerNormLinear(TransformerEngineBaseModule): This can help in latency bound communication situations. Requires PyTorch version 2.7.0 or higher. When set to ``None``, standard all-reduce is used. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1287,7 +1281,6 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1307,7 +1300,6 @@ def __init__( ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1738,7 +1730,6 @@ def forward( grad_weight_quantizer, grad_output_quantizer, is_cpu_offload_enabled(), - self.offload_activation, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e51e608c6e..6c6cca74ef 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -210,7 +210,6 @@ def _forward( fc2_grad_weight_quantizer, fc2_grad_output_quantizer, cpu_offloading, - offload_activation, tp_group, tp_size, sequence_parallel, @@ -280,20 +279,6 @@ def _forward( # save the initial state for recomputation by bwd if save_for_checkpoint: - if cpu_offloading: - if offload_activation: - mark_activation_offload(inp) - else: - mark_not_offload(inp) - mark_not_offload( - ln_weight, - ln_bias, - fc1_weight, - fc1_bias, - fc2_weight, - fc2_bias, - ) - # save tensors tensors_to_save, tensor_objects = prepare_for_saving( inp, @@ -327,7 +312,6 @@ def _forward( "fc2_grad_weight_quantizer": fc2_grad_weight_quantizer, "fc2_grad_output_quantizer": fc2_grad_output_quantizer, "cpu_offloading": cpu_offloading, - "offload_activation": offload_activation, "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": sequence_parallel, @@ -756,14 +740,9 @@ def _forward( if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: - if offload_activation: - mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out - ) - else: - mark_not_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out - ) + mark_activation_offload( + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + ) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1949,8 +1928,6 @@ class LayerNormMLP(TransformerEngineBaseModule): whether to use selective activation checkpointing, where activations are not saved for bwd, and instead are recomputed (skipping fc2, as it is not needed for backward). Trades compute for memory. default is false, in which activations are saved in fwd. not supported for onnx forward - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1987,7 +1964,6 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, - offload_activation: bool = True, ) -> None: super().__init__(name) @@ -2009,7 +1985,6 @@ def __init__( self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type self.checkpoint = checkpoint - self.offload_activation = offload_activation # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( @@ -2403,7 +2378,6 @@ def forward( fc2_grad_weight_quantizer, fc2_grad_output_quantizer, is_cpu_offload_enabled(), - self.offload_activation, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6a594e6c29..6c2d98d160 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -153,7 +153,6 @@ class LinearFwdArgs: # --- Misc --- cpu_offloading: bool - offload_activation: bool is_grad_enabled: bool @@ -270,7 +269,6 @@ def _linear_forward_impl( is_first_microbatch = args.is_first_microbatch fp8 = args.fp8 cpu_offloading = args.cpu_offloading - offload_activation = args.offload_activation tp_group = args.tp_group sequence_parallel = args.sequence_parallel activation_dtype = args.activation_dtype @@ -567,10 +565,7 @@ def _linear_forward_impl( saved_inputmat = inputmat if cpu_offloading and saved_inputmat is not None: - if offload_activation: - mark_activation_offload(saved_inputmat) - else: - mark_not_offload(saved_inputmat) + mark_activation_offload(saved_inputmat) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights @@ -1461,8 +1456,6 @@ class Linear(TransformerEngineBaseModule): cast tensor. In some scenarios, the input tensor is used by multiple modules, and saving the original input tensor may reduce the memory usage. Cannot work with FP8 DelayedScaling recipe. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ def __init__( @@ -1491,7 +1484,6 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, - offload_activation: bool = True, name: Optional[str] = None, ) -> None: super().__init__(name) @@ -1507,7 +1499,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1956,7 +1947,6 @@ def forward( wgrad_store=wgrad_store, # misc cpu_offloading=is_cpu_offload_enabled(), - offload_activation=self.offload_activation, is_grad_enabled=is_grad_enabled, ) out, new_weight_workspace = linear_fn( diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 2f25ffd082..f4beffe90c 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -60,20 +60,12 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): pass. This will typically reduce memory usage but require extra compute and increase numerical error. This feature is highly experimental. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ - def __init__( - self, - *, - cache_quantized_input: bool = False, - offload_activation: bool = True, - ): + def __init__(self, *, cache_quantized_input: bool = False): super().__init__() self.cache_quantized_input: bool = cache_quantized_input - self.offload_activation: bool = offload_activation @abc.abstractmethod def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: @@ -123,10 +115,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x) - else: - mark_not_offload(x) + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -370,21 +359,13 @@ class ScaledSReLU(BasicOperation): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ num_extra_inputs: int = 1 - def __init__( - self, - *, - activation_recompute_in_mlp: bool = False, - offload_activation: bool = True, - ) -> None: + def __init__(self, *, activation_recompute_in_mlp: bool = False) -> None: super().__init__() self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp - self.offload_activation: bool = offload_activation def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( @@ -434,10 +415,7 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x) - else: - mark_not_offload(x, scales) + mark_activation_offload(x) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 93bd17a013..6b17d66fcd 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,7 +13,7 @@ import torch from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -91,8 +91,6 @@ class BasicLinear(BasicOperation): Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly experimental. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -109,10 +107,8 @@ def __init__( rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, userbuffers_options: Optional[dict[str, Any]] = None, - offload_activation: bool = True, ) -> None: super().__init__() - self.offload_activation: bool = offload_activation # Weight tensor dimensions self.in_features: int = in_features @@ -1054,11 +1050,10 @@ def op_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(saved_input) - else: - mark_not_offload(saved_input) - mark_not_offload(saved_weight) + # No special CPU offloading logic is needed for weights. saved_weight is + # either self.weight (nn.Parameter, auto-excluded from offload) or a + # workspace freshly created each forward pass. + mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) ctx.with_quantized_compute = with_quantized_compute and backward_override is None ctx.backward_override = backward_override diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index fc113183cb..8850604aad 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,7 +9,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize @@ -25,10 +25,9 @@ class Dropout(BasicOperation): """ - def __init__(self, p: float, *, offload_activation: bool = True) -> None: + def __init__(self, p: float) -> None: super().__init__() self.dropout_probability: float = p - self.offload_activation: bool = offload_activation def op_forward( self, @@ -73,10 +72,7 @@ def op_forward( # Save context for backward if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(mask) - else: - mark_not_offload(mask) + mark_activation_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 3ac7c2e69e..70720c2eb7 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -15,7 +15,6 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload from ...cpp_extensions import general_grouped_gemm, general_grouped_gemm_for_grouped_tensor from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore @@ -24,6 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer @@ -105,8 +105,6 @@ class GroupedLinear(BasicOperation): additional extra input and adds ``bias * scales`` instead of ``bias`` in the forward pass. The scale tensor has shape ``(total_tokens,)`` and is split according to the split sizes. - offload_activation : bool, default = ``True`` - Offload saved input activation tensors when CPU offload is enabled. """ @@ -128,12 +126,10 @@ def __init__( single_grouped_bias: bool = False, delay_wgrad_compute: bool = False, scale_bias: bool = False, - offload_activation: bool = True, ) -> None: super().__init__() self._scale_bias: bool = scale_bias and bias - self.offload_activation: bool = offload_activation if self._scale_bias: self.num_extra_inputs = 2 @@ -1032,25 +1028,23 @@ def fuser_forward_save_ctx( ctx = basic_op_ctxs[0] + # Activation CPU offloading + # Note: No special logic is needed for weights. They are + # either nn.Parameter (auto-excluded from offload) or are + # temporary workspaces freshly created in each forward pass. if is_cpu_offload_enabled(): - saved_tensors = tensors_to_save[0] - activation_start = 4 if self._scale_bias else 3 - activation_count = 1 if use_grouped_tensor_path else self.num_groups - activation_end = activation_start + activation_count - activation_tensors = tuple( - tensor - for tensor in saved_tensors[activation_start:activation_end] - if tensor is not None - ) - weight_tensors = tuple( - tensor for tensor in saved_tensors[activation_end:] if tensor is not None - ) - - if self.offload_activation: - mark_activation_offload(*activation_tensors) + saved = tensors_to_save[0] + offset = 4 if self._scale_bias else 3 + if use_grouped_tensor_path: + # Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights] + grouped_x = saved[offset] + if grouped_x is not None: + mark_activation_offload(grouped_x) else: - mark_not_offload(*activation_tensors) - mark_not_offload(*weight_tensors) + # Layout: [split_sizes, None, None, (scales?), *xs, *ws] + live_xs = [t for t in saved[offset:offset + self.num_groups] if t is not None] + if live_xs: + mark_activation_offload(*live_xs) ctx.save_for_backward(*tensors_to_save[0]) @@ -1136,6 +1130,10 @@ def _fuser_forward_split_quantize( xs = tex.split_quantize(x, split_sizes_int, input_quantizers) else: xs = torch.split(x, split_sizes_int) + if is_cpu_offload_enabled(): + live_xs = [t for t in xs if t is not None] + if live_xs: + start_offload(*live_xs) # Allocate output tensor in_shape = list(input_.size()) @@ -1241,6 +1239,9 @@ def _fuser_forward_grouped_tensor( tensor_offsets=base_split_offsets * self.in_features, ) + if is_cpu_offload_enabled() and grouped_x is not None: + start_offload(grouped_x) + # Build the weight GroupedTensor / list. if self.single_grouped_weight: # GroupedTensor diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index bdf4ef8c51..be155c9356 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -11,7 +11,7 @@ import torch from ...torch_version import torch_version -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -48,8 +48,6 @@ class L2Normalization(BasicOperation): batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propagation and activation recompute phase. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -59,11 +57,9 @@ def __init__( eps: float = 1e-6, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, - offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps - self.offload_activation: bool = offload_activation # JIT warmup for L2Normalization fused operations if seq_length and micro_batch_size: @@ -107,10 +103,7 @@ def op_forward( # Save state for backward pass if requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x, rsqrt_norm) - else: - mark_not_offload(x, rsqrt_norm) + mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 74005163e8..3fda5145c6 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -64,8 +64,6 @@ class LayerNorm(BasicOperation): For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", "inference"). - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -78,12 +76,10 @@ def __init__( dtype: Optional[torch.dtype] = None, zero_centered_gamma: bool = False, sm_margin: int | dict[str, int] = 0, - offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps self.zero_centered_gamma: bool = zero_centered_gamma - self.offload_activation: bool = offload_activation # Parameter shape if not isinstance(normalized_shape, Iterable): @@ -221,10 +217,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x, means, rstdevs) - else: - mark_not_offload(x, means, rstdevs) + mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index c7d8fe1e69..1d8d8be971 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -63,8 +63,6 @@ class RMSNorm(BasicOperation): For more fine-grained control, provide a dict with the SM margin at each compute stage ("forward", "backward", "inference"). - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -77,12 +75,10 @@ def __init__( dtype: Optional[torch.dtype] = None, zero_centered_gamma: bool = False, sm_margin: int = 0, - offload_activation: bool = True, ) -> None: super().__init__() self.eps: float = eps self.zero_centered_gamma: bool = zero_centered_gamma - self.offload_activation: bool = offload_activation # Parameter shape if not isinstance(normalized_shape, Iterable): @@ -202,10 +198,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x, rstdevs) - else: - mark_not_offload(x, rstdevs) + mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 0df720311c..02f330ede3 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -12,7 +12,7 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -67,8 +67,6 @@ class SwiGLU(BasicOperation): when the interleave size is 2). This data format is highly experiental and is primarily intended to support some advanced fused kernels. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -77,12 +75,10 @@ def __init__( *, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, - offload_activation: bool = True, ): super().__init__() self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size - self.offload_activation: bool = offload_activation def op_forward( self, @@ -132,10 +128,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(input_) - else: - mark_not_offload(input_) + mark_activation_offload(input_) ctx.save_for_backward(input_) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -225,8 +218,6 @@ class ClampedSwiGLU(BasicOperation): When set, the GLU activations will use an experimental block interleaved format. See the corresponding option in the SwiGLU operation for more details. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -238,7 +229,6 @@ def __init__( glu_linear_offset: float = 1.0, cache_quantized_input: bool = False, glu_interleave_size: Optional[int] = None, - offload_activation: bool = True, ): super().__init__() self.limit: float = limit @@ -246,7 +236,6 @@ def __init__( self.glu_linear_offset: float = glu_linear_offset self.cache_quantized_input: bool = cache_quantized_input self.glu_interleave_size: Optional[int] = glu_interleave_size - self.offload_activation: bool = offload_activation def _tex_clamped_swiglu_forward( self, @@ -323,10 +312,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(x) - else: - mark_not_offload(x) + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -396,12 +382,10 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, - offload_activation: bool = True, ) -> None: super().__init__() self.glu_interleave_size: Optional[int] = glu_interleave_size self.activation_recompute_in_mlp: bool = activation_recompute_in_mlp - self.offload_activation: bool = offload_activation def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -479,10 +463,7 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - if self.offload_activation: - mark_activation_offload(input_) - else: - mark_not_offload(input_, scales) + mark_activation_offload(input_) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype @@ -574,8 +555,6 @@ class ScaledSwiGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -606,8 +585,6 @@ class ScaledClampedQGeGLU(_ScaledGLU): activation_recompute_in_mlp : bool, default = ``False`` Enable fused grouped MLP kernels to recompute activation outputs during backward when supported instead of saving them. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. limit : float, default ``7.0`` Clamp limit (see :class:`ClampedSwiGLU`). alpha : float, default ``1.702`` @@ -623,7 +600,6 @@ def __init__( glu_interleave_size: Optional[int] = None, *, activation_recompute_in_mlp: bool = False, - offload_activation: bool = True, limit: float = 7.0, alpha: float = 1.702, glu_linear_offset: float = 1.0, @@ -631,13 +607,11 @@ def __init__( super().__init__( glu_interleave_size, activation_recompute_in_mlp=activation_recompute_in_mlp, - offload_activation=offload_activation, ) self._clamped: ClampedSwiGLU = ClampedSwiGLU( limit=limit, alpha=alpha, glu_linear_offset=glu_linear_offset, - offload_activation=offload_activation, ) def _glu_forward(self, swiglu_in: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 0e34131d2a..2d777022a1 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -326,6 +326,12 @@ def fuser_forward( or isinstance(fc1_input_quantizer, NVFP4Quantizer) and isinstance(input_quantizer, NVFP4Quantizer) ): + # GroupedTensor is a torch.Tensor subclass, so the CPU offload + # infrastructure's prepare_for_saving treats it as a plain tensor + # and does not decompose it into its component data tensors. By + # repacking into a GroupedTensorStorage (not a torch.Tensor), we + # ensure the fuser's prepare_for_saving call correctly decomposes + # the activation before save_for_backward. grouped_fc1_x = GroupedTensorStorage( shape=input_.logical_shape, dtype=input_.fake_dtype, @@ -722,8 +728,6 @@ def fuser_forward( mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) activation_op = self.basic_ops[1] cpu_offloading = is_cpu_offload_enabled() - offload_fc1_activation = fc1_op.offload_activation - offload_moe_activation = activation_op.offload_activation activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( getattr(activation_op, "activation_recompute_in_mlp", False) @@ -745,18 +749,20 @@ def fuser_forward( grouped_fc_x.rowwise_data = None grouped_fc_x.scale_inv = None + if cpu_offloading: + activation_tensors = [ + t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) + if t is not None + ] + start_offload(*activation_tensors) + mark_activation_offload(*activation_tensors) + # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, # grouped_fc1_x, *fc1_weight_tensors] fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) - if cpu_offloading: - if offload_fc1_activation: - mark_activation_offload(grouped_fc1_x) - else: - mark_not_offload(grouped_fc1_x) - mark_not_offload(*fc1_weight_tensors) fc1_ctx.save_for_backward( split_sizes, base_split_offsets, @@ -775,11 +781,6 @@ def fuser_forward( fc1_ctx.weight_requires_grad = weight_requires_grad # Activation - if cpu_offloading: - if offload_moe_activation: - mark_activation_offload(activation_in, scales) - else: - mark_not_offload(activation_in, scales) activation_ctx.save_for_backward(activation_in, scales) activation_ctx.extra_input_requires_grad = True activation_ctx.input_requires_grad = True @@ -795,12 +796,6 @@ def fuser_forward( fc2_weight_tensors = ( [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight ) - if cpu_offloading: - if saved_grouped_fc2_x is not None: - # FC2 input is saved for FC2 wgrad, but it is not the Megatron moe_act - # activation target controlled above. Keep this extra saved tensor resident. - mark_not_offload(saved_grouped_fc2_x) - mark_not_offload(*fc2_weight_tensors) fc2_saved: list[Optional[torch.Tensor | GroupedTensorStorage]] = [ split_sizes, base_split_offsets, diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index f7724b153c..8df929f799 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias @@ -130,11 +130,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - if linear_op.offload_activation: - mark_activation_offload(saved_input) - else: - mark_not_offload(saved_input) - mark_not_offload(saved_weight) + mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 71bb65a5a3..5376a7d264 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias @@ -127,11 +127,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - if linear_op.offload_activation: - mark_activation_offload(saved_input) - else: - mark_not_offload(saved_input) - mark_not_offload(saved_weight) + mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 803227c204..abeb39adfa 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale @@ -108,11 +108,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - if linear_op.offload_activation: - mark_activation_offload(saved_input) - else: - mark_not_offload(saved_input) - mark_not_offload(saved_weight) + mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 323225baa8..3a8ff5438d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,7 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, mark_not_offload +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...quantization import FP8GlobalStateManager from ...module.base import ( @@ -355,11 +355,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: if is_cpu_offload_enabled(): - if linear_op.offload_activation: - mark_activation_offload(x_local) - else: - mark_not_offload(x_local) - mark_not_offload(w) + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 7b2fc0f502..c6ca4786b8 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -56,8 +56,6 @@ class Linear(FusedOperation): there is no guarantee that ``grad`` will be set or be meaningful. This is primarily intended to integrate with Megatron-LM. - offload_activation : bool, default = ``True`` - Offload saved activation tensors when CPU offload is enabled. """ @@ -74,7 +72,6 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, - offload_activation: bool = True, ) -> None: # Tensor parallel configuration @@ -107,7 +104,6 @@ def __init__( "sequence_parallel": sequence_parallel, "rng_state_tracker_function": rng_state_tracker_function, "accumulate_into_main_grad": accumulate_into_main_grad, - "offload_activation": offload_activation, } bias_kwargs = { "size": out_features, From 1da42fd3d64572db00194283bafe25c7d9fd2170 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Jun 2026 23:38:59 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/grouped_linear.py | 2 +- transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 70720c2eb7..8bec796fb5 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -1042,7 +1042,7 @@ def fuser_forward_save_ctx( mark_activation_offload(grouped_x) else: # Layout: [split_sizes, None, None, (scales?), *xs, *ws] - live_xs = [t for t in saved[offset:offset + self.num_groups] if t is not None] + live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None] if live_xs: mark_activation_offload(*live_xs) diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 2d777022a1..e22652d7b6 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -751,8 +751,7 @@ def fuser_forward( if cpu_offloading: activation_tensors = [ - t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) - if t is not None + t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None ] start_offload(*activation_tensors) mark_activation_offload(*activation_tensors) From aae2f2ef7d8bd1bf3129bdd021ecfa0306bd59f8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 6 Jun 2026 02:11:07 +0000 Subject: [PATCH 6/6] Construct internal grouped tensors within grouped linear and grouped MLP GroupedTensor should only be used when exposed externally. Otherwise GroupedTensorStorage has less CPU overhead. There also seems to be some issue with CPU offloading that has not yet been root-caused. Signed-off-by: Tim Moon --- .../pytorch/module/grouped_linear.py | 15 ++++++++------ .../pytorch/ops/basic/grouped_linear.py | 20 +++++++++---------- .../pytorch/ops/fused/forward_grouped_mlp.py | 2 +- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c1d45511df..ac304d3379 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -16,7 +16,10 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor +from transformer_engine.pytorch.tensor.grouped_tensor import ( + GroupedTensor, + GroupedTensorStorage, +) from .base import ( get_dummy_wgrad, quantize_weight, @@ -135,9 +138,9 @@ def _make_grouped_tensor( base_split_offsets: torch.Tensor, last_dim: int, dtype: torch.dtype, - ) -> GroupedTensor: - """Wrap a packed 2D buffer as a varying-first-dimension GroupedTensor.""" - return GroupedTensor( + ) -> GroupedTensorStorage: + """Wrap a packed 2D buffer as a varying-first-dimension GroupedTensorStorage.""" + return GroupedTensorStorage( shape=(data.size(0), last_dim), dtype=dtype, num_tensors=num_gemms, @@ -154,13 +157,13 @@ def _make_grouped_bias( num_gemms: int, out_features: int, dtype: torch.dtype, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """Pack per-GEMM biases into the grouped GEMM bias format.""" bias_data = torch.stack( [_GroupedLinear._maybe_dequantize(bias, dtype) for bias in biases], dim=0, ).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_gemms, out_features), dtype=dtype, num_tensors=num_gemms, diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 8bec796fb5..73e328c9d1 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -784,7 +784,7 @@ def _get_grouped_weight_for_gemm( columnwise_usage: bool, with_quantized_compute: bool, dtype: torch.dtype, - ) -> GroupedTensor: + ) -> GroupedTensorStorage: """Prepare weights for ``general_grouped_gemm_for_grouped_tensor``. Supports MXFP8/BF16/FP16 compute paths. """ @@ -801,7 +801,7 @@ def _get_grouped_weight_for_gemm( weight_parts = weight_param.split_into_quantized_tensors() dequantized = [maybe_dequantize(w, dtype) for w in weight_parts] weight_data = torch.stack(dequantized, dim=0).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups * self.out_features, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -815,7 +815,7 @@ def _get_grouped_weight_for_gemm( if weight_param.rowwise_data.dtype == dtype: return weight_param weight_data = weight_param.rowwise_data.to(dtype=dtype) - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups * self.out_features, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -867,8 +867,8 @@ def _get_weight_tensors(self) -> list[torch.nn.Parameter]: def _get_grouped_bias_for_gemm( self, dtype: torch.dtype, - ) -> Optional[torch.Tensor]: - """Build a uniform GroupedTensor of per-group biases for the cublas + ) -> Optional[GroupedTensorStorage]: + """Build a uniform GroupedTensorStorage of per-group biases for the cublas grouped GEMM. Each group expects a (1, out_features) bias vector. Returns ``None`` @@ -889,7 +889,7 @@ def _get_grouped_bias_for_gemm( ] bias_data = torch.stack(bias_list, dim=0).contiguous() - return GroupedTensor( + return GroupedTensorStorage( shape=(num_groups, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1229,7 +1229,7 @@ def _fuser_forward_grouped_tensor( grouped_x = tex.group_quantize(x, input_quantizer, num_groups, split_sizes) else: # No quantize: wrap the contiguous high-precision buffer. - grouped_x = GroupedTensor( + grouped_x = GroupedTensorStorage( shape=(total_tokens, self.in_features), dtype=dtype, num_tensors=num_groups, @@ -1265,7 +1265,7 @@ def _fuser_forward_grouped_tensor( # Allocate output buffer and wrap as a GroupedTensor view. out_shape = original_shape[:-1] + [self.out_features] out = torch.empty(out_shape, dtype=dtype, device=device) - grouped_out = GroupedTensor( + grouped_out = GroupedTensorStorage( shape=(total_tokens, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1593,7 +1593,7 @@ def _fuser_backward_grouped_tensor( else: dy_2d = maybe_dequantize(dy_2d, dtype) # Wrap BF16/FP16 buffer as a GroupedTensor for grouped gemm - grouped_dy = GroupedTensor( + grouped_dy = GroupedTensorStorage( shape=(total_tokens, self.out_features), dtype=dtype, num_tensors=num_groups, @@ -1629,7 +1629,7 @@ def _fuser_backward_grouped_tensor( if ctx.input_requires_grad: grad_input_shape = list(grad_output.size())[:-1] + [self.in_features] grad_input = torch.empty(grad_input_shape, dtype=dtype, device=device) - grouped_grad_input = GroupedTensor( + grouped_grad_input = GroupedTensorStorage( shape=(total_tokens, self.in_features), dtype=dtype, num_tensors=num_groups, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index e22652d7b6..f03ccc15b5 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -619,7 +619,7 @@ def fuser_forward( else: fc2_out_buf = fc2_out_buf + token_bias else: - fc2_out_grouped = GroupedTensor( + fc2_out_grouped = GroupedTensorStorage( shape=(in_shape[0], fc2_weight_shape[0]), dtype=dtype, num_tensors=num_groups,