Skip to content

Errors on in-place ops on tensor aliases unresolved by proxy substitution #2766

@shino16

Description

@shino16

Repro:

import torch, thunder

def f(a, b):
    return a.exp_() * b.tanh_()

def g(a, _):
    b = a.view(5,5)
    return a.exp_() * b.tanh_()

def h(a, _):
    b = a[0,0]
    return a.exp_() * b.tanh_()

for fn in [f, g, h]:
    jf = thunder.jit(fn)
    x = torch.randn(5, 5, device='cuda')
    x_ = x.detach().clone()
    out = jf(x, x[0, 0])
    out_ = fn(x_, x_[0, 0])

    torch.testing.assert_close(out, out_)
    # AssertionError on f, g and h

Found in #2760 (comment) and #2760 (comment) by @beverlylytle.

Trace of f after update_aliases.py:

# Constructed by Update aliases for in-place ops
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t6,) = prims.update_aliases((a,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:4: 	    return a.exp_() * b.tanh_()
  t1 = ltorch.exp_(t6)  # t1: "cuda:0 f32[5, 5]"
    # t0 = ltorch.exp(t6)  # t0: "cuda:0 f32[5, 5]"
      # t0 = prims.exp(t6)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t6, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  (t7,) = prims.update_aliases((b,))

  # /opt/pytorch/lightning-thunder/tmp/main.py:4: 	    return a.exp_() * b.tanh_()
  t3 = ltorch.tanh_(t7)  # t3: "cuda:0 f32[]"
    # t2 = ltorch.tanh(t7)  # t2: "cuda:0 f32[]"
      # t2 = prims.tanh(t7)  # t2: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t7, grad_enabled=True)  # t3: "cuda:0 f32[]"
  t5 = ltorch.mul(t1, t3)  # t5: "cuda:0 f32[5, 5]"
    # t4 = prims.broadcast_in_dim(t3, (5, 5), ())  # t4: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t1, t4)  # t5: "cuda:0 f32[5, 5]"
  return {'output': (t5,), 'flat_args': [t1, t3]}

Trace after fusion:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cuda:0 f32[5, 5]"
  # b: "cuda:0 f32[]"
  (t10,) = update_aliases((a,))
  del a
  [t1] = nvFusion0(t10)
    # t0 = prims.exp(t10)  # t0: "cuda:0 f32[5, 5]"
    # t1 = prims.copy_(t0, t10, grad_enabled=True)  # t1: "cuda:0 f32[5, 5]"
  del t10
  (t11,) = update_aliases((b,))
  del b
  [t3, t5] = nvFusion1(t11, t1)
    # t2 = prims.tanh(t11)  # t2: "cuda:0 f32[]"
    # t3 = prims.copy_(t2, t11, grad_enabled=True)  # t3: "cuda:0 f32[]"
    # t4 = prims.broadcast_in_dim(t3, (5, 5), ())  # t4: "cuda:0 f32[5, 5]"
    # t5 = prims.mul(t1, t4)  # t5: "cuda:0 f32[5, 5]"
  del t11
  return {'output': (t5,), 'flat_args': [t1, t3]}

The problem here is that nvFusion1 does not know that t11 and t1 share memory.

In order to make sure that t3 = prims.copy_(t2, t11, grad_enabled=True) completes before t5 = prims.mul(t1, t4), we could insert prims.update_aliases before prims.mul, which would fix the bugs because prims.update_aliases is unfusible.

Such solutions create more fusion breaks, so we want to minimize the use of prims.update_aliases. Ideally, we hope to make prims.update_aliases a fusible op and let nvFuser handle memory aliases in its combined region.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions