Skip to content

Commit 6fef86e

Browse files
ManfeiBaiyeounoh
andauthored
[Cherry-pick] Add doc images for auto-sharding (#6323) and Update spmd.md with SPMD debug tool (#6358) (#6395)
Co-authored-by: Yeounoh Chung <[email protected]>
1 parent 79fd3d9 commit 6fef86e

File tree

7 files changed

+28
-0
lines changed

7 files changed

+28
-0
lines changed
22.5 KB
Loading

docs/assets/gpt2_v4_8_mfu_batch.png

24 KB
Loading

docs/assets/llama2_2b_bsz128.png

22.5 KB
Loading

docs/assets/perf_auto_vs_manual.png

17.4 KB
Loading

docs/assets/spmd_debug_1.png

161 KB
Loading

docs/assets/spmd_debug_2.png

168 KB
Loading

docs/spmd.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,31 @@ XLA_USE_SPMD=1 python test/spmd/test_train_spmd_imagenet.py --fake_data --batch_
401401
```
402402

403403
Note that I used a batch size 4 times as large since I am running it on a TPU v4 which has 4 TPU devices attached to it. You should see the throughput becomes roughly 4x the non-spmd run.
404+
405+
### SPMD Debugging Tool
406+
407+
We provide a `shard placement visualization debug tool` for PyTorch/XLA SPMD user on TPU/GPU/CPU with single-host/multi-host: you could use `visualize_tensor_sharding` to visualize sharded tensor, or you could use `visualize_sharding` to visualize sharing string. Here are two code examples on TPU single-host(v4-8) with `visualize_tensor_sharding` or `visualize_sharding`:
408+
- Code snippet used `visualize_tensor_sharding` and visualization result:
409+
```python
410+
import rich
411+
412+
# Here, mesh is a 2x2 mesh with axes 'x' and 'y'
413+
t = torch.randn(8, 4, device='xla')
414+
xs.mark_sharding(t, mesh, ('x', 'y'))
415+
416+
# A tensor's sharding can be visualized using the `visualize_tensor_sharding` method
417+
from torch_xla.distributed.spmd.debugging import visualize_tensor_sharding
418+
generated_table = visualize_tensor_sharding(t, use_color=False)
419+
```
420+
![alt_text](assets/spmd_debug_1.png "visualize_tensor_sharding example on TPU v4-8(single-host)")
421+
- Code snippet used `visualize_sharding` and visualization result:
422+
```python
423+
from torch_xla.distributed.spmd.debugging import visualize_sharding
424+
sharding = '{devices=[2,2]0,1,2,3}'
425+
generated_table = visualize_sharding(sharding, use_color=False)
426+
```
427+
![alt_text](assets/spmd_debug_2.png "visualize_sharding example on TPU v4-8(single-host")
428+
429+
You could use these examples on TPU/GPU/CPU single-host and modify it to run on multi-host. And you could modify it to sharding-style `tiled`, `partial_replication` and `replicated`.
430+
431+

0 commit comments

Comments
 (0)