Handle torch compile#2
Draft
shjwudp wants to merge 9 commits into
Draft
Conversation
sample_kwargs may contain non-tensor values (None, lists, bools) that tree_flatten passes through into the static input surface. Guard all .requires_grad and .data_ptr() accesses with None/type checks to avoid AttributeError crashes when frameworks pass mixed tensor/non-tensor keyword arguments. - _run_warmup_backward: guard .requires_grad with 'is not None' - num_required_grad_sample_args: handle non-tensor with isinstance check - backward capture inputs: guard .requires_grad with 'is not None' - Graphed.forward copy loop: guard inputs[i] with 'is not None'
Some frameworks pass captured kwargs as positional args during replay (e.g. Attention.forward hidden_states). The previous strict kwargs_keys validation would reject this. Now we: 1. Remove the strict key-in-user_kwargs validation. 2. Reconstruct the capture-time arg order by checking both user_kwargs (by name) and user_args (by position) in kwargs_keys order. Also removes the now-redundant flatten_user_args since all args are merged into flatten_user_kwargs in the right order.
Warmup now uses the same torch.cuda.Stream for both warmup iterations and forward/backward graph captures. This keeps workspace buffers in the same CUDA context, avoiding re-allocation between warmup and capture.
Both make_graphed_callables and _make_graphed_callables now accept an optional capture_stream: Optional[torch.cuda.Stream] parameter. When provided, warmup and all graph captures share that stream instead of creating their own. This lets frameworks (e.g. FSDP v2) pass their shared pool stream to ensure correct serialization.
warmup_outputs and per_fwd_outputs hold references to all forward output tensors from the last warmup iteration. These pinned GPU memory across gc.collect() + torch.cuda.empty_cache(), causing ~19.8 GB extra memory usage during graph capture vs. no-warmup. Explicitly delete both before garbage collection so empty_cache actually releases the cached blocks.
This reverts commit d0e484b.
Warmup now uses a throwaway stream so torch.compile recompilation and other warmup-side CUDA ops don't contaminate the capture stream. capture_stream parameter still controls the stream used for all graph captures. This fixes cudaErrorStreamCaptureInvalidated when torch.compile triggers lazy recompilation during warmup.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.