Skip to content

Commit f0a5681

Browse files
committed
Insert prims.shape before reshape bsym in update_aliases.py
1 parent 4d3a3c3 commit f0a5681

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

thunder/core/update_aliases.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import reduce, partial
22

3+
from thunder.core.compile_data import using_symbolic_values
34
import thunder.core.prims as prims
45
from thunder.core.proxies import TensorProxy, variableify, unvariableify
56
from thunder.core.pytree import tree_flatten
@@ -76,21 +77,23 @@ def replace_args_with_alias_map(
7677
reshaped_arg = arg
7778
if arg_to_replace.shape != arg.shape:
7879
with tracectx(computation_trace):
79-
reshaped_arg = prims.reshape.meta(arg, arg_to_replace.shape)
80-
arg_to_optional_bsyms[variableify(arg_to_replace)] = prims.reshape.bind(
81-
arg,
82-
arg_to_replace.shape,
83-
output=reshaped_arg,
84-
)
80+
shape = prims.shape.meta(arg_to_replace)
81+
reshaped_arg = prims.reshape.meta(arg, shape)
82+
reshape_bsym = prims.reshape.bind(arg, shape, output=reshaped_arg)
83+
if using_symbolic_values():
84+
shape_bsym = prims.shape.bind(arg_to_replace, output=shape)
85+
arg_to_optional_bsyms[variableify(arg_to_replace)] = (shape_bsym, reshape_bsym)
86+
else:
87+
arg_to_optional_bsyms[variableify(arg_to_replace)] = (reshape_bsym,)
8588
swap_map_for_aliases[variableify(arg_to_replace)] = reshaped_arg
8689
appended_bsyms = {}
8790
for bsym in computation_trace.bound_symbols:
8891
for arg in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args):
89-
reshape_bsym = arg_to_optional_bsyms.get(variableify(arg))
90-
if reshape_bsym is not None:
91-
if reshape_bsym not in appended_bsyms:
92-
bsyms.append(reshape_bsym)
93-
appended_bsyms[reshape_bsym] = arg
92+
reshape_bsyms = arg_to_optional_bsyms.get(variableify(arg))
93+
if reshape_bsyms is not None:
94+
if reshape_bsyms not in appended_bsyms:
95+
bsyms.extend(reshape_bsyms)
96+
appended_bsyms[reshape_bsyms] = arg
9497
if replaced_args_map := {
9598
x.name: swap_map_for_aliases[variableify(x)].name
9699
for x in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args)

0 commit comments

Comments
 (0)