@@ -55,6 +55,11 @@ def _involves_viewed_args(bsym, viewed):
5555 return any (isinstance (p , TensorProxy ) and variableify (p ) in viewed for p in bsym .flat_proxy_args )
5656
5757
58+ def _static_numel (tensor : TensorProxy ) -> int | None :
59+ numel = getattr (tensor , "_numel" , None )
60+ return numel if isinstance (numel , int ) else None
61+
62+
5863def replace_args_with_alias_map (
5964 computation_trace : Trace ,
6065 alias_tensor_indices : list [list [int ]],
@@ -69,9 +74,11 @@ def replace_args_with_alias_map(
6974 arg = flat_args [indices [0 ]]
7075 for idx in filter (lambda idx : idx < len (flat_args ), indices [1 :]):
7176 arg_to_replace = flat_args [idx ]
72- # Skip aliases with different numel (e.g., complex tensor and its real view)
77+ # Skip aliases with different statically-known numel (e.g., complex tensor and its real view)
7378 # These share storage but have incompatible element counts
74- if arg .numel != arg_to_replace .numel :
79+ arg_numel = _static_numel (arg )
80+ arg_to_replace_numel = _static_numel (arg_to_replace )
81+ if arg_numel is not None and arg_to_replace_numel is not None and arg_numel != arg_to_replace_numel :
7582 continue
7683 reshaped_arg = arg
7784 if arg_to_replace .shape != arg .shape :
0 commit comments