Skip to content

Commit 0b41302

Browse files
committed
Use setter for activation offload policy
Signed-off-by: hongbinl <hongbinl@nvidia.com>
1 parent b93a9af commit 0b41302

2 files changed

Lines changed: 6 additions & 10 deletions

File tree

tests/pytorch/test_fusible_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ def test_basic_operation_activation_offloading_policy(monkeypatch):
101101
assert calls == [("start", [tensor_id]), ("mark", [tensor_id])]
102102

103103
calls.clear()
104-
op.disable_activation_offloading()
104+
op.set_activation_offloading(False)
105105
op.maybe_mark_and_start_activation_offload(tensor, start=True)
106106
assert calls == [("skip", [tensor_id])]
107107

108108
calls.clear()
109-
op.enable_activation_offloading()
109+
op.set_activation_offloading(True)
110110
op.maybe_mark_and_start_activation_offload(tensor, start=True, mark=False)
111111
assert calls == [("start", [tensor_id])]
112112

transformer_engine/pytorch/ops/op.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,13 @@ def __init__(self) -> None:
195195
def is_fused_op(self) -> bool:
196196
return False
197197

198-
def disable_activation_offloading(self, disabled: bool = True) -> None:
199-
"""Disable activation CPU offloading for tensors saved by this op.
198+
def set_activation_offloading(self, enabled: bool) -> None:
199+
"""Enable or disable activation CPU offloading for tensors saved by this op.
200200
201201
CPU offloading is controlled by the surrounding offload context. This setting only
202-
opts this operation's saved activation tensors out of that context.
202+
opts this operation's saved activation tensors in or out of that context.
203203
"""
204-
self.activation_offloading = not disabled
205-
206-
def enable_activation_offloading(self) -> None:
207-
"""Re-enable activation CPU offloading for tensors saved by this op."""
208-
self.disable_activation_offloading(False)
204+
self.activation_offloading = enabled
209205

210206
def maybe_mark_and_start_activation_offload(
211207
self,

0 commit comments

Comments
 (0)