Skip to content

Commit 879f8bf

Browse files
authored
Set TensorProxy shape's history to None if tensor's history is unavailable (#2755)
Co-authored-by: Masato Shinokawa <[email protected]>
1 parent 656656b commit 879f8bf

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

thunder/core/proxies.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,12 +2011,21 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
20112011
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
20122012
# For parameters, shapes should be static.
20132013
if using_symbolic_values() and not isinstance(t, torch.nn.Parameter):
2014-
shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance])
2014+
if history is not None:
2015+
shape_pr = ProvenanceRecord(
2016+
PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance]
2017+
)
2018+
dim_pr = lambda idx: ProvenanceRecord(
2019+
PseudoInst.BINARY_SUBSCR, inputs=[shape_pr, wrap_const(idx).provenance]
2020+
)
2021+
else:
2022+
dim_pr = lambda idx: None
2023+
20152024
shape = tuple(
20162025
IntegerProxy(
20172026
None,
20182027
s,
2019-
history=ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance]),
2028+
history=dim_pr(idx),
20202029
constraint=CONSTRAINT.CONSTRAINABLE,
20212030
)
20222031
for idx, s in enumerate(t.shape)

thunder/tests/distributed/test_dtensor.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import thunder
1414

1515
from thunder.tests.distributed.helper import DistributedParallelTestCase
16-
from torch.distributed._tensor import DeviceMesh, distribute_tensor
16+
from torch.distributed.tensor import DTensor, DeviceMesh, distribute_tensor
1717
from torch.distributed.tensor.placement_types import Shard, Replicate
1818
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter
1919
from torch.distributed.tensor.parallel import (
@@ -462,6 +462,33 @@ def test_dtensor_opinfo(self, op: OpInfo, executor):
462462

463463
assert tested_sample_count > 0, f"test_dtensor_opinfo:No samples tested for {op.name} with {executor} executor"
464464

465+
def test_dtensor_from_local_symbolic_values(self):
466+
num_devices = self.world_size
467+
mesh = DeviceMesh("cuda", list(range(num_devices)))
468+
469+
dim_size = 8
470+
local_tensor = torch.randn(dim_size, dim_size, device="cuda")
471+
472+
def fn(x):
473+
return DTensor.from_local(x, mesh, [Shard(0)])
474+
475+
tjit = thunder.jit(fn, cache="symbolic values")
476+
477+
actual = tjit(local_tensor)
478+
expected = DTensor.from_local(local_tensor, mesh, [Shard(0)])
479+
480+
torch.testing.assert_close(actual, expected)
481+
assert thunder.cache_misses(tjit) == 1
482+
assert thunder.cache_hits(tjit) == 0
483+
484+
dim_size = 16
485+
local_tensor = torch.randn(dim_size, dim_size, device="cuda")
486+
actual = tjit(local_tensor)
487+
expected = DTensor.from_local(local_tensor, mesh, [Shard(0)])
488+
torch.testing.assert_close(actual, expected)
489+
assert thunder.cache_misses(tjit) == 1
490+
assert thunder.cache_hits(tjit) == 1
491+
465492

466493
common_utils.instantiate_parametrized_tests(DTensorTest)
467494

0 commit comments

Comments
 (0)