Skip to content
Open
40 changes: 40 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]


def test_basic_operation_activation_offloading_policy(monkeypatch):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codex put this in such a weird and unmotivated place. No taste at all.

"""BasicOperation should expose a public opt-out for saved activation CPU offload."""
import transformer_engine.pytorch.cpu_offload as cpu_offload

calls = []
tensor = torch.empty(1)
tensor_id = id(tensor)
op = te_ops.Identity()

monkeypatch.setattr(
cpu_offload,
"start_offload",
lambda *tensors: calls.append(("start", [id(t) for t in tensors])),
)
monkeypatch.setattr(
cpu_offload,
"mark_activation_offload",
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
)
monkeypatch.setattr(
cpu_offload,
"mark_not_offload",
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
)
Comment thread
lhb8125 marked this conversation as resolved.
Outdated

op.maybe_mark_and_start_activation_offload(tensor, None, start=True)
assert calls == [("start", [tensor_id]), ("mark", [tensor_id])]

calls.clear()
op.set_activation_offloading(False)
op.maybe_mark_and_start_activation_offload(tensor, start=True)
assert calls == [("skip", [tensor_id])]

calls.clear()
op.set_activation_offloading(True)
op.maybe_mark_and_start_activation_offload(tensor, start=True, mark=False)
assert calls == [("start", [tensor_id])]


# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -115,7 +115,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -415,7 +415,7 @@ def fuser_forward(
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch

from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
Expand Down Expand Up @@ -1049,11 +1049,11 @@ def op_forward(
else:
saved_input = x_local
saved_weight = w
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
if is_cpu_offload_enabled():
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
mark_activation_offload(saved_input)
self.maybe_mark_and_start_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...tensor import Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
Expand Down Expand Up @@ -72,7 +72,7 @@ def op_forward(
# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
self.maybe_mark_and_start_activation_offload(mask)
ctx.save_for_backward(mask)
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import transformer_engine_torch as tex
from ...constants import DType
from ...cpp_extensions import general_grouped_gemm, general_grouped_gemm_for_grouped_tensor
from ...cpu_offload import is_cpu_offload_enabled
from ...distributed import CudaRNGStatesTracker
from ...module._common import WeightGradStore
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
Expand Down Expand Up @@ -1039,12 +1039,12 @@ def fuser_forward_save_ctx(
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
if grouped_x is not None:
mark_activation_offload(grouped_x)
self.maybe_mark_and_start_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
if live_xs:
mark_activation_offload(*live_xs)
self.maybe_mark_and_start_activation_offload(*live_xs)

ctx.save_for_backward(*tensors_to_save[0])

Expand Down Expand Up @@ -1133,7 +1133,7 @@ def _fuser_forward_split_quantize(
if is_cpu_offload_enabled():
live_xs = [t for t in xs if t is not None]
if live_xs:
start_offload(*live_xs)
self.maybe_mark_and_start_activation_offload(*live_xs, start=True, mark=False)

# Allocate output tensor
in_shape = list(input_.size())
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def _fuser_forward_grouped_tensor(
)

if is_cpu_offload_enabled() and grouped_x is not None:
start_offload(grouped_x)
self.maybe_mark_and_start_activation_offload(grouped_x, start=True, mark=False)

# Build the weight GroupedTensor / list.
if self.single_grouped_weight:
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/l2normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled
from ...torch_version import torch_version
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
Expand Down Expand Up @@ -103,7 +103,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
self.maybe_mark_and_start_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)

return y
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -217,7 +217,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
self.maybe_mark_and_start_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -198,7 +198,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
self.maybe_mark_and_start_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype

Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/ops/basic/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -128,7 +128,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_and_start_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -312,7 +312,7 @@ def op_forward(
# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -463,7 +463,7 @@ def fuser_forward(
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_and_start_activation_offload(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
13 changes: 5 additions & 8 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import Recipe
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
from ...utils import (
Expand Down Expand Up @@ -170,7 +170,7 @@ def fuser_forward(
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
fc1_op, _, fc2_op = self.basic_ops
fc1_op, activation_op, fc2_op = self.basic_ops
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs

# Tensor properties
Expand Down Expand Up @@ -726,7 +726,6 @@ def fuser_forward(
# Save state for backward pass
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
Expand All @@ -750,11 +749,9 @@ def fuser_forward(
grouped_fc_x.scale_inv = None

if cpu_offloading:
activation_tensors = [
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
]
start_offload(*activation_tensors)
mark_activation_offload(*activation_tensors)
fc1_op.maybe_mark_and_start_activation_offload(grouped_fc1_x, start=True)
activation_op.maybe_mark_and_start_activation_offload(activation_in, start=True)
fc2_op.maybe_mark_and_start_activation_offload(saved_grouped_fc2_x, start=True)

# FC1 saved-tensor layout.
# [split_sizes, base_split_offsets, split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
Expand Down Expand Up @@ -130,7 +130,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
Expand Down Expand Up @@ -127,7 +127,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
Expand Down Expand Up @@ -108,7 +108,7 @@ def fuser_forward(
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...cpu_offload import is_cpu_offload_enabled
from ...distributed import get_distributed_world_size
from ...quantization import FP8GlobalStateManager
from ...module.base import (
Expand Down Expand Up @@ -355,7 +355,7 @@ def fuser_forward(
# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op.maybe_mark_and_start_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
Expand Down
Loading
Loading