|
13 | 13 | import thunder |
14 | 14 |
|
15 | 15 | from thunder.tests.distributed.helper import DistributedParallelTestCase |
16 | | -from torch.distributed._tensor import DeviceMesh, distribute_tensor |
| 16 | +from torch.distributed.tensor import DTensor, DeviceMesh, distribute_tensor |
17 | 17 | from torch.distributed.tensor.placement_types import Shard, Replicate |
18 | 18 | from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter |
19 | 19 | from torch.distributed.tensor.parallel import ( |
@@ -462,6 +462,33 @@ def test_dtensor_opinfo(self, op: OpInfo, executor): |
462 | 462 |
|
463 | 463 | assert tested_sample_count > 0, f"test_dtensor_opinfo:No samples tested for {op.name} with {executor} executor" |
464 | 464 |
|
| 465 | + def test_dtensor_from_local_symbolic_values(self): |
| 466 | + num_devices = self.world_size |
| 467 | + mesh = DeviceMesh("cuda", list(range(num_devices))) |
| 468 | + |
| 469 | + dim_size = 8 |
| 470 | + local_tensor = torch.randn(dim_size, dim_size, device="cuda") |
| 471 | + |
| 472 | + def fn(x): |
| 473 | + return DTensor.from_local(x, mesh, [Shard(0)]) |
| 474 | + |
| 475 | + tjit = thunder.jit(fn, cache="symbolic values") |
| 476 | + |
| 477 | + actual = tjit(local_tensor) |
| 478 | + expected = DTensor.from_local(local_tensor, mesh, [Shard(0)]) |
| 479 | + |
| 480 | + torch.testing.assert_close(actual, expected) |
| 481 | + assert thunder.cache_misses(tjit) == 1 |
| 482 | + assert thunder.cache_hits(tjit) == 0 |
| 483 | + |
| 484 | + dim_size = 16 |
| 485 | + local_tensor = torch.randn(dim_size, dim_size, device="cuda") |
| 486 | + actual = tjit(local_tensor) |
| 487 | + expected = DTensor.from_local(local_tensor, mesh, [Shard(0)]) |
| 488 | + torch.testing.assert_close(actual, expected) |
| 489 | + assert thunder.cache_misses(tjit) == 1 |
| 490 | + assert thunder.cache_hits(tjit) == 1 |
| 491 | + |
465 | 492 |
|
466 | 493 | common_utils.instantiate_parametrized_tests(DTensorTest) |
467 | 494 |
|
|
0 commit comments