Skip to content

Commit 1a2c7d9

Browse files
shino16beverlylytle
andcommitted
Add tests
Co-authored-by: beverlylytle <[email protected]>
1 parent f0a5681 commit 1a2c7d9

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

thunder/tests/test_update_aliases.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,9 @@ def g(a, b):
328328

329329
@instantiate(
330330
dtypes=NOTHING,
331+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
331332
)
332-
def test_aliased_input(executor, device, dtype):
333+
def test_aliased_input(executor, device, dtype, cache):
333334
def f(x, y, z):
334335
return y.exp_().add(x) + z.exp()
335336

@@ -339,7 +340,7 @@ def f(x, y, z):
339340
a_ = a.clone().detach()
340341
b_ = b.clone().detach()
341342
c_ = c.clone().detach()
342-
jfn = executor.make_callable(f)
343+
jfn = executor.make_callable(f, cache=cache)
343344
actual = jfn(a, b, c)
344345
expected = f(a_, b_, c_)
345346
torch.testing.assert_close(actual, expected)
@@ -350,8 +351,9 @@ def f(x, y, z):
350351

351352
@instantiate(
352353
dtypes=NOTHING,
354+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
353355
)
354-
def test_write_to_intermediate_result(executor, device, dtype):
356+
def test_write_to_intermediate_result(executor, device, dtype, cache):
355357
if executor == nvFuserExecutor:
356358
pytest.xfail("nvFuser does not support writing to intermediate results")
357359

@@ -361,7 +363,7 @@ def fn(x):
361363
return y
362364

363365
a = make_tensor((2, 3), dtype=torch.float32, device=device)
364-
jfn = executor.make_callable(fn, skip_inplace_alias_updates=True)
366+
jfn = executor.make_callable(fn, cache=cache)
365367
actual = jfn(a)
366368
expected = fn(a)
367369
torch.testing.assert_close(actual, expected)
@@ -517,3 +519,26 @@ def foo(x):
517519
expected_grad = torch.autograd.grad(expected, c, g)
518520
torch.testing.assert_close(actual_grad_fx, expected_grad)
519521
torch.testing.assert_close(actual_grad_jit, expected_grad)
522+
523+
524+
@instantiate(
525+
dtypes=(dtypes.float32,),
526+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
527+
)
528+
def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype, cache):
529+
def f(x, y, z):
530+
return x + 2, y.add_(z)
531+
532+
a = make_tensor((2, 3), dtype=dtypes.to_torch_dtype(dtype), device=device)
533+
b = a[0, :]
534+
c = a[1, :]
535+
a_ = a.clone().detach()
536+
b_ = a_[0, :]
537+
c_ = a_[1, :]
538+
jfn = executor.make_callable(f, cache=cache)
539+
actual = jfn(a, b, c)
540+
expected = f(a_, b_, c_)
541+
torch.testing.assert_close(actual, expected)
542+
torch.testing.assert_close(a, a_)
543+
torch.testing.assert_close(b, b_)
544+
torch.testing.assert_close(c, c_)

0 commit comments

Comments
 (0)