Skip to content
Merged
36 changes: 36 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,42 @@
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]


def test_basic_operation_activation_offloading_policy(monkeypatch):
Comment thread
timmoon10 marked this conversation as resolved.
Outdated
"""BasicOperation should expose a public opt-out for saved activation CPU offload."""
import transformer_engine.pytorch.ops.op as op_module

calls = []
tensor = torch.empty(1)
tensor_id = id(tensor)
op = te_ops.Identity()

monkeypatch.setattr(
op_module,
"mark_activation_offload",
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
)
monkeypatch.setattr(
op_module,
"mark_not_offload",
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
)
monkeypatch.setattr(op_module, "is_cpu_offload_enabled", lambda: True)

op.mark_for_cpu_offload_if_needed(tensor, None)
assert calls == [("mark", [tensor_id])]

calls.clear()
op.set_activation_offloading(False)
op.mark_for_cpu_offload_if_needed(tensor)
assert calls == [("skip", [tensor_id])]

calls.clear()
op.set_activation_offloading(True)
op.mark_for_cpu_offload_if_needed(tensor)
assert calls == [("mark", [tensor_id])]


# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
Expand Down
7 changes: 2 additions & 5 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -114,8 +113,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.mark_for_cpu_offload_if_needed(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -414,8 +412,7 @@ def fuser_forward(

ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.mark_for_cpu_offload_if_needed(x)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
13 changes: 7 additions & 6 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
Expand Down Expand Up @@ -1049,11 +1048,13 @@ def op_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
mark_activation_offload(saved_input)

# Activation CPU offloading
# Note: No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
self.mark_for_cpu_offload_if_needed(saved_input)

ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
Expand Down Expand Up @@ -71,8 +70,7 @@ def op_forward(

# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
self.mark_for_cpu_offload_if_needed(mask)
ctx.save_for_backward(mask)
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
Expand Down
23 changes: 9 additions & 14 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpu_offload import is_cpu_offload_enabled, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
Expand Down Expand Up @@ -1032,19 +1032,14 @@ def fuser_forward_save_ctx(
# Note: No special logic is needed for weights. They are
# either nn.Parameter (auto-excluded from offload) or are
# temporary workspaces freshly created in each forward pass.
if is_cpu_offload_enabled():
saved = tensors_to_save[0]
offset = 4 if self._scale_bias else 3
if use_grouped_tensor_path:
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
if grouped_x is not None:
mark_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
if live_xs:
mark_activation_offload(*live_xs)
saved = tensors_to_save[0]
offset = 4 if self._scale_bias else 3
if use_grouped_tensor_path:
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
self.mark_for_cpu_offload_if_needed(saved[offset])
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
self.mark_for_cpu_offload_if_needed(saved[offset : offset + self.num_groups])
Comment thread
timmoon10 marked this conversation as resolved.

ctx.save_for_backward(*tensors_to_save[0])

Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/l2normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch

from ...torch_version import torch_version
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
Expand Down Expand Up @@ -102,8 +101,7 @@ def op_forward(

# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
self.mark_for_cpu_offload_if_needed(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)

return y
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -216,8 +215,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
self.mark_for_cpu_offload_if_needed(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype

Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -197,8 +196,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
self.mark_for_cpu_offload_if_needed(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype

Expand Down
10 changes: 3 additions & 7 deletions transformer_engine/pytorch/ops/basic/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -127,8 +126,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.mark_for_cpu_offload_if_needed(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -311,8 +309,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.mark_for_cpu_offload_if_needed(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -462,8 +459,7 @@ def fuser_forward(
# Save state for backward pass
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.mark_for_cpu_offload_if_needed(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpu_offload import is_cpu_offload_enabled, start_offload
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
from ...quantization import Recipe
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
Expand Down Expand Up @@ -170,7 +170,7 @@ def fuser_forward(
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
fc1_op, _, fc2_op = self.basic_ops
fc1_op, activation_op, fc2_op = self.basic_ops
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs

# Tensor properties
Expand Down Expand Up @@ -726,7 +726,6 @@ def fuser_forward(
# Save state for backward pass
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
Expand Down Expand Up @@ -754,7 +753,9 @@ def fuser_forward(
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
]
start_offload(*activation_tensors)
mark_activation_offload(*activation_tensors)
fc1_op.mark_for_cpu_offload_if_needed(grouped_fc1_x)
activation_op.mark_for_cpu_offload_if_needed(activation_in)
fc2_op.mark_for_cpu_offload_if_needed(saved_grouped_fc2_x)

# FC1 saved-tensor layout.
# [split_sizes, base_split_offsets, split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
Expand Down Expand Up @@ -129,8 +128,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.mark_for_cpu_offload_if_needed(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
Expand Down Expand Up @@ -126,8 +125,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.mark_for_cpu_offload_if_needed(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
Expand Down Expand Up @@ -107,8 +106,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.mark_for_cpu_offload_if_needed(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size
from ...quantization import FP8GlobalStateManager
from ...module.base import (
Expand Down Expand Up @@ -354,8 +353,7 @@ def fuser_forward(

# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op.mark_for_cpu_offload_if_needed(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
Expand Down
Loading
Loading