Skip to content

update_aliases misses case when function's input views itself #2756

@beverlylytle

Description

@beverlylytle

The following:

def f(x, y):
    x.add_(1)
    y.mul_(2)
    return x, y

jf = thunder.jit(f)
x = torch.randn(5)
y = x[:2]
jf(x,y)
print(thunder.last_traces(jf))

yields

def computation(x, y):
  # x: "cpu f32[5]"
  # y: "cpu f32[2]"
  (t12,) = update_aliases((x,))
  del x

  # /home/blytle/scratch/symb_arth.py:67:           x.add_(1)
  t13 = torch.add(t12, 1, alpha=1)  # t13: "cpu f32[5]"
    # t13 = ltorch.add(t12, 1, alpha=1)  # t13: "cpu f32[5]"
      # t13 = prims.add(t12, 1.0)  # t13: "cpu f32[5]"
  t14 = copy_(t13, t12, grad_enabled=True)  # t14: "cpu f32[5]"
  del t13, t12
  (t15,) = update_aliases((y,))
  del y

  # /home/blytle/scratch/symb_arth.py:68:           y.mul_(2)
  t16 = torch.mul(t15, 2)  # t16: "cpu f32[2]"
    # t16 = ltorch.mul(t15, 2)  # t16: "cpu f32[2]"
      # t16 = prims.mul(t15, 2.0)  # t16: "cpu f32[2]"
  t17 = copy_(t16, t15, grad_enabled=True)  # t17: "cpu f32[2]"
  del t16, t15
  return (t14, t17)

Since y is a view of x, we would expect that update_aliases is called on the group (x,y) rather than each separately. Without this grouping, topological reordering of the in-place operations can produce erroneous results silently.

The function replace_args_with_alias_map which is supposed to replace occurrences of the view-input with the base tensor ignores the case when a view has a different number of elements as the base. One possible solution would be for replace_args_with_alias_map to collect the view groups coming from the input tensors and return that within insert_alias_updates to act as the initial values of view_groups.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions