|
1 | 1 | from functools import reduce, partial |
2 | 2 |
|
| 3 | +from thunder.core.compile_data import using_symbolic_values |
3 | 4 | import thunder.core.prims as prims |
4 | 5 | from thunder.core.proxies import TensorProxy, variableify, unvariableify |
5 | 6 | from thunder.core.pytree import tree_flatten |
@@ -76,21 +77,23 @@ def replace_args_with_alias_map( |
76 | 77 | reshaped_arg = arg |
77 | 78 | if arg_to_replace.shape != arg.shape: |
78 | 79 | with tracectx(computation_trace): |
79 | | - reshaped_arg = prims.reshape.meta(arg, arg_to_replace.shape) |
80 | | - arg_to_optional_bsyms[variableify(arg_to_replace)] = prims.reshape.bind( |
81 | | - arg, |
82 | | - arg_to_replace.shape, |
83 | | - output=reshaped_arg, |
84 | | - ) |
| 80 | + shape = prims.shape.meta(arg_to_replace) |
| 81 | + reshaped_arg = prims.reshape.meta(arg, shape) |
| 82 | + reshape_bsym = prims.reshape.bind(arg, shape, output=reshaped_arg) |
| 83 | + if using_symbolic_values(): |
| 84 | + shape_bsym = prims.shape.bind(arg_to_replace, output=shape) |
| 85 | + arg_to_optional_bsyms[variableify(arg_to_replace)] = (shape_bsym, reshape_bsym) |
| 86 | + else: |
| 87 | + arg_to_optional_bsyms[variableify(arg_to_replace)] = (reshape_bsym,) |
85 | 88 | swap_map_for_aliases[variableify(arg_to_replace)] = reshaped_arg |
86 | 89 | appended_bsyms = {} |
87 | 90 | for bsym in computation_trace.bound_symbols: |
88 | 91 | for arg in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args): |
89 | | - reshape_bsym = arg_to_optional_bsyms.get(variableify(arg)) |
90 | | - if reshape_bsym is not None: |
91 | | - if reshape_bsym not in appended_bsyms: |
92 | | - bsyms.append(reshape_bsym) |
93 | | - appended_bsyms[reshape_bsym] = arg |
| 92 | + reshape_bsyms = arg_to_optional_bsyms.get(variableify(arg)) |
| 93 | + if reshape_bsyms is not None: |
| 94 | + if reshape_bsyms not in appended_bsyms: |
| 95 | + bsyms.extend(reshape_bsyms) |
| 96 | + appended_bsyms[reshape_bsyms] = arg |
94 | 97 | if replaced_args_map := { |
95 | 98 | x.name: swap_map_for_aliases[variableify(x)].name |
96 | 99 | for x in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args) |
|
0 commit comments