Skip to content

Commit 79fd3d9

Browse files
authored
[backport] Add documentation around CheckpointManager (#6378)
1 parent c0e522c commit 79fd3d9

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

docs/spmd.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,56 @@ dist_cp.load_state_dict(
270270
model.load_state_dict(state_dict["model"])
271271
```
272272

273+
#### CheckpointManager
274+
275+
The experimental [CheckpointManager](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint/manager.py#L40)
276+
interface provides a higher-level API over the `torch.distributed.checkpoint`
277+
functions to enable a few key features:
278+
279+
- **Managed checkpoints**: Each checkpoint taken by the `CheckpointManager` is
280+
identified by the step at which it was taken. All steps tracked are accessible
281+
through the `CheckpointManager.all_steps` method, and any tracked steps can be
282+
restored using `CheckpointManager.restore`.
283+
- **Asynchronous checkpointing**: Checkpoints taken through the
284+
`CheckpointManager.save_async` API are written to persistent storage
285+
asynchronously to unblock training for the duration of the checkpoint. The
286+
input sharded state_dict is first moved to CPU before the checkpoint is
287+
dispatched to a background thread.
288+
- **Auto-checkpointing on preemption**: On Cloud TPU, preemptions can be detected
289+
and a checkpoint taken before the process is terminated. To use, ensure your
290+
TPU is provisioned through a QueuedResource with
291+
[Autocheckpointing enabled](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/queued-resources/create#--autocheckpoint-enabled),
292+
and ensure the `chkpt_on_preemption` parameter is set when constructing the
293+
CheckpointManager (this option is enabled by default).
294+
- **FSSpec Support**: `CheckpointManager` uses an fsspec storage backend to enable
295+
checkpointing directly to any fsspec-compatible filesystem, including GCS.
296+
297+
Example usage of the CheckpointManager is below:
298+
299+
```python
300+
from torch_xla.experimental.distributed_checkpoint import CheckpointManager
301+
302+
# Create a CheckpointManager to checkpoint every 10 steps into GCS.
303+
chkpt_mgr = CheckpointManager('gs://my-bucket/my-experiment', 10)
304+
305+
# Select a checkpoint to restore from, and restore if applicable
306+
tracked_steps = chkpt_mgr.all_steps()
307+
if tracked_steps:
308+
# Choose the highest step
309+
best_step = max(tracked_steps)
310+
state_dict = {'model': model.state_dict()}
311+
chkpt_mgr.restore(best_step, state_dict)
312+
model.load_state_dict(state_dict['model'])
313+
314+
# Call `save` or `save_async` every step within the train loop. These methods
315+
# return True when a checkpoint is taken.
316+
for step, data in enumerate(dataloader):
317+
...
318+
state_dict = {'model': model.state_dict(), 'optim': optim.state_dict()}
319+
if chkpt_mgr.save_async(step, state_dict):
320+
print(f'Checkpoint taken at step {step}')
321+
```
322+
273323
### Virtual Device Optimization
274324

275325
PyTorch/XLA normally transfers tensor data asynchronously from host to device once the tensor is defined. This is to overlap the data transfer with the graph tracing time. However, because GSPMD allows the user to modify the tensor sharding _after _the tensor has been defined, we need an optimization to prevent unnecessary transfer of tensor data back and forth between host and device. We introduce Virtual Device Optimization, a technique to place the tensor data on a virtual device SPMD:0 first, before uploading to the physical devices when all the sharding decisions are finalized. Every tensor data in SPMD mode is placed on a virtual device, SPMD:0. The virtual device is exposed to the user as an XLA device XLA:0 with the actual shards on physical devices, like TPU:0, TPU:1, etc.

0 commit comments

Comments
 (0)