@@ -328,8 +328,9 @@ def g(a, b):
328328
329329@instantiate (
330330 dtypes = NOTHING ,
331+ decorators = (pytest .mark .parametrize ("cache" , ("constant values" , "symbolic values" )),),
331332)
332- def test_aliased_input (executor , device , dtype ):
333+ def test_aliased_input (executor , device , dtype , cache ):
333334 def f (x , y , z ):
334335 return y .exp_ ().add (x ) + z .exp ()
335336
@@ -339,7 +340,7 @@ def f(x, y, z):
339340 a_ = a .clone ().detach ()
340341 b_ = b .clone ().detach ()
341342 c_ = c .clone ().detach ()
342- jfn = executor .make_callable (f )
343+ jfn = executor .make_callable (f , cache = cache )
343344 actual = jfn (a , b , c )
344345 expected = f (a_ , b_ , c_ )
345346 torch .testing .assert_close (actual , expected )
@@ -350,8 +351,9 @@ def f(x, y, z):
350351
351352@instantiate (
352353 dtypes = NOTHING ,
354+ decorators = (pytest .mark .parametrize ("cache" , ("constant values" , "symbolic values" )),),
353355)
354- def test_write_to_intermediate_result (executor , device , dtype ):
356+ def test_write_to_intermediate_result (executor , device , dtype , cache ):
355357 if executor == nvFuserExecutor :
356358 pytest .xfail ("nvFuser does not support writing to intermediate results" )
357359
@@ -361,7 +363,7 @@ def fn(x):
361363 return y
362364
363365 a = make_tensor ((2 , 3 ), dtype = torch .float32 , device = device )
364- jfn = executor .make_callable (fn , skip_inplace_alias_updates = True )
366+ jfn = executor .make_callable (fn , cache = cache )
365367 actual = jfn (a )
366368 expected = fn (a )
367369 torch .testing .assert_close (actual , expected )
@@ -517,3 +519,26 @@ def foo(x):
517519 expected_grad = torch .autograd .grad (expected , c , g )
518520 torch .testing .assert_close (actual_grad_fx , expected_grad )
519521 torch .testing .assert_close (actual_grad_jit , expected_grad )
522+
523+
524+ @instantiate (
525+ dtypes = (dtypes .float32 ,),
526+ decorators = (pytest .mark .parametrize ("cache" , ("constant values" , "symbolic values" )),),
527+ )
528+ def test_aliasing_for_viewed_input_of_different_shapes (executor , device , dtype , cache ):
529+ def f (x , y , z ):
530+ return x + 2 , y .add_ (z )
531+
532+ a = make_tensor ((2 , 3 ), dtype = dtypes .to_torch_dtype (dtype ), device = device )
533+ b = a [0 , :]
534+ c = a [1 , :]
535+ a_ = a .clone ().detach ()
536+ b_ = a_ [0 , :]
537+ c_ = a_ [1 , :]
538+ jfn = executor .make_callable (f , cache = cache )
539+ actual = jfn (a , b , c )
540+ expected = f (a_ , b_ , c_ )
541+ torch .testing .assert_close (actual , expected )
542+ torch .testing .assert_close (a , a_ )
543+ torch .testing .assert_close (b , b_ )
544+ torch .testing .assert_close (c , c_ )
0 commit comments