Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]


def test_basic_operation_activation_offloading_policy(monkeypatch):
"""BasicOperation should expose a public opt-out for saved activation CPU offload."""
import transformer_engine.pytorch.cpu_offload as cpu_offload

calls = []
tensor = torch.empty(1)
tensor_id = id(tensor)
op = te_ops.Identity()

monkeypatch.setattr(
cpu_offload,
"mark_activation_offload",
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
)
monkeypatch.setattr(
cpu_offload,
"mark_not_offload",
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
)
Comment thread
lhb8125 marked this conversation as resolved.

op.maybe_mark_activation_offload(tensor, None)
assert calls == [("mark", [tensor_id])]

calls.clear()
op.set_activation_offloading(False)
op.maybe_mark_activation_offload(tensor)
assert calls == [("skip", [tensor_id])]

calls.clear()
op.set_activation_offloading(True)
op.maybe_mark_activation_offload(tensor)
assert calls == [("mark", [tensor_id])]


# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -115,7 +115,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -415,7 +415,7 @@ def fuser_forward(
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_activation_offload(x)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
Expand Down Expand Up @@ -1049,11 +1049,11 @@ def op_forward(
else:
saved_input = x_local
saved_weight = w
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
if is_cpu_offload_enabled():
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
mark_activation_offload(saved_input)
self.maybe_mark_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...tensor import Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
Expand Down Expand Up @@ -72,7 +72,7 @@ def op_forward(
# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
self.maybe_mark_activation_offload(mask)
ctx.save_for_backward(mask)
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpu_offload import is_cpu_offload_enabled, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
Expand Down Expand Up @@ -1039,12 +1039,12 @@ def fuser_forward_save_ctx(
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
if grouped_x is not None:
mark_activation_offload(grouped_x)
self.maybe_mark_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
if live_xs:
mark_activation_offload(*live_xs)
self.maybe_mark_activation_offload(*live_xs)

ctx.save_for_backward(*tensors_to_save[0])

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/l2normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled
from ...torch_version import torch_version
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
Expand Down Expand Up @@ -103,7 +103,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
self.maybe_mark_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)

return y
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -217,7 +217,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
self.maybe_mark_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -198,7 +198,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
self.maybe_mark_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype

Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/ops/basic/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -128,7 +128,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -312,7 +312,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -463,7 +463,7 @@ def fuser_forward(
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_activation_offload(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpu_offload import is_cpu_offload_enabled, start_offload
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
from ...quantization import Recipe
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
Expand Down Expand Up @@ -170,7 +170,7 @@ def fuser_forward(
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
fc1_op, _, fc2_op = self.basic_ops
fc1_op, activation_op, fc2_op = self.basic_ops
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs

# Tensor properties
Expand Down Expand Up @@ -726,7 +726,6 @@ def fuser_forward(
# Save state for backward pass
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
Expand Down Expand Up @@ -754,7 +753,9 @@ def fuser_forward(
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
]
start_offload(*activation_tensors)
mark_activation_offload(*activation_tensors)
fc1_op.maybe_mark_activation_offload(grouped_fc1_x)
activation_op.maybe_mark_activation_offload(activation_in)
fc2_op.maybe_mark_activation_offload(saved_grouped_fc2_x)

# FC1 saved-tensor layout.
# [split_sizes, base_split_offsets, split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
Expand Down Expand Up @@ -130,7 +130,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
Expand Down Expand Up @@ -127,7 +127,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
Expand Down Expand Up @@ -108,7 +108,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...distributed import get_distributed_world_size
from ...quantization import FP8GlobalStateManager
from ...module.base import (
Expand Down Expand Up @@ -355,7 +355,7 @@ def fuser_forward(
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op.maybe_mark_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
Expand Down
26 changes: 26 additions & 0 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from transformer_engine.common.recipe import Recipe
from ..cpu_offload import mark_activation_offload, mark_not_offload
from ..quantization import (
FP8GlobalStateManager,
QuantizerRole,
Expand Down Expand Up @@ -189,11 +190,36 @@ def __init__(self) -> None:
# Objects for quantization
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
self.activation_offloading: bool = True

@property
def is_fused_op(self) -> bool:
return False

def set_activation_offloading(self, enabled: bool) -> None:
"""Enable or disable activation CPU offloading for tensors saved by this op.

CPU offloading is controlled by the surrounding offload context. This setting only
opts this operation's saved activation tensors in or out of that context.
"""
self.activation_offloading = enabled

def maybe_mark_activation_offload(self, *tensors: Any) -> None:
"""Mark saved activation tensors for CPU offloading in an active offload context.

If activation offloading has been disabled for this op, mark the tensors so the
active offload context skips them.
"""
tensors = tuple(tensor for tensor in tensors if tensor is not None)
if not tensors:
return

if not self.activation_offloading:
mark_not_offload(*tensors)
return

mark_activation_offload(*tensors)

def num_quantizers(
self,
mode: str, # pylint: disable=unused-argument
Expand Down
Loading