-
Notifications
You must be signed in to change notification settings - Fork 108
Open
Description
Repro:
import torch, thunder
def f(a, b):
return a.exp_() * b.tanh_()
def g(a, _):
b = a.view(5,5)
return a.exp_() * b.tanh_()
def h(a, _):
b = a[0,0]
return a.exp_() * b.tanh_()
for fn in [f, g, h]:
jf = thunder.jit(fn)
x = torch.randn(5, 5, device='cuda')
x_ = x.detach().clone()
out = jf(x, x[0, 0])
out_ = fn(x_, x_[0, 0])
torch.testing.assert_close(out, out_)
# AssertionError on f, g and hFound in #2760 (comment) and #2760 (comment) by @beverlylytle.
Trace of f after update_aliases.py:
# Constructed by Update aliases for in-place ops
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[5, 5]"
# b: "cuda:0 f32[]"
(t6,) = prims.update_aliases((a,))
# /opt/pytorch/lightning-thunder/tmp/main.py:4: return a.exp_() * b.tanh_()
t1 = ltorch.exp_(t6) # t1: "cuda:0 f32[5, 5]"
# t0 = ltorch.exp(t6) # t0: "cuda:0 f32[5, 5]"
# t0 = prims.exp(t6) # t0: "cuda:0 f32[5, 5]"
# t1 = prims.copy_(t0, t6, grad_enabled=True) # t1: "cuda:0 f32[5, 5]"
(t7,) = prims.update_aliases((b,))
# /opt/pytorch/lightning-thunder/tmp/main.py:4: return a.exp_() * b.tanh_()
t3 = ltorch.tanh_(t7) # t3: "cuda:0 f32[]"
# t2 = ltorch.tanh(t7) # t2: "cuda:0 f32[]"
# t2 = prims.tanh(t7) # t2: "cuda:0 f32[]"
# t3 = prims.copy_(t2, t7, grad_enabled=True) # t3: "cuda:0 f32[]"
t5 = ltorch.mul(t1, t3) # t5: "cuda:0 f32[5, 5]"
# t4 = prims.broadcast_in_dim(t3, (5, 5), ()) # t4: "cuda:0 f32[5, 5]"
# t5 = prims.mul(t1, t4) # t5: "cuda:0 f32[5, 5]"
return {'output': (t5,), 'flat_args': [t1, t3]}Trace after fusion:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[5, 5]"
# b: "cuda:0 f32[]"
(t10,) = update_aliases((a,))
del a
[t1] = nvFusion0(t10)
# t0 = prims.exp(t10) # t0: "cuda:0 f32[5, 5]"
# t1 = prims.copy_(t0, t10, grad_enabled=True) # t1: "cuda:0 f32[5, 5]"
del t10
(t11,) = update_aliases((b,))
del b
[t3, t5] = nvFusion1(t11, t1)
# t2 = prims.tanh(t11) # t2: "cuda:0 f32[]"
# t3 = prims.copy_(t2, t11, grad_enabled=True) # t3: "cuda:0 f32[]"
# t4 = prims.broadcast_in_dim(t3, (5, 5), ()) # t4: "cuda:0 f32[5, 5]"
# t5 = prims.mul(t1, t4) # t5: "cuda:0 f32[5, 5]"
del t11
return {'output': (t5,), 'flat_args': [t1, t3]}The problem here is that nvFusion1 does not know that t11 and t1 share memory.
In order to make sure that t3 = prims.copy_(t2, t11, grad_enabled=True) completes before t5 = prims.mul(t1, t4), we could insert prims.update_aliases before prims.mul, which would fix the bugs because prims.update_aliases is unfusible.
Such solutions create more fusion breaks, so we want to minimize the use of prims.update_aliases. Ideally, we hope to make prims.update_aliases a fusible op and let nvFuser handle memory aliases in its combined region.