diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index b65d6944e0..8c2d0f6324 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1043,12 +1043,13 @@ def full( fd: FusionDefinition, lc_to_nv_map: dict, ) -> Any: + nv_shape = getnv(shape, fd, lc_to_nv_map, inline_number=True) nv_fill_value = getnv(fill_value, fd, lc_to_nv_map) nvdtype = lcdtype_to_nvdtype(dtype) _select_device(fd, device) - return fd.ops.full(shape, nv_fill_value, nvdtype) + return fd.ops.full(nv_shape, nv_fill_value, nvdtype) register_supported(PrimIDs.FULL, full, _full_check) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index c68a70c559..4d341cfcb5 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -942,6 +942,27 @@ def embedding_fn(inputs): torch.testing.assert_close(out, expected_out) +@instantiate( + executors=(nvFuserExecutor,), + dtypes=NOTHING, +) +def test_full_symbolic_values(executor, device: str, dtype: dtypes.dtype): + def foo(a): + # TODO: 'device=device' doesn't work for "symbolic values" cache policy + # See issue: https://github.com/Lightning-AI/lightning-thunder/issues/1710 + return torch.full(a.shape, 0, device="cuda", dtype=dtype) + + jfoo = thunder.jit(foo, cache="symbolic values") + + for shape in ((2, 3), (3, 2)): + a = torch.randn(shape, device=device) + actual = jfoo(a) + expected = foo(a) + torch.testing.assert_close(actual, expected) + + assert thunder.cache_misses(jfoo) == 1 + + @instantiate( executors=(nvFuserExecutor,), dtypes=NOTHING,