From ad90b781193a1cfd9b9e594a44c38d50faadb35a Mon Sep 17 00:00:00 2001 From: jianbinc Date: Fri, 26 Jun 2026 15:41:27 +0800 Subject: [PATCH 1/2] fix: handle non-tensor sample_kwargs in static_input_surface sample_kwargs may contain non-tensor values (None, lists, bools) that tree_flatten passes through into the static input surface. Guard all .requires_grad and .data_ptr() accesses with None/type checks to avoid AttributeError crashes when frameworks pass mixed tensor/non-tensor keyword arguments. - _run_warmup_backward: guard .requires_grad with 'is not None' - num_required_grad_sample_args: handle non-tensor with isinstance check - backward capture inputs: guard .requires_grad with 'is not None' - Graphed.forward copy loop: guard inputs[i] with 'is not None' --- src/te_graph_runtime/graph.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/te_graph_runtime/graph.py b/src/te_graph_runtime/graph.py index 569d248..a1d9e7f 100644 --- a/src/te_graph_runtime/graph.py +++ b/src/te_graph_runtime/graph.py @@ -867,7 +867,7 @@ def hook_fn(module, inputs, outputs, func_idx=func_idx): # pylint: disable=unus def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None: static_input_surface = per_callable_static_input_surfaces[func_idx] - inputs = tuple(i for i in static_input_surface if i.requires_grad) + inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) outputs_requiring_grad = tuple(o for o in outputs if o is not None and o.requires_grad) grad_outputs = _make_grad_outputs(outputs) @@ -883,10 +883,13 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> # Filter module params that get None grad from grad_inputs and remove them # from static_input_surface. This is to ensure that the backward hooks # registered to these params are not wrongly triggered. - num_required_grad_sample_args = sum(arg.requires_grad for arg in flatten_sample_args[func_idx]) + num_required_grad_sample_args = sum( + isinstance(arg, torch.Tensor) and arg.requires_grad + for arg in flatten_sample_args[func_idx] + ) required_grad_input_idx = [] for i, arg in enumerate(static_input_surface): - if arg.requires_grad: + if isinstance(arg, torch.Tensor) and arg.requires_grad: required_grad_input_idx.append(i) module_params_with_grad = [] for grad_inputs_idx, inputs_idx in enumerate(required_grad_input_idx): @@ -1108,7 +1111,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> func, static_grad_outputs, ) - inputs = tuple(i for i in static_input_surface if i.requires_grad) + inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool ): @@ -1232,7 +1235,7 @@ def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> if is_training: func = graph_callables[bwd_idx] _call_capture_time_backward_pre_hooks(bwd_idx, func, static_grad_outputs) - inputs = tuple(i for i in static_input_surface if i.requires_grad) + inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) with _none_grad_context_wrapper(inputs), _graph_context_wrapper( bwd_graph, pool=mempool ): @@ -1311,6 +1314,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i for i in range(len_user_args): if ( isinstance(static_input_surface[i], torch.Tensor) + and inputs[i] is not None and static_input_surface[i].data_ptr() != inputs[i].data_ptr() ): static_input_surface[i].copy_(inputs[i]) From bc6902232d4f728b14ead3414c037fe8a1943706 Mon Sep 17 00:00:00 2001 From: jianbinc Date: Fri, 26 Jun 2026 16:11:03 +0800 Subject: [PATCH 2/2] fix: handle positional args in functionalized during graph replay Some frameworks pass captured kwargs as positional args during replay (e.g. Attention.forward hidden_states). The previous strict kwargs_keys validation would reject this. Now we: 1. Remove the strict key-in-user_kwargs validation. 2. Reconstruct the capture-time arg order by checking both user_kwargs (by name) and user_args (by position) in kwargs_keys order. Also removes the now-redundant flatten_user_args since all args are merged into flatten_user_kwargs in the right order. --- src/te_graph_runtime/graph.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/te_graph_runtime/graph.py b/src/te_graph_runtime/graph.py index a1d9e7f..c2b93ab 100644 --- a/src/te_graph_runtime/graph.py +++ b/src/te_graph_runtime/graph.py @@ -1420,21 +1420,23 @@ def functionalized(*user_args, **user_kwargs): user_kwargs.pop("cuda_graph_event") else: cuda_graph_event = None - # Check that required kwargs are provided - for key in kwargs_keys: - if key not in user_kwargs: - raise TypeError( - f"Graphed callable was initialized with kwarg {key} ," - "but it was not provided in graph replay" - ) - # Runs the autograd function with inputs == all inputs to # the graph that might require grad (explicit user args + # module parameters) # Assumes module params didn't change since capture. - flatten_user_args, _ = _tree_flatten(user_args) - flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) - func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params + # Reconstruct the same flattened arg order as capture time. + # User may pass some recorded kwargs as positional args, so + # check user_args first (by position), then user_kwargs. + user_pos_args = list(user_args) + kwarg_values = [] + for key in kwargs_keys: + if key in user_kwargs: + kwarg_values.append(user_kwargs[key]) + elif user_pos_args: + kwarg_values.append(user_pos_args.pop(0)) + # else: key was a default not passed — skip (not a tensor) + flatten_user_kwargs, _ = _tree_flatten(kwarg_values) + func_args = tuple(flatten_user_kwargs) + module_params out = Graphed.apply( skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args )