fix: handle non-tensor sample_kwargs in static_input_surface#1
Conversation
sample_kwargs may contain non-tensor values (None, lists, bools) that tree_flatten passes through into the static input surface. Guard all .requires_grad and .data_ptr() accesses with None/type checks to avoid AttributeError crashes when frameworks pass mixed tensor/non-tensor keyword arguments. - _run_warmup_backward: guard .requires_grad with 'is not None' - num_required_grad_sample_args: handle non-tensor with isinstance check - backward capture inputs: guard .requires_grad with 'is not None' - Graphed.forward copy loop: guard inputs[i] with 'is not None'
Some frameworks pass captured kwargs as positional args during replay (e.g. Attention.forward hidden_states). The previous strict kwargs_keys validation would reject this. Now we: 1. Remove the strict key-in-user_kwargs validation. 2. Reconstruct the capture-time arg order by checking both user_kwargs (by name) and user_args (by position) in kwargs_keys order. Also removes the now-redundant flatten_user_args since all args are merged into flatten_user_kwargs in the right order.
| kwarg_values.append(user_pos_args.pop(0)) | ||
| # else: key was a default not passed — skip (not a tensor) | ||
| flatten_user_kwargs, _ = _tree_flatten(kwarg_values) | ||
| func_args = tuple(flatten_user_kwargs) + module_params |
There was a problem hiding this comment.
Blocker: this drops normal replay positional inputs. user_args is only consumed as a fallback source for captured kwargs, and func_args is built from flatten_user_kwargs + module_params, so graphed(x) with no captured kwargs passes only module params, and graphed(x, scale=scale) drops x. This needs to preserve the flattened explicit user args and then append the captured kwarg values in the same order used during capture.
| def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None: | ||
| static_input_surface = per_callable_static_input_surfaces[func_idx] | ||
| inputs = tuple(i for i in static_input_surface if i.requires_grad) | ||
| inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad) |
There was a problem hiding this comment.
Blocker: this still crashes for non-None, non-tensor kwargs such as the img_shapes=[[1,64,64]] example in the PR body. i is not None and i.requires_grad will call .requires_grad on a Python list/int/etc. The guard needs to be isinstance(i, torch.Tensor) and i.requires_grad here and in the corresponding backward-capture paths below.
Fix non-tensor
sample_kwargsinstatic_input_surface+ positional args in replayTwo commits fixing crashes when frameworks pass mixed tensor/non-tensor
keyword arguments and positional args during replay.
Commit 1:
fix: handle non-tensor sample_kwargs in static_input_surfaceWhen
sample_kwargscontains non-tensor values (e.g.attention_mask=None,img_shapes=[[1,64,64]]),tree_flattenpasses them intostatic_input_surface.All existing
.requires_gradand.data_ptr()accesses crash onNone.Fix: 6 guards across 4 locations:
_run_warmup_backward:i is not None and i.requires_gradnum_required_grad_sample_args:isinstance(arg, torch.Tensor) and arg.requires_gradi is not None and i.requires_gradGraphed.forwardcopy loop:inputs[i] is not NoneCommit 2:
fix: handle positional args in functionalized during graph replayFrameworks may pass captured kwargs as positional args during replay
(e.g.
Attention.forward(hidden_states, attention_mask=mask)). The previousstrict
kwargs_keysvalidation would reject this.Fix: Remove the strict
key in user_kwargsvalidation. Reconstruct thecapture-time arg order by checking both
user_kwargs(by name) anduser_args(by position) for each key in
kwargs_keys.Diff
src/te_graph_runtime/graph.py | 28 +++++++++++++++++-----------
1 file changed, 17 insertions(+), 11 deletions(-)
Impact
Frameworks using
sample_kwargswith non-tensor values or positional args no longer crash duringwarmup, capture, or replay.