Skip to content

[PyTorch] Add op-level activation offload opt-out API#3108

Open
lhb8125 wants to merge 8 commits into
NVIDIA:mainfrom
lhb8125:codex/te-op-offload-control
Open

[PyTorch] Add op-level activation offload opt-out API#3108
lhb8125 wants to merge 8 commits into
NVIDIA:mainfrom
lhb8125:codex/te-op-offload-control

Conversation

@lhb8125

@lhb8125 lhb8125 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

Follow-up to #3047.

This PR adds an op-level activation offload policy for saved activation tensors so downstream fused grouped MLP users can opt individual TE ops out of activation offloading without changing the surrounding CPU offload context.

  • add BasicOperation.set_activation_offloading(enabled)
  • route activation offload marking through BasicOperation.maybe_mark_activation_offload
  • use mark_not_offload for ops whose saved activations are opted out
  • keep global CPU offload context checks and start_offload calls at their original call sites
  • add unit coverage for the per-op mark/opt-out behavior

Testing

  • git diff --check
  • python3.12 -m py_compile on the changed TE op files and tests/pytorch/test_fusible_ops.py

Signed-off-by: hongbinl <hongbinl@nvidia.com>
@lhb8125 lhb8125 requested a review from timmoon10 as a code owner June 9, 2026 12:27
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 9, 2026
@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds a per-op activation offload opt-out API to BasicOperation, allowing downstream fused grouped MLP users to selectively exclude individual ops from CPU activation offloading without changing the surrounding context. A new set_activation_offloading(enabled) setter and maybe_mark_activation_offload(*tensors) helper are added to BasicOperation; all 15 changed files replace the direct mark_activation_offload(...) calls with self.maybe_mark_activation_offload(...) or <op>.maybe_mark_activation_offload(...).

  • op.py adds activation_offloading: bool = True state and the two new methods; when opted out, mark_not_offload is called instead of the (V2-no-op) mark_activation_offload, correctly setting _TE_do_not_offload on component tensors before they reach the fuser's save_for_backward hook.
  • forward_grouped_mlp.py replaces the single bulk mark_activation_offload(*activation_tensors) call with three per-op calls, letting each of fc1_op, activation_op, and fc2_op independently opt their saved tensor out of offloading.
  • The new test correctly patches transformer_engine.pytorch.ops.op.mark_activation_offload and mark_not_offload (the names imported into op.py's namespace), so the mock intercepts all calls from maybe_mark_activation_offload."

Confidence Score: 5/5

The change is safe to merge. All production call sites remain guarded by is_cpu_offload_enabled(), the V2 _TE_do_not_offload flag is set before tensors reach the fuser's save_for_backward hook, and the V1 path correctly receives offload=False through mark_not_offload.

The core logic in op.py is correct and None-safe. The fused-op wiring in forward_grouped_mlp.py correctly delegates per-op marking before tensors are collected by the fuser. The only finding is a minor inefficiency where opted-out tensors still receive a start_offload CUDA event that is immediately discarded; this does not affect correctness or memory safety.

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py — the start_offload call could be made to exclude opted-out tensors, but the current behavior is correct.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/op.py Adds activation_offloading flag, set_activation_offloading(), and maybe_mark_activation_offload() to BasicOperation; logic is correct and None-safe
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Replaces bulk mark_activation_offload call with per-op maybe_mark_activation_offload; start_offload still receives all tensors including opted-out ones, recording unnecessary CUDA events
tests/pytorch/test_fusible_ops.py New test correctly patches op_module.mark_activation_offload and op_module.mark_not_offload (the right namespace) and covers default, opted-out, and re-enabled states
transformer_engine/pytorch/ops/basic/grouped_linear.py Replaces mark_activation_offload calls in _save_ctx with self.maybe_mark_activation_offload; consistent with other basic ops

Sequence Diagram

sequenceDiagram
    participant User
    participant BasicOperation
    participant FusedOp as FusedOperation
    participant Fuser as OperationFuser
    participant CPUOffload as cpu_offload

    User->>BasicOperation: set_activation_offloading(False)
    Note over BasicOperation: activation_offloading = False

    rect rgb(240, 248, 255)
        Note over FusedOp: Forward pass
        FusedOp->>CPUOffload: start_offload(fc1_x, act_in, fc2_x)
        CPUOffload-->>FusedOp: records CUDA events on tensors
        FusedOp->>BasicOperation: fc1_op.maybe_mark_activation_offload(fc1_x)
        alt "activation_offloading == True"
            BasicOperation->>CPUOffload: mark_activation_offload(fc1_x)
        else "activation_offloading == False"
            BasicOperation->>CPUOffload: mark_not_offload(fc1_x)
            CPUOffload-->>BasicOperation: "sets _TE_do_not_offload=True on components"
        end
        FusedOp->>Fuser: "ctx.to_save = tensors"
    end

    rect rgb(255, 248, 240)
        Note over Fuser: Fuser collects and saves
        Fuser->>CPUOffload: "prepare_for_saving(*to_save)"
        CPUOffload-->>Fuser: decomposed component tensors
        Fuser->>Fuser: "func_ctx.save_for_backward(*components)"
        Fuser->>CPUOffload: push_tensor(fc1_component)
        CPUOffload-->>Fuser: "_TE_do_not_offload=True, return tensor (not offloaded)"
        Fuser->>CPUOffload: push_tensor(act_component)
        CPUOffload-->>Fuser: offloaded, return index
    end
Loading

Reviews (5): Last reviewed commit: "Patch activation offload test bound symb..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/ops/op.py Outdated
Signed-off-by: hongbinl <hongbinl@nvidia.com>
@lhb8125 lhb8125 changed the title [PyTorch] Add op-level CPU offload opt-out API [PyTorch] Add op-level activation offload opt-out API Jun 9, 2026
lhb8125 added 3 commits June 9, 2026 05:53
Signed-off-by: hongbinl <hongbinl@nvidia.com>
Signed-off-by: hongbinl <hongbinl@nvidia.com>
Signed-off-by: hongbinl <hongbinl@nvidia.com>
Comment thread transformer_engine/pytorch/ops/op.py Outdated
Comment on lines +218 to +222
from ..cpu_offload import ( # pylint: disable=import-outside-toplevel
mark_activation_offload,
mark_not_offload,
start_offload,
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't you import it once (e.g. at the top of this file)? There is a non-zero CPU overhead from importing the already-imported module.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Comment thread transformer_engine/pytorch/ops/op.py Outdated
Comment on lines +234 to +235
if mark:
mark_activation_offload(*tensors)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would potentiall mark the tensors multiple times as all callsites are just leaving the default value here. Why do you combine these 2 functions rather than having 2 functions?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good advice! We should leave start_offload() as it is.

@ptrendx ptrendx left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general question about the motivation of this feature - I believe we already have a heuristics to not offload tensors which are too small, is that not enough?

CC @pggPL

Signed-off-by: hongbinl <hongbinl@nvidia.com>
@lhb8125

lhb8125 commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

A general question about the motivation of this feature - I believe we already have a heuristics to not offload tensors which are too small, is that not enough?

CC @pggPL

This is to support the selective offloading for Nemotron model training. If using fused MLP, MCore doesn't know which tensor is from expert_fc1 or moe_act or expert_fc2, so we need to expose an API to manually set offloading strategy for different ops. cc @timmoon10

Signed-off-by: hongbinl <hongbinl@nvidia.com>
Comment thread tests/pytorch/test_fusible_ops.py
Signed-off-by: hongbinl <hongbinl@nvidia.com>
@ptrendx

ptrendx commented Jun 9, 2026

Copy link
Copy Markdown
Member

Ok, but does it actually matter to you that a specific tensor gets offloaded rather than getting the right amount of data to be offloaded?

@xrennvidia

Copy link
Copy Markdown
Collaborator

Ok, but does it actually matter to you that a specific tensor gets offloaded rather than getting the right amount of data to be offloaded?

Different activation tensors have different amount of data, selectively offloading activations is the way how we control the amount of data to be offloaded. If we offload too much, offloading latency can be exposed, if we offload too few, we get OOM. I think the essence of fine-grained offloading is to allow users to control the offloading of each tensor separately, the we can make the better perf tradeoffs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants