Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/te_graph_runtime/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus

def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None:
static_input_surface = per_callable_static_input_surfaces[func_idx]
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: this still crashes for non-None, non-tensor kwargs such as the img_shapes=[[1,64,64]] example in the PR body. i is not None and i.requires_grad will call .requires_grad on a Python list/int/etc. The guard needs to be isinstance(i, torch.Tensor) and i.requires_grad here and in the corresponding backward-capture paths below.

outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad)
grad_outputs = _make_grad_outputs(outputs)

Expand All @@ -883,10 +883,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
# Filter module params that get None grad from grad_inputs and remove them
# from static_input_surface. This is to ensure that the backward hooks
# registered to these params are not wrongly triggered.
num_required_grad_sample_args = sum(arg.requires_grad for arg in flatten_sample_args[func_idx])
num_required_grad_sample_args = sum(
isinstance(arg, torch.Tensor) and arg.requires_grad
for arg in flatten_sample_args[func_idx]
)
required_grad_input_idx = []
for i, arg in enumerate(static_input_surface):
if arg.requires_grad:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
required_grad_input_idx.append(i)
module_params_with_grad = []
for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx):
Expand Down Expand Up @@ -1108,7 +1111,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
func,
static_grad_outputs,
)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
Expand Down Expand Up @@ -1232,7 +1235,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) ->
if is_training:
func = graph_callables[bwd_idx]
_call_capture_time_backward_pre_hooks(bwd_idx, func, static_grad_outputs)
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)
with _none_grad_context_wrapper(inputs), _graph_context_wrapper(
bwd_graph, pool=mempool
):
Expand Down Expand Up @@ -1311,6 +1314,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i
for i in range(len_user_args):
if (
isinstance(static_input_surface[i], torch.Tensor)
and inputs[i] is not None
and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
):
static_input_surface[i].copy_(inputs[i])
Expand Down Expand Up @@ -1416,21 +1420,23 @@ def functionalized(*user_args, **user_kwargs):
user_kwargs.pop("cuda_graph_event")
else:
cuda_graph_event = None
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
raise TypeError(
f"Graphed callable was initialized with kwarg {key} ,"
"but it was not provided in graph replay"
)

# Runs the autograd function with inputs == all inputs to
# the graph that might require grad (explicit user args +
# module parameters)
# Assumes module params didn't change since capture.
flatten_user_args, _ = _tree_flatten(user_args)
flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
# Reconstruct the same flattened arg order as capture time.
# User may pass some recorded kwargs as positional args, so
# check user_args first (by position), then user_kwargs.
user_pos_args = list(user_args)
kwarg_values = []
for key in kwargs_keys:
if key in user_kwargs:
kwarg_values.append(user_kwargs[key])
elif user_pos_args:
kwarg_values.append(user_pos_args.pop(0))
# else: key was a default not passed — skip (not a tensor)
flatten_user_kwargs, _ = _tree_flatten(kwarg_values)
func_args = tuple(flatten_user_kwargs) + module_params

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: this drops normal replay positional inputs. user_args is only consumed as a fallback source for captured kwargs, and func_args is built from flatten_user_kwargs + module_params, so graphed(x) with no captured kwargs passes only module params, and graphed(x, scale=scale) drops x. This needs to preserve the flattened explicit user args and then append the captured kwarg values in the same order used during capture.

out = Graphed.apply(
skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args
)
Expand Down