Skip to content

Commit d05dac5

Browse files
committed
Skip update_alias numel check if shape is dynamic
1 parent ade597e commit d05dac5

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

thunder/core/update_aliases.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def _involves_viewed_args(bsym, viewed):
5555
return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args)
5656

5757

58+
def _static_numel(tensor: TensorProxy) -> int | None:
59+
numel = getattr(tensor, "_numel", None)
60+
return numel if isinstance(numel, int) else None
61+
62+
5863
def replace_args_with_alias_map(
5964
computation_trace: Trace,
6065
alias_tensor_indices: list[list[int]],
@@ -69,9 +74,11 @@ def replace_args_with_alias_map(
6974
arg = flat_args[indices[0]]
7075
for idx in filter(lambda idx: idx < len(flat_args), indices[1:]):
7176
arg_to_replace = flat_args[idx]
72-
# Skip aliases with different numel (e.g., complex tensor and its real view)
77+
# Skip aliases with different statically-known numel (e.g., complex tensor and its real view)
7378
# These share storage but have incompatible element counts
74-
if arg.numel != arg_to_replace.numel:
79+
arg_numel = _static_numel(arg)
80+
arg_to_replace_numel = _static_numel(arg_to_replace)
81+
if arg_numel is not None and arg_to_replace_numel is not None and arg_numel != arg_to_replace_numel:
7582
continue
7683
reshaped_arg = arg
7784
if arg_to_replace.shape != arg.shape:

0 commit comments

Comments
 (0)