Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,48 @@
# Supported devices
_devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")]


def test_basic_operation_cpu_offloading_control(monkeypatch):
"""BasicOperation should expose a public opt-out for activation CPU offload."""
import transformer_engine.pytorch.cpu_offload as cpu_offload

calls = []
tensor = torch.empty(1)
tensor_id = id(tensor)
op = te_ops.Identity()

monkeypatch.setattr(cpu_offload, "is_cpu_offload_enabled", lambda: True)
monkeypatch.setattr(
cpu_offload,
"start_offload",
lambda *tensors: calls.append(("start", [id(t) for t in tensors])),
)
monkeypatch.setattr(
cpu_offload,
"mark_activation_offload",
lambda *tensors: calls.append(("mark", [id(t) for t in tensors])),
)
monkeypatch.setattr(
cpu_offload,
"mark_not_offload",
lambda *tensors: calls.append(("skip", [id(t) for t in tensors])),
)
Comment thread
lhb8125 marked this conversation as resolved.

op.maybe_mark_and_start_activation_offload(tensor, None, start=True)
assert calls == [("start", [tensor_id]), ("mark", [tensor_id])]

calls.clear()
op.disable_cpu_offloading()
op.maybe_mark_and_start_activation_offload(tensor, start=True)
assert calls == [("skip", [tensor_id])]

calls.clear()
op.enable_cpu_offloading()
monkeypatch.setattr(cpu_offload, "is_cpu_offload_enabled", lambda: False)
op.maybe_mark_and_start_activation_offload(tensor, start=True)
assert calls == []


# Supported quantization recipes
_quantization_list: list[Optional[str]] = [None]
if fp8_available:
Expand Down
7 changes: 2 additions & 5 deletions transformer_engine/pytorch/ops/basic/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -114,8 +113,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -414,8 +412,7 @@ def fuser_forward(

ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
Expand Down Expand Up @@ -1049,11 +1048,10 @@ def op_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
mark_activation_offload(saved_input)
# No special CPU offloading logic is needed for weights. saved_weight is
# either self.weight (nn.Parameter, auto-excluded from offload) or a
# workspace freshly created each forward pass.
self.maybe_mark_and_start_activation_offload(saved_input)
ctx.save_for_backward(saved_input, saved_weight)
ctx.with_quantized_compute = with_quantized_compute and backward_override is None
ctx.backward_override = backward_override
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Quantizer
from ...tensor.storage.float8_tensor_storage import Float8TensorStorage
from .._common import maybe_autocast_dtype, maybe_dequantize
Expand Down Expand Up @@ -71,8 +70,7 @@ def op_forward(

# Save context for backward
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(mask)
self.maybe_mark_and_start_activation_offload(mask)
ctx.save_for_backward(mask)
ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
Expand Down
33 changes: 13 additions & 20 deletions transformer_engine/pytorch/ops/basic/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe
from ...quantized_tensor import QuantizedTensorStorage
from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer
Expand Down Expand Up @@ -1032,19 +1031,16 @@ def fuser_forward_save_ctx(
# Note: No special logic is needed for weights. They are
# either nn.Parameter (auto-excluded from offload) or are
# temporary workspaces freshly created in each forward pass.
if is_cpu_offload_enabled():
saved = tensors_to_save[0]
offset = 4 if self._scale_bias else 3
if use_grouped_tensor_path:
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
if grouped_x is not None:
mark_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
if live_xs:
mark_activation_offload(*live_xs)
saved = tensors_to_save[0]
offset = 4 if self._scale_bias else 3
if use_grouped_tensor_path:
# Layout: [split_sizes, base_split_offsets, split_points, (scales?), grouped_x, *weights]
grouped_x = saved[offset]
self.maybe_mark_and_start_activation_offload(grouped_x)
else:
# Layout: [split_sizes, None, None, (scales?), *xs, *ws]
live_xs = [t for t in saved[offset : offset + self.num_groups] if t is not None]
self.maybe_mark_and_start_activation_offload(*live_xs)
Comment thread
lhb8125 marked this conversation as resolved.
Outdated

ctx.save_for_backward(*tensors_to_save[0])

Expand Down Expand Up @@ -1130,10 +1126,8 @@ def _fuser_forward_split_quantize(
xs = tex.split_quantize(x, split_sizes_int, input_quantizers)
else:
xs = torch.split(x, split_sizes_int)
if is_cpu_offload_enabled():
live_xs = [t for t in xs if t is not None]
if live_xs:
start_offload(*live_xs)
live_xs = [t for t in xs if t is not None]
self.maybe_mark_and_start_activation_offload(*live_xs, start=True)

# Allocate output tensor
in_shape = list(input_.size())
Expand Down Expand Up @@ -1239,8 +1233,7 @@ def _fuser_forward_grouped_tensor(
tensor_offsets=base_split_offsets * self.in_features,
)

if is_cpu_offload_enabled() and grouped_x is not None:
start_offload(grouped_x)
self.maybe_mark_and_start_activation_offload(grouped_x, start=True)

# Build the weight GroupedTensor / list.
if self.single_grouped_weight:
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/l2normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch

from ...torch_version import torch_version
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import (
l2normalization_fused,
l2normalization_fwd_fused,
Expand Down Expand Up @@ -102,8 +101,7 @@ def op_forward(

# Save state for backward pass
if requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rsqrt_norm)
self.maybe_mark_and_start_activation_offload(x, rsqrt_norm)
ctx.save_for_backward(x, rsqrt_norm)

return y
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -216,8 +215,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, means, rstdevs)
self.maybe_mark_and_start_activation_offload(x, means, rstdevs)
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype

Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...constants import TE_DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...export import is_in_onnx_export_mode
from ...tensor import Quantizer
from ...utils import (
Expand Down Expand Up @@ -197,8 +196,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x, rstdevs)
self.maybe_mark_and_start_activation_offload(x, rstdevs)
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype

Expand Down
10 changes: 3 additions & 7 deletions transformer_engine/pytorch/ops/basic/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import transformer_engine_torch as tex
from ...constants import DType
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
Expand Down Expand Up @@ -127,8 +126,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_and_start_activation_offload(input_)
ctx.save_for_backward(input_)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -311,8 +309,7 @@ def op_forward(

# Save state for backward pass
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x)
self.maybe_mark_and_start_activation_offload(x)
ctx.save_for_backward(x)
ctx.dtype = dtype
ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer
Expand Down Expand Up @@ -462,8 +459,7 @@ def fuser_forward(
# Save state for backward pass
ctx = basic_op_ctxs[0]
if ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(input_)
self.maybe_mark_and_start_activation_offload(input_)
ctx.input_requires_grad = True
ctx.extra_input_requires_grad = extra_input.requires_grad
ctx.dtype = dtype
Expand Down
14 changes: 4 additions & 10 deletions transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

import transformer_engine_torch as tex
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload, start_offload
from ...cpp_extensions import general_gemm, general_grouped_gemm_for_grouped_tensor
from ...quantization import Recipe
from ...tensor import NVFP4Quantizer, NVFP4Tensor, Quantizer
Expand Down Expand Up @@ -170,7 +169,7 @@ def fuser_forward(
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
# Get basic operations
fc1_op, _, fc2_op = self.basic_ops
fc1_op, activation_op, fc2_op = self.basic_ops
fc1_ctx, activation_ctx, fc2_ctx = basic_op_ctxs

# Tensor properties
Expand Down Expand Up @@ -726,8 +725,6 @@ def fuser_forward(
# Save state for backward pass
if requires_grad:
mark_grouped_tensor(grouped_fc1_x, activation_in, scales, grouped_fc2_x)
activation_op = self.basic_ops[1]
cpu_offloading = is_cpu_offload_enabled()
activation_is_srelu = isinstance(activation_op, ScaledSReLU)
activation_recompute_in_mlp = bool(
getattr(activation_op, "activation_recompute_in_mlp", False)
Expand All @@ -749,12 +746,9 @@ def fuser_forward(
grouped_fc_x.rowwise_data = None
grouped_fc_x.scale_inv = None

if cpu_offloading:
activation_tensors = [
t for t in (grouped_fc1_x, activation_in, saved_grouped_fc2_x) if t is not None
]
start_offload(*activation_tensors)
mark_activation_offload(*activation_tensors)
fc1_op.maybe_mark_and_start_activation_offload(grouped_fc1_x, start=True)
activation_op.maybe_mark_and_start_activation_offload(activation_in, start=True)
fc2_op.maybe_mark_and_start_activation_offload(saved_grouped_fc2_x, start=True)

# FC1 saved-tensor layout.
# [split_sizes, base_split_offsets, split_points,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import BasicLinear, Bias
Expand Down Expand Up @@ -129,8 +128,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, Bias
Expand Down Expand Up @@ -126,8 +125,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import torch

from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...quantization import FP8GlobalStateManager
from ...tensor import Quantizer
from ..basic import AddExtraInput, BasicLinear, ConstantScale
Expand Down Expand Up @@ -107,8 +106,7 @@ def fuser_forward(
else:
saved_input = x_local
saved_weight = w
if is_cpu_offload_enabled():
mark_activation_offload(saved_input)
linear_op.maybe_mark_and_start_activation_offload(saved_input)
linear_op_ctx.save_for_backward(saved_input, saved_weight)
linear_op_ctx.with_quantized_compute = (
with_quantized_compute and backward_override is None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...distributed import get_distributed_world_size
from ...quantization import FP8GlobalStateManager
from ...module.base import (
Expand Down Expand Up @@ -354,8 +353,7 @@ def fuser_forward(

# Save state for backward pass
if linear_op_ctx.requires_grad:
if is_cpu_offload_enabled():
mark_activation_offload(x_local)
linear_op.maybe_mark_and_start_activation_offload(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
Expand Down
Loading
Loading