Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions thunder/core/update_aliases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import reduce, partial
import operator

from thunder.core.compile_data import using_symbolic_values
import thunder.core.prims as prims
from thunder.core.proxies import TensorProxy, variableify, unvariableify
from thunder.core.pytree import tree_flatten
Expand Down Expand Up @@ -55,23 +57,36 @@ def _involves_viewed_args(bsym, viewed):
return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args)


def _can_be_reshaped(arg, arg_to_replace):
# TODO: Fix this once numel for symbolic values is implemented
if using_symbolic_values():
arg_numel = reduce(operator.mul, arg._shape, 1)
arg_to_replace_numel = reduce(operator.mul, arg_to_replace._shape, 1)
else:
arg_numel = arg.numel
arg_to_replace_numel = arg_to_replace.numel
return arg_numel == arg_to_replace_numel


def replace_args_with_alias_map(
computation_trace: Trace,
alias_tensor_indices: list[list[int]],
) -> tuple[Trace, dict[VariableInterface, TensorProxy]]:
) -> tuple[Trace, list[set[VariableInterface]]]:
if not alias_tensor_indices:
return computation_trace, {}
bsyms: list[BoundSymbol] = []
flat_args, _ = tree_flatten((computation_trace.args, computation_trace.kwargs))
swap_map_for_aliases: dict[VariableInterface, TensorProxy] = {}
arg_to_optional_bsyms: dict[VariableInterface, BoundSymbol] = {}
view_groups = {}
for indices in alias_tensor_indices:
arg = flat_args[indices[0]]
for idx in filter(lambda idx: idx < len(flat_args), indices[1:]):
arg_to_replace = flat_args[idx]
# Skip aliases with different numel (e.g., complex tensor and its real view)
# These share storage but have incompatible element counts
if arg.numel != arg_to_replace.numel:
if not _can_be_reshaped(arg, arg_to_replace):
view_groups.setdefault(variableify(arg), []).append(variableify(arg_to_replace))
continue
reshaped_arg = arg
if arg_to_replace.shape != arg.shape:
Expand Down Expand Up @@ -111,7 +126,8 @@ def replace_args_with_alias_map(
no_implicit_alias_trace.bound_symbols = bsyms
str_map = {unvariableify(k).name: v.name for k, v in swap_map_for_aliases.items()}
no_implicit_alias_trace.set_provenance(TraceProvenance(f"Duplicate alias args using {str_map}"))
return no_implicit_alias_trace, swap_map_for_aliases
view_groups = [{k}.union(set(v)) for k, v in view_groups.items() if len(v) != 0]
return no_implicit_alias_trace, view_groups


def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[list[int]]) -> Trace:
Expand All @@ -123,10 +139,10 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li

# First pass: identify inputs which are views of each other and swap them out with a default,
# reshaping if necessary.
computation_trace, _ = replace_args_with_alias_map(computation_trace, alias_tensor_indices)
computation_trace, input_view_groups = replace_args_with_alias_map(computation_trace, alias_tensor_indices)

# Second pass: identify views, their originals, and operands involved in inplace ops
view_groups = []
intermediate_view_groups = []
inplace_inputs = set()
for bsym in computation_trace.bound_symbols:
if _is_inplace_op(bsym) or _is_view_creation_op(bsym):
Expand All @@ -136,17 +152,21 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
if _is_inplace_op(bsym):
inplace_inputs.add(in_tensor)
out_tensors = set()
for group in view_groups:
for group in intermediate_view_groups:
Copy link
Collaborator

@shino16 shino16 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should iterate for groups in input_view_groups too, otherwise the following makes {x, y} and {y, y2} as separate view groups.

import torch, thunder

def f(x, y):
    y2 = y.view(-1)
    x.exp_()
    return y2.sin()

jf = thunder.jit(f)

x = torch.randn(3, device="cuda")
x_ref = x.clone().detach()
z = jf(x, x[0])
z_ref = f(x_ref, x_ref[0])

print(thunder.last_traces(jf)[-1])
torch.testing.assert_close(z, z_ref) # AssertionError

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought that we should merge overlapping groups to avoid what you pointed out. But I wasn't able to break the overall trace correctness. I guess because I tested on CPU and not on CUDA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a good point.

if in_tensor in group:
group.update(out_tensors)
break
else:
view_groups.append(out_tensors.union({in_tensor}))
intermediate_view_groups.append(out_tensors.union({in_tensor}))

# filter out view groups that don't have any tensors involved in inplace ops
view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0]
input_view_groups = [group for group in input_view_groups if len(group.intersection(inplace_inputs)) != 0]
intermediate_view_groups = [
group for group in intermediate_view_groups if len(group.intersection(inplace_inputs)) != 0
]
view_groups = input_view_groups + intermediate_view_groups
viewed = set(reduce(set.union, view_groups, set()))
encountered = set()
encountered = set(reduce(set.union, input_view_groups, set()))

# Third pass: insert alias updates
for bsym in computation_trace.bound_symbols:
Expand Down
22 changes: 22 additions & 0 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,25 @@ def foo(x):
expected_grad = torch.autograd.grad(expected, c, g)
torch.testing.assert_close(actual_grad_fx, expected_grad)
torch.testing.assert_close(actual_grad_jit, expected_grad)


@instantiate(
dtypes=(dtypes.float32,),
)
def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype):
def f(x, y, z):
return x + 2, y.add_(z)

a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device)
b = a[0, :]
c = a[1, :]
a_ = a.clone().detach()
b_ = a_[0, :]
c_ = a_[1, :]
jfn = executor.make_callable(f)
actual = jfn(a, b, c)
expected = f(a_, b_, c_)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(a, a_)
torch.testing.assert_close(b, b_)
torch.testing.assert_close(c, c_)
Loading