-
Notifications
You must be signed in to change notification settings - Fork 108
Handle aliasing of viewed input tensors of varying shapes #2760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
caf8058
6fc9aec
e0c385b
9d24da4
52703b3
0f0592b
5cdb81f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| from functools import reduce, partial | ||
| import operator | ||
|
|
||
| from thunder.core.compile_data import using_symbolic_values | ||
| import thunder.core.prims as prims | ||
| from thunder.core.proxies import TensorProxy, variableify, unvariableify | ||
| from thunder.core.pytree import tree_flatten | ||
|
|
@@ -55,23 +57,36 @@ def _involves_viewed_args(bsym, viewed): | |
| return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args) | ||
|
|
||
|
|
||
| def _can_be_reshaped(arg, arg_to_replace): | ||
| # TODO: Fix this once numel for symbolic values is implemented | ||
| if using_symbolic_values(): | ||
| arg_numel = reduce(operator.mul, arg._shape, 1) | ||
| arg_to_replace_numel = reduce(operator.mul, arg_to_replace._shape, 1) | ||
| else: | ||
| arg_numel = arg.numel | ||
| arg_to_replace_numel = arg_to_replace.numel | ||
| return arg_numel == arg_to_replace_numel | ||
|
|
||
|
|
||
| def replace_args_with_alias_map( | ||
| computation_trace: Trace, | ||
| alias_tensor_indices: list[list[int]], | ||
| ) -> tuple[Trace, dict[VariableInterface, TensorProxy]]: | ||
| ) -> tuple[Trace, list[set[VariableInterface]]]: | ||
| if not alias_tensor_indices: | ||
| return computation_trace, {} | ||
beverlylytle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| bsyms: list[BoundSymbol] = [] | ||
| flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs)) | ||
| swap_map_for_aliases: dict[VariableInterface, TensorProxy] = {} | ||
| arg_to_optional_bsyms: dict[VariableInterface, BoundSymbol] = {} | ||
| view_groups = {} | ||
| for indices in alias_tensor_indices: | ||
| arg = flat_args[indices[0]] | ||
| for idx in filter(lambda idx: idx < len(flat_args), indices[1:]): | ||
| arg_to_replace = flat_args[idx] | ||
| # Skip aliases with different numel (e.g., complex tensor and its real view) | ||
beverlylytle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # These share storage but have incompatible element counts | ||
| if arg.numel != arg_to_replace.numel: | ||
| if not _can_be_reshaped(arg, arg_to_replace): | ||
beverlylytle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| view_groups.setdefault(variableify(arg), []).append(variableify(arg_to_replace)) | ||
| continue | ||
| reshaped_arg = arg | ||
| if arg_to_replace.shape != arg.shape: | ||
|
|
@@ -111,7 +126,8 @@ def replace_args_with_alias_map( | |
| no_implicit_alias_trace.bound_symbols = bsyms | ||
| str_map = {unvariableify(k).name: v.name for k, v in swap_map_for_aliases.items()} | ||
| no_implicit_alias_trace.set_provenance(TraceProvenance(f"Duplicate alias args using {str_map}")) | ||
| return no_implicit_alias_trace, swap_map_for_aliases | ||
| view_groups = [{k}.union(set(v)) for k, v in view_groups.items() if len(v) != 0] | ||
| return no_implicit_alias_trace, view_groups | ||
|
|
||
|
|
||
| def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace: | ||
|
|
@@ -123,10 +139,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li | |
|
|
||
| # First pass: identify inputs which are views of each other and swap them out with a default, | ||
| # reshaping if necessary. | ||
| computation_trace, _ = replace_args_with_alias_map(computation_trace, alias_tensor_indices) | ||
| computation_trace, input_view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices) | ||
|
|
||
| # Second pass: identify views, their originals, and operands involved in inplace ops | ||
| view_groups = [] | ||
| intermediate_view_groups = [] | ||
| inplace_inputs = set() | ||
| for bsym in computation_trace.bound_symbols: | ||
| if _is_inplace_op(bsym) or _is_view_creation_op(bsym): | ||
|
|
@@ -136,17 +152,21 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li | |
| if _is_inplace_op(bsym): | ||
| inplace_inputs.add(in_tensor) | ||
| out_tensors = set() | ||
| for group in view_groups: | ||
| for group in intermediate_view_groups: | ||
|
||
| if in_tensor in group: | ||
| group.update(out_tensors) | ||
| break | ||
| else: | ||
| view_groups.append(out_tensors.union({in_tensor})) | ||
| intermediate_view_groups.append(out_tensors.union({in_tensor})) | ||
|
|
||
| # filter out view groups that don't have any tensors involved in inplace ops | ||
| view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0] | ||
| input_view_groups = [group for group in input_view_groups if len(group.intersection(inplace_inputs)) != 0] | ||
| intermediate_view_groups = [ | ||
| group for group in intermediate_view_groups if len(group.intersection(inplace_inputs)) != 0 | ||
| ] | ||
| view_groups = input_view_groups + intermediate_view_groups | ||
| viewed = set(reduce(set.union, view_groups, set())) | ||
| encountered = set() | ||
| encountered = set(reduce(set.union, input_view_groups, set())) | ||
|
|
||
| # Third pass: insert alias updates | ||
| for bsym in computation_trace.bound_symbols: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.