Skip to content

Commit 053a6f2

Browse files
authored
[FSDPv2] Add a user guide (#6408)
Summary: This diff adds a user guide for FSDPv2.
1 parent 6fef86e commit 053a6f2

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

docs/fsdpv2.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Fully Sharded Data Parallel via SPMD
2+
3+
Fully Sharded Data Parallel via SPMD or FSDPv2 is an utility that re-epxresses the famous FSDP algorithm in SPMD. [This](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/spmd_fully_sharded_data_parallel.py) is
4+
an experimental feature that aiming to offer a familiar interface for users to enjoy all the benefits that SPMD brings into
5+
the table. The design doc is [here](https://github.com/pytorch/xla/issues/6379).
6+
7+
Please review the [SPMD user guide](./spmd.md) before proceeding.
8+
9+
Example usage:
10+
```python3
11+
import torch
12+
import torch_xla.core.xla_model as xm
13+
import torch_xla.distributed.spmd as xs
14+
from torch_xla.experimental.spmd_fully_sharded_data_parallel import SpmdFullyShardedDataParallel as FSDPv2
15+
16+
# Define the mesh following common SPMD practice
17+
num_devices = xr.global_runtime_device_count()
18+
mesh_shape = (num_devices, 1)
19+
device_ids = np.array(range(num_devices))
20+
# To be noted, the mesh must have an axis named 'fsdp', which the weights and activations will be sharded on.
21+
mesh = Mesh(device_ids, mesh_shape, ('fsdp', 'model'))
22+
23+
# Shard the input, and assume x is a 2D tensor.
24+
x = xs.mark_sharding(x, mesh, ('fsdp', None))
25+
26+
# As normal FSDP, but an extra mesh is needed.
27+
model = FSDPv2(my_module, mesh)
28+
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
29+
output = model(x, y)
30+
loss = output.sum()
31+
loss.backward()
32+
optim.step()
33+
```
34+
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters. The autowrapping
35+
feature will come in the future releases.
36+
37+
## Sharding output
38+
39+
To ensure the XLA compiler correctly implements the FSDP algorithm, we need to shard both weights and activations. This means sharding the output of the forward method. Since the forward function output can vary, we offer shard_output to shard activations in cases where your module output doesn't fall into one of these categories:
40+
1. A single tensor
41+
2. A tuple of tensors where the 0th element is the activation.
42+
43+
Example usage:
44+
```python3
45+
def shard_output(output, mesh):
46+
xs.mark_sharding(output.logits, mesh, ('fsdp', None, None))
47+
48+
model = FSDPv2(my_module, mesh, shard_output)
49+
```
50+
51+
## Gradient checkpointing
52+
53+
Currently, gradient checkpointing needs to be applied to the module before the FSDP wrapper. Otherwise, recursively loop into children modules will end up with infinite loop. We will fix this issue in the future releases.
54+
55+
Example usage:
56+
```python3
57+
from torch_xla.distributed.fsdp import checkpoint_module
58+
59+
model = FSDPv2(checkpoint_module(my_module), mesh)
60+
```
61+
62+
## HuggingFace Llama 2 Example
63+
We have a fork of HF Llama 2 to demonstrate a potential integration [here](https://github.com/huggingface/transformers/compare/main...pytorch-tpu:transformers:llama2-spmd-fsdp).

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ test
8787
.. mdinclude:: ../ddp.md
8888
.. mdinclude:: ../gpu.md
8989
.. mdinclude:: ../spmd.md
90+
.. mdinclude:: ../fsdpv2.md

0 commit comments

Comments
 (0)