-
Notifications
You must be signed in to change notification settings - Fork 108
Closed
Description
The following:
def f(x, y):
x.add_(1)
y.mul_(2)
return x, y
jf = thunder.jit(f)
x = torch.randn(5)
y = x[:2]
jf(x,y)
print(thunder.last_traces(jf))
yields
def computation(x, y):
# x: "cpu f32[5]"
# y: "cpu f32[2]"
(t12,) = update_aliases((x,))
del x
# /home/blytle/scratch/symb_arth.py:67: x.add_(1)
t13 = torch.add(t12, 1, alpha=1) # t13: "cpu f32[5]"
# t13 = ltorch.add(t12, 1, alpha=1) # t13: "cpu f32[5]"
# t13 = prims.add(t12, 1.0) # t13: "cpu f32[5]"
t14 = copy_(t13, t12, grad_enabled=True) # t14: "cpu f32[5]"
del t13, t12
(t15,) = update_aliases((y,))
del y
# /home/blytle/scratch/symb_arth.py:68: y.mul_(2)
t16 = torch.mul(t15, 2) # t16: "cpu f32[2]"
# t16 = ltorch.mul(t15, 2) # t16: "cpu f32[2]"
# t16 = prims.mul(t15, 2.0) # t16: "cpu f32[2]"
t17 = copy_(t16, t15, grad_enabled=True) # t17: "cpu f32[2]"
del t16, t15
return (t14, t17)
Since y is a view of x, we would expect that update_aliases is called on the group (x,y) rather than each separately. Without this grouping, topological reordering of the in-place operations can produce erroneous results silently.
The function replace_args_with_alias_map which is supposed to replace occurrences of the view-input with the base tensor ignores the case when a view has a different number of elements as the base. One possible solution would be for replace_args_with_alias_map to collect the view groups coming from the input tensors and return that within insert_alias_updates to act as the initial values of view_groups.
shino16
Metadata
Metadata
Assignees
Labels
No labels