Skip to content

Handle torch compile#2

Draft
shjwudp wants to merge 9 commits into
buptzyb:mainfrom
shjwudp:handle_torch_compile
Draft

Handle torch compile#2
shjwudp wants to merge 9 commits into
buptzyb:mainfrom
shjwudp:handle_torch_compile

Conversation

@shjwudp

@shjwudp shjwudp commented Jun 29, 2026

Copy link
Copy Markdown

No description provided.

shjwudp added 9 commits June 26, 2026 15:41
sample_kwargs may contain non-tensor values (None, lists, bools)
that tree_flatten passes through into the static input surface.
Guard all .requires_grad and .data_ptr() accesses with None/type
checks to avoid AttributeError crashes when frameworks pass
mixed tensor/non-tensor keyword arguments.

- _run_warmup_backward: guard .requires_grad with 'is not None'
- num_required_grad_sample_args: handle non-tensor with isinstance check
- backward capture inputs: guard .requires_grad with 'is not None'
- Graphed.forward copy loop: guard inputs[i] with 'is not None'
Some frameworks pass captured kwargs as positional args during
replay (e.g. Attention.forward hidden_states).  The previous
strict kwargs_keys validation would reject this.  Now we:

1. Remove the strict key-in-user_kwargs validation.
2. Reconstruct the capture-time arg order by checking both
   user_kwargs (by name) and user_args (by position) in
   kwargs_keys order.

Also removes the now-redundant flatten_user_args since all
args are merged into flatten_user_kwargs in the right order.
Warmup now uses the same torch.cuda.Stream for both
warmup iterations and forward/backward graph captures.
This keeps workspace buffers in the same CUDA context,
avoiding re-allocation between warmup and capture.
Both make_graphed_callables and _make_graphed_callables now accept
an optional capture_stream: Optional[torch.cuda.Stream] parameter.
When provided, warmup and all graph captures share that stream instead
of creating their own.  This lets frameworks (e.g. FSDP v2) pass their
shared pool stream to ensure correct serialization.
warmup_outputs and per_fwd_outputs hold references to all forward
output tensors from the last warmup iteration.  These pinned GPU
memory across gc.collect() + torch.cuda.empty_cache(), causing
~19.8 GB extra memory usage during graph capture vs. no-warmup.

Explicitly delete both before garbage collection so empty_cache
actually releases the cached blocks.
Warmup now uses a throwaway stream so torch.compile recompilation
and other warmup-side CUDA ops don't contaminate the capture stream.
capture_stream parameter still controls the stream used for all graph
captures.  This fixes cudaErrorStreamCaptureInvalidated when
torch.compile triggers lazy recompilation during warmup.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant