[PyTorch] Support cudagraph recomputation#2518
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables cudagraph recomputation support by making two key changes: replacing
The changes align with PyTorch's cudagraph requirements where Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant make_graphed_callables
participant _CheckpointFunction
participant _none_grad_context_wrapper
participant RNGTracker
participant CUDAGraph
User->>make_graphed_callables: Call with forward function
Note over make_graphed_callables: Warmup Phase
make_graphed_callables->>RNGTracker: get_states()
RNGTracker-->>make_graphed_callables: tracker_states
make_graphed_callables->>make_graphed_callables: Check if graph_safe via is_graph_safe_rng_state()
make_graphed_callables->>make_graphed_callables: Get CUDA RNG state (graph_safe=True/False)
make_graphed_callables->>make_graphed_callables: Execute forward pass
alt Training Mode
make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
Note over _none_grad_context_wrapper: Save original grads<br/>Set input.grad = None
make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
Note over make_graphed_callables: Grads accumulate in input.grad
make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
end
Note over make_graphed_callables: Capture Phase
make_graphed_callables->>CUDAGraph: Capture forward graph
CUDAGraph-->>make_graphed_callables: fwd_graph
alt Training Mode
make_graphed_callables->>CUDAGraph: Capture backward graph
make_graphed_callables->>_none_grad_context_wrapper: Enter with inputs
make_graphed_callables->>make_graphed_callables: torch.autograd.backward()
make_graphed_callables->>make_graphed_callables: Collect grad_inputs from input.grad
_none_grad_context_wrapper->>_none_grad_context_wrapper: Restore original grads
CUDAGraph-->>make_graphed_callables: bwd_graph
end
Note over _CheckpointFunction: During Recomputation
_CheckpointFunction->>RNGTracker: get_states()
RNGTracker-->>_CheckpointFunction: tracker_states
_CheckpointFunction->>_CheckpointFunction: Determine graph_safe_rng_state
_CheckpointFunction->>_CheckpointFunction: Get/Set RNG states (graph_safe=True/False)
_CheckpointFunction->>_CheckpointFunction: Recompute forward
_CheckpointFunction->>_CheckpointFunction: Restore RNG states (graph_safe=True/False)
|
Signed-off-by: Robin Zhang <robinz@nvidia.com>
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/pytorch/graph.py, line 66-77 (link)logic: potential race condition if
inputstuple contains non-tensor itemsThe code assumes all items in
inputsare tensors with.gradattributes, but the wrapper doesn't validate this. If non-tensor items slip through, accessing.gradwould raiseAttributeErrorAre all items in the inputs tuple guaranteed to be tensors at the call sites, or could there be edge cases with non-tensor inputs?
2 files reviewed, 1 comment
The inputs are from |
|
Hi @ksivaman could you help review? Thanks! |
|
/te-ci pytorch L0 L1 |
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/pytorch/graph.py, line 71-77 (link)style: The original gradient values are restored after the context exits. Verify this works correctly when gradients are already present (accumulation scenario).
2 files reviewed, 1 comment
Description
Support cudagraph recomputation with two changes:
autograd.gradwithautograd.backwardin cudagraph capturing.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: