-
Notifications
You must be signed in to change notification settings - Fork 2
fix: handle non-tensor sample_kwargs in static_input_surface #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -867,7 +867,7 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus | |
|
|
||
| def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None: | ||
| static_input_surface = per_callable_static_input_surfaces[func_idx] | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) | ||
| outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) | ||
| grad_outputs = _make_grad_outputs(outputs) | ||
|
|
||
|
|
@@ -883,10 +883,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> | |
| # Filter module params that get None grad from grad_inputs and remove them | ||
| # from static_input_surface. This is to ensure that the backward hooks | ||
| # registered to these params are not wrongly triggered. | ||
| num_required_grad_sample_args = sum(arg.requires_grad for arg in flatten_sample_args[func_idx]) | ||
| num_required_grad_sample_args = sum( | ||
| isinstance(arg, torch.Tensor) and arg.requires_grad | ||
| for arg in flatten_sample_args[func_idx] | ||
| ) | ||
| required_grad_input_idx = [] | ||
| for i, arg in enumerate(static_input_surface): | ||
| if arg.requires_grad: | ||
| if isinstance(arg, torch.Tensor) and arg.requires_grad: | ||
| required_grad_input_idx.append(i) | ||
| module_params_with_grad = [] | ||
| for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): | ||
|
|
@@ -1108,7 +1111,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> | |
| func, | ||
| static_grad_outputs, | ||
| ) | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) | ||
| with _none_grad_context_wrapper(inputs), _graph_context_wrapper( | ||
| bwd_graph, pool=mempool | ||
| ): | ||
|
|
@@ -1232,7 +1235,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> | |
| if is_training: | ||
| func = graph_callables[bwd_idx] | ||
| _call_capture_time_backward_pre_hooks(bwd_idx, func, static_grad_outputs) | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) | ||
| with _none_grad_context_wrapper(inputs), _graph_context_wrapper( | ||
| bwd_graph, pool=mempool | ||
| ): | ||
|
|
@@ -1311,6 +1314,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i | |
| for i in range(len_user_args): | ||
| if ( | ||
| isinstance(static_input_surface[i], torch.Tensor) | ||
| and inputs[i] is not None | ||
| and static_input_surface[i].data_ptr() != inputs[i].data_ptr() | ||
| ): | ||
| static_input_surface[i].copy_(inputs[i]) | ||
|
|
@@ -1416,21 +1420,23 @@ def functionalized(*user_args, **user_kwargs): | |
| user_kwargs.pop("cuda_graph_event") | ||
| else: | ||
| cuda_graph_event = None | ||
| # Check that required kwargs are provided | ||
| for key in kwargs_keys: | ||
| if key not in user_kwargs: | ||
| raise TypeError( | ||
| f"Graphed callable was initialized with kwarg {key} ," | ||
| "but it was not provided in graph replay" | ||
| ) | ||
|
|
||
| # Runs the autograd function with inputs == all inputs to | ||
| # the graph that might require grad (explicit user args + | ||
| # module parameters) | ||
| # Assumes module params didn't change since capture. | ||
| flatten_user_args, _ = _tree_flatten(user_args) | ||
| flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) | ||
| func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params | ||
| # Reconstruct the same flattened arg order as capture time. | ||
| # User may pass some recorded kwargs as positional args, so | ||
| # check user_args first (by position), then user_kwargs. | ||
| user_pos_args = list(user_args) | ||
| kwarg_values = [] | ||
| for key in kwargs_keys: | ||
| if key in user_kwargs: | ||
| kwarg_values.append(user_kwargs[key]) | ||
| elif user_pos_args: | ||
| kwarg_values.append(user_pos_args.pop(0)) | ||
| # else: key was a default not passed — skip (not a tensor) | ||
| flatten_user_kwargs, _ = _tree_flatten(kwarg_values) | ||
| func_args = tuple(flatten_user_kwargs) + module_params | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Blocker: this drops normal replay positional inputs. |
||
| out = Graphed.apply( | ||
| skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocker: this still crashes for non-
None, non-tensor kwargs such as theimg_shapes=[[1,64,64]]example in the PR body.i is not None and i.requires_gradwill call.requires_gradon a Python list/int/etc. The guard needs to beisinstance(i, torch.Tensor) and i.requires_gradhere and in the corresponding backward-capture paths below.