Skip to content

Commit ec21d73

Browse files
Handle aliasing of viewed input tensors of varying shapes (#2760)
Co-authored-by: Masato Shinokawa <[email protected]>
1 parent 5599c15 commit ec21d73

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

thunder/core/update_aliases.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import reduce, partial
22

3+
from thunder.core.compile_data import using_symbolic_values
34
import thunder.core.prims as prims
45
from thunder.core.proxies import TensorProxy, variableify, unvariableify
56
from thunder.core.pytree import tree_flatten
@@ -55,23 +56,36 @@ def _involves_viewed_args(bsym, viewed):
5556
return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args)
5657

5758

59+
def _can_be_reshaped(arg, arg_to_replace):
60+
# TODO: Fix this once numel for symbolic values is implemented
61+
if using_symbolic_values():
62+
arg_numel = arg._numel()
63+
arg_to_replace_numel = arg_to_replace._numel()
64+
else:
65+
arg_numel = arg.numel
66+
arg_to_replace_numel = arg_to_replace.numel
67+
return arg_numel == arg_to_replace_numel
68+
69+
5870
def replace_args_with_alias_map(
5971
computation_trace: Trace,
6072
alias_tensor_indices: list[list[int]],
61-
) -> tuple[Trace, dict[VariableInterface, TensorProxy]]:
73+
) -> tuple[Trace, list[set[VariableInterface]]]:
6274
if not alias_tensor_indices:
63-
return computation_trace, {}
75+
return computation_trace, []
6476
bsyms: list[BoundSymbol] = []
6577
flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs))
6678
swap_map_for_aliases: dict[VariableInterface, TensorProxy] = {}
6779
arg_to_optional_bsyms: dict[VariableInterface, BoundSymbol] = {}
80+
view_groups = {}
6881
for indices in alias_tensor_indices:
6982
arg = flat_args[indices[0]]
7083
for idx in filter(lambda idx: idx < len(flat_args), indices[1:]):
7184
arg_to_replace = flat_args[idx]
72-
# Skip aliases with different numel (e.g., complex tensor and its real view)
85+
# Track aliases with different numel (e.g., complex tensor and its real view)
7386
# These share storage but have incompatible element counts
74-
if arg.numel != arg_to_replace.numel:
87+
if not _can_be_reshaped(arg, arg_to_replace):
88+
view_groups.setdefault(variableify(arg), []).append(variableify(arg_to_replace))
7589
continue
7690
reshaped_arg = arg
7791
if arg_to_replace.shape != arg.shape:
@@ -111,7 +125,8 @@ def replace_args_with_alias_map(
111125
no_implicit_alias_trace.bound_symbols = bsyms
112126
str_map = {unvariableify(k).name: v.name for k, v in swap_map_for_aliases.items()}
113127
no_implicit_alias_trace.set_provenance(TraceProvenance(f"Duplicate alias args using {str_map}"))
114-
return no_implicit_alias_trace, swap_map_for_aliases
128+
view_groups = [{k}.union(set(v)) for k, v in view_groups.items() if len(v) != 0]
129+
return no_implicit_alias_trace, view_groups
115130

116131

117132
def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace:
@@ -123,10 +138,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
123138

124139
# First pass: identify inputs which are views of each other and swap them out with a default,
125140
# reshaping if necessary.
126-
computation_trace, _ = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
141+
computation_trace, view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
127142

128143
# Second pass: identify views, their originals, and operands involved in inplace ops
129-
view_groups = []
144+
encountered = set().union(*view_groups)
130145
inplace_inputs = set()
131146
for bsym in computation_trace.bound_symbols:
132147
if _is_inplace_op(bsym) or _is_view_creation_op(bsym):
@@ -146,7 +161,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
146161
# filter out view groups that don't have any tensors involved in inplace ops
147162
view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0]
148163
viewed = set(reduce(set.union, view_groups, set()))
149-
encountered = set()
150164

151165
# Third pass: insert alias updates
152166
for bsym in computation_trace.bound_symbols:

thunder/tests/test_update_aliases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,3 +517,25 @@ def foo(x):
517517
expected_grad = torch.autograd.grad(expected, c, g)
518518
torch.testing.assert_close(actual_grad_fx, expected_grad)
519519
torch.testing.assert_close(actual_grad_jit, expected_grad)
520+
521+
522+
@instantiate(
523+
dtypes=(dtypes.float32,),
524+
)
525+
def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype):
526+
def f(x, y, z):
527+
return x + 2, y.add_(z)
528+
529+
a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device)
530+
b = a[0, :]
531+
c = a[1, :]
532+
a_ = a.clone().detach()
533+
b_ = a_[0, :]
534+
c_ = a_[1, :]
535+
jfn = executor.make_callable(f)
536+
actual = jfn(a, b, c)
537+
expected = f(a_, b_, c_)
538+
torch.testing.assert_close(actual, expected)
539+
torch.testing.assert_close(a, a_)
540+
torch.testing.assert_close(b, b_)
541+
torch.testing.assert_close(c, c_)

0 commit comments

Comments
 (0)