diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 7c75d11e3b..b97bccd1b2 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -71,6 +71,41 @@ # Supported devices _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] + +def test_basic_operation_activation_offloading_policy(monkeypatch): + """BasicOperation should expose a public opt-out for saved activation CPU offload.""" + import transformer_engine.pytorch.ops.op as op_module + + calls = [] + tensor = torch.empty(1) + tensor_id = id(tensor) + op = te_ops.Identity() + + monkeypatch.setattr( + op_module, + "mark_activation_offload", + lambda *tensors: calls.append(("mark", [id(t) for t in tensors])), + ) + monkeypatch.setattr( + op_module, + "mark_not_offload", + lambda *tensors: calls.append(("skip", [id(t) for t in tensors])), + ) + + op.maybe_mark_activation_offload(tensor, None) + assert calls == [("mark", [tensor_id])] + + calls.clear() + op.set_activation_offloading(False) + op.maybe_mark_activation_offload(tensor) + assert calls == [("skip", [tensor_id])] + + calls.clear() + op.set_activation_offloading(True) + op.maybe_mark_activation_offload(tensor) + assert calls == [("mark", [tensor_id])] + + # Supported quantization recipes _quantization_list: list[Optional[str]] = [None] if fp8_available: diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index f4beffe90c..709ea9f5c7 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -115,7 +115,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.maybe_mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -415,7 +415,7 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.maybe_mark_activation_offload(x) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 6b17d66fcd..9e241a5c12 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,7 +13,7 @@ import torch from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -1049,11 +1049,11 @@ def op_forward( else: saved_input = x_local saved_weight = w + # No special CPU offloading logic is needed for weights. saved_weight is + # either self.weight (nn.Parameter, auto-excluded from offload) or a + # workspace freshly created each forward pass. if is_cpu_offload_enabled(): - # No special CPU offloading logic is needed for weights. saved_weight is - # either self.weight (nn.Parameter, auto-excluded from offload) or a - # workspace freshly created each forward pass. - mark_activation_offload(saved_input) + self.maybe_mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) ctx.with_quantized_compute = with_quantized_compute and backward_override is None ctx.backward_override = backward_override diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 8850604aad..519d29ede2 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,7 +9,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...tensor import Quantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize @@ -72,7 +72,7 @@ def op_forward( # Save context for backward if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(mask) + self.maybe_mark_activation_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index 73e328c9d1..c7c1f33b5e 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -23,7 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload +from ...cpu_offload import is_cpu_offload_enabled, start_offload from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer @@ -1039,12 +1039,12 @@ def fuser_forward_save_ctx( # Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights] grouped_x = saved[offset] if grouped_x is not None: - mark_activation_offload(grouped_x) + self.maybe_mark_activation_offload(grouped_x) else: # Layout: [split_sizes, None, None, (scales?), *xs, *ws] live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None] if live_xs: - mark_activation_offload(*live_xs) + self.maybe_mark_activation_offload(*live_xs) ctx.save_for_backward(*tensors_to_save[0]) diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index be155c9356..eee1794439 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -10,8 +10,8 @@ import torch +from ...cpu_offload import is_cpu_offload_enabled from ...torch_version import torch_version -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -103,7 +103,7 @@ def op_forward( # Save state for backward pass if requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, rsqrt_norm) + self.maybe_mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3fda5145c6..c1a8255132 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -217,7 +217,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, means, rstdevs) + self.maybe_mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..a9079b6159 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,7 +14,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...export import is_in_onnx_export_mode from ...tensor import Quantizer from ...utils import ( @@ -198,7 +198,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x, rstdevs) + self.maybe_mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/swiglu.py b/transformer_engine/pytorch/ops/basic/swiglu.py index 02f330ede3..72c0286fff 100644 --- a/transformer_engine/pytorch/ops/basic/swiglu.py +++ b/transformer_engine/pytorch/ops/basic/swiglu.py @@ -11,8 +11,8 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled from ...constants import DType -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -128,7 +128,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(input_) + self.maybe_mark_activation_offload(input_) ctx.save_for_backward(input_) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -312,7 +312,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x) + self.maybe_mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer @@ -463,7 +463,7 @@ def fuser_forward( ctx = basic_op_ctxs[0] if ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(input_) + self.maybe_mark_activation_offload(input_) ctx.input_requires_grad = True ctx.extra_input_requires_grad = extra_input.requires_grad ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index f03ccc15b5..fa7d56a78d 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -13,7 +13,7 @@ import torch import transformer_engine_torch as tex -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload +from ...cpu_offload import is_cpu_offload_enabled, start_offload from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor from ...quantization import Recipe from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer @@ -170,7 +170,7 @@ def fuser_forward( basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: # Get basic operations - fc1_op, _, fc2_op = self.basic_ops + fc1_op, activation_op, fc2_op = self.basic_ops fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs # Tensor properties @@ -726,7 +726,6 @@ def fuser_forward( # Save state for backward pass if requires_grad: mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x) - activation_op = self.basic_ops[1] cpu_offloading = is_cpu_offload_enabled() activation_is_srelu = isinstance(activation_op, ScaledSReLU) activation_recompute_in_mlp = bool( @@ -754,7 +753,9 @@ def fuser_forward( t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None ] start_offload(*activation_tensors) - mark_activation_offload(*activation_tensors) + fc1_op.maybe_mark_activation_offload(grouped_fc1_x) + activation_op.maybe_mark_activation_offload(activation_in) + fc2_op.maybe_mark_activation_offload(saved_grouped_fc2_x) # FC1 saved-tensor layout. # [split_sizes, base_split_offsets, split_points, diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 8df929f799..8e5c5f841b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias @@ -130,7 +130,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.maybe_mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5376a7d264..d2e085e41b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias @@ -127,7 +127,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.maybe_mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index abeb39adfa..8852a60e92 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,7 +10,7 @@ import torch -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale @@ -108,7 +108,7 @@ def fuser_forward( saved_input = x_local saved_weight = w if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) + linear_op.maybe_mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( with_quantized_compute and backward_override is None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3a8ff5438d..07d66a3b3a 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,7 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm -from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...cpu_offload import is_cpu_offload_enabled from ...distributed import get_distributed_world_size from ...quantization import FP8GlobalStateManager from ...module.base import ( @@ -355,7 +355,7 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: if is_cpu_offload_enabled(): - mark_activation_offload(x_local) + linear_op.maybe_mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 86bd60ed9c..9e2c442ff5 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,6 +14,7 @@ import torch from transformer_engine.common.recipe import Recipe +from ..cpu_offload import mark_activation_offload, mark_not_offload from ..quantization import ( FP8GlobalStateManager, QuantizerRole, @@ -189,11 +190,36 @@ def __init__(self) -> None: # Objects for quantization self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None self._quantizers: Optional[dict[str, list[Quantizer]]] = None + self.activation_offloading: bool = True @property def is_fused_op(self) -> bool: return False + def set_activation_offloading(self, enabled: bool) -> None: + """Enable or disable activation CPU offloading for tensors saved by this op. + + CPU offloading is controlled by the surrounding offload context. This setting only + opts this operation's saved activation tensors in or out of that context. + """ + self.activation_offloading = enabled + + def maybe_mark_activation_offload(self, *tensors: Any) -> None: + """Mark saved activation tensors for CPU offloading in an active offload context. + + If activation offloading has been disabled for this op, mark the tensors so the + active offload context skips them. + """ + tensors = tuple(tensor for tensor in tensors if tensor is not None) + if not tensors: + return + + if not self.activation_offloading: + mark_not_offload(*tensors) + return + + mark_activation_offload(*tensors) + def num_quantizers( self, mode: str, # pylint: disable=unused-argument