11from functools import reduce , partial
22
3+ from thunder .core .compile_data import using_symbolic_values
34import thunder .core .prims as prims
45from thunder .core .proxies import TensorProxy , variableify , unvariableify
56from thunder .core .pytree import tree_flatten
@@ -55,23 +56,36 @@ def _involves_viewed_args(bsym, viewed):
5556 return any (isinstance (p , TensorProxy ) and variableify (p ) in viewed for p in bsym .flat_proxy_args )
5657
5758
59+ def _can_be_reshaped (arg , arg_to_replace ):
60+ # TODO: Fix this once numel for symbolic values is implemented
61+ if using_symbolic_values ():
62+ arg_numel = arg ._numel ()
63+ arg_to_replace_numel = arg_to_replace ._numel ()
64+ else :
65+ arg_numel = arg .numel
66+ arg_to_replace_numel = arg_to_replace .numel
67+ return arg_numel == arg_to_replace_numel
68+
69+
5870def replace_args_with_alias_map (
5971 computation_trace : Trace ,
6072 alias_tensor_indices : list [list [int ]],
61- ) -> tuple [Trace , dict [ VariableInterface , TensorProxy ]]:
73+ ) -> tuple [Trace , list [ set [ VariableInterface ] ]]:
6274 if not alias_tensor_indices :
63- return computation_trace , {}
75+ return computation_trace , []
6476 bsyms : list [BoundSymbol ] = []
6577 flat_args , _ = tree_flatten ((computation_trace .args , computation_trace .kwargs ))
6678 swap_map_for_aliases : dict [VariableInterface , TensorProxy ] = {}
6779 arg_to_optional_bsyms : dict [VariableInterface , BoundSymbol ] = {}
80+ view_groups = {}
6881 for indices in alias_tensor_indices :
6982 arg = flat_args [indices [0 ]]
7083 for idx in filter (lambda idx : idx < len (flat_args ), indices [1 :]):
7184 arg_to_replace = flat_args [idx ]
72- # Skip aliases with different numel (e.g., complex tensor and its real view)
85+ # Track aliases with different numel (e.g., complex tensor and its real view)
7386 # These share storage but have incompatible element counts
74- if arg .numel != arg_to_replace .numel :
87+ if not _can_be_reshaped (arg , arg_to_replace ):
88+ view_groups .setdefault (variableify (arg ), []).append (variableify (arg_to_replace ))
7589 continue
7690 reshaped_arg = arg
7791 if arg_to_replace .shape != arg .shape :
@@ -111,7 +125,8 @@ def replace_args_with_alias_map(
111125 no_implicit_alias_trace .bound_symbols = bsyms
112126 str_map = {unvariableify (k ).name : v .name for k , v in swap_map_for_aliases .items ()}
113127 no_implicit_alias_trace .set_provenance (TraceProvenance (f"Duplicate alias args using { str_map } " ))
114- return no_implicit_alias_trace , swap_map_for_aliases
128+ view_groups = [{k }.union (set (v )) for k , v in view_groups .items () if len (v ) != 0 ]
129+ return no_implicit_alias_trace , view_groups
115130
116131
117132def insert_alias_updates (computation_trace : Trace , alias_tensor_indices : list [list [int ]]) -> Trace :
@@ -123,10 +138,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
123138
124139 # First pass: identify inputs which are views of each other and swap them out with a default,
125140 # reshaping if necessary.
126- computation_trace , _ = replace_args_with_alias_map (computation_trace , alias_tensor_indices )
141+ computation_trace , view_groups = replace_args_with_alias_map (computation_trace , alias_tensor_indices )
127142
128143 # Second pass: identify views, their originals, and operands involved in inplace ops
129- view_groups = []
144+ encountered = set (). union ( * view_groups )
130145 inplace_inputs = set ()
131146 for bsym in computation_trace .bound_symbols :
132147 if _is_inplace_op (bsym ) or _is_view_creation_op (bsym ):
@@ -146,7 +161,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
146161 # filter out view groups that don't have any tensors involved in inplace ops
147162 view_groups = [group for group in view_groups if len (group .intersection (inplace_inputs )) != 0 ]
148163 viewed = set (reduce (set .union , view_groups , set ()))
149- encountered = set ()
150164
151165 # Third pass: insert alias updates
152166 for bsym in computation_trace .bound_symbols :
0 commit comments