You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Note that I used a batch size 4 times as large since I am running it on a TPU v4 which has 4TPU devices attached to it. You should see the throughput becomes roughly 4x the non-spmd run.
404
+
405
+
### SPMD Debugging Tool
406
+
407
+
We provide a `shard placement visualization debug tool`for PyTorch/XLASPMD user on TPU/GPU/CPUwith single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with`visualize_tensor_sharding`or`visualize_sharding`:
408
+
- Code snippet used `visualize_tensor_sharding`and visualization result:
409
+
```python
410
+
import rich
411
+
412
+
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
413
+
t= torch.randn(8, 4, device='xla')
414
+
xs.mark_sharding(t, mesh, ('x', 'y'))
415
+
416
+
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
417
+
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding

428
+
429
+
You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication`and`replicated`.
0 commit comments