-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Dedicated docs page for distributed checkpoints (#19287)
- Loading branch information
Showing
10 changed files
with
259 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
186 changes: 186 additions & 0 deletions
186
docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
########################################## | ||
Saving and Loading Distributed Checkpoints | ||
########################################## | ||
|
||
Generally, the bigger your model is, the longer it takes to save a checkpoint to disk. | ||
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues. | ||
|
||
|
||
---- | ||
|
||
|
||
***************************** | ||
Save a distributed checkpoint | ||
***************************** | ||
|
||
The distributed checkpoint format is the default when you train with the :doc:`FSDP strategy <../../advanced/model_parallel/fsdp>`. | ||
|
||
.. code-block:: python | ||
import lightning as L | ||
from lightning.fabric.strategies import FSDPStrategy | ||
# 1. Select the FSDP strategy | ||
strategy = FSDPStrategy( | ||
# Default: sharded/distributed checkpoint | ||
state_dict_type="sharded", | ||
# Full checkpoint (not distributed) | ||
# state_dict_type="full", | ||
) | ||
fabric = L.Fabric(devices=2, strategy=strategy, ...) | ||
fabric.launch() | ||
... | ||
model, optimizer = fabric.setup(model, optimizer) | ||
# 2. Define model, optimizer, and other training loop state | ||
state = {"model": model, "optimizer": optimizer, "iter": iteration} | ||
# DON'T do this (inefficient): | ||
# state = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), ...} | ||
# 3. Save using Fabric's method | ||
fabric.save("path/to/checkpoint/file", state) | ||
# DON'T do this (inefficient): | ||
# torch.save("path/to/checkpoint/file", state) | ||
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path. | ||
This reduces memory peaks and speeds up the saving to disk. | ||
The resulting checkpoint folder will have this structure: | ||
|
||
.. collapse:: Full example | ||
|
||
.. code-block:: python | ||
import time | ||
import torch | ||
import torch.nn.functional as F | ||
import lightning as L | ||
from lightning.fabric.strategies import FSDPStrategy | ||
from lightning.pytorch.demos import Transformer, WikiText2 | ||
strategy = FSDPStrategy(state_dict_type="sharded") | ||
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) | ||
fabric.launch() | ||
with fabric.rank_zero_first(): | ||
dataset = WikiText2() | ||
# 1B parameters | ||
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
model, optimizer = fabric.setup(model, optimizer) | ||
state = {"model": model, "optimizer": optimizer, "iteration": 0} | ||
for i in range(10): | ||
input, target = fabric.to_device(dataset[i]) | ||
output = model(input.unsqueeze(0), target.unsqueeze(0)) | ||
loss = F.nll_loss(output, target.view(-1)) | ||
fabric.backward(loss) | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
fabric.print(loss.item()) | ||
fabric.print("Saving checkpoint ...") | ||
t0 = time.time() | ||
fabric.save("my-checkpoint.ckpt", state) | ||
fabric.print(f"Took {time.time() - t0:.2f} seconds.") | ||
Check the contents of the checkpoint folder: | ||
|
||
.. code-block:: bash | ||
ls -a my-checkpoint.ckpt/ | ||
.. code-block:: | ||
my-checkpoint.ckpt/ | ||
├── __0_0.distcp | ||
├── __1_0.distcp | ||
├── __2_0.distcp | ||
├── __3_0.distcp | ||
└── meta.pt | ||
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files | ||
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs. | ||
|
||
|
||
---- | ||
|
||
|
||
***************************** | ||
Load a distributed checkpoint | ||
***************************** | ||
|
||
You can easily load a distributed checkpoint in Fabric if your script uses :doc:`FSDP <../../advanced/model_parallel/fsdp>`. | ||
|
||
.. code-block:: python | ||
import lightning as L | ||
from lightning.fabric.strategies import FSDPStrategy | ||
# 1. Select the FSDP strategy | ||
fabric = L.Fabric(devices=2, strategy=FSDPStrategy(), ...) | ||
fabric.launch() | ||
... | ||
model, optimizer = fabric.setup(model, optimizer) | ||
# 2. Define model, optimizer, and other training loop state | ||
state = {"model": model, "optimizer": optimizer, "iter": iteration} | ||
# 3. Load using Fabric's method | ||
fabric.load("path/to/checkpoint/file", state) | ||
# DON'T do this (inefficient): | ||
# model.load_state_dict(torch.load("path/to/checkpoint/file")) | ||
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint. | ||
|
||
.. collapse:: Full example | ||
|
||
.. code-block:: python | ||
import torch | ||
import lightning as L | ||
from lightning.fabric.strategies import FSDPStrategy | ||
from lightning.pytorch.demos import Transformer, WikiText2 | ||
strategy = FSDPStrategy(state_dict_type="sharded") | ||
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy) | ||
fabric.launch() | ||
with fabric.rank_zero_first(): | ||
dataset = WikiText2() | ||
# 1B parameters | ||
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
model, optimizer = fabric.setup(model, optimizer) | ||
state = {"model": model, "optimizer": optimizer, "iteration": 0} | ||
fabric.print("Loading checkpoint ...") | ||
fabric.load("my-checkpoint.ckpt", state) | ||
.. important:: | ||
|
||
If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Fabric at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`. | ||
|
||
|
||
---- | ||
|
||
|
||
.. _Convert dist-checkpoint: | ||
|
||
******************************** | ||
Convert a distributed checkpoint | ||
******************************** | ||
|
||
Coming soon. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
########### | ||
Checkpoints | ||
########### | ||
|
||
.. raw:: html | ||
|
||
<div class="display-card-container"> | ||
<div class="row"> | ||
|
||
.. displayitem:: | ||
:header: Save and load model progress | ||
:description: Efficient saving and loading of model weights, training state, hyperparameters and more. | ||
:button_link: checkpoint.html | ||
:col_css: col-md-4 | ||
:height: 150 | ||
:tag: intermediate | ||
|
||
.. displayitem:: | ||
:header: Working with very large models | ||
:description: Save and load very large models efficiently with distributed checkpoints | ||
:button_link: distributed_checkpoint.html | ||
:col_css: col-md-4 | ||
:height: 150 | ||
:tag: advanced | ||
|
||
|
||
.. raw:: html | ||
|
||
</div> | ||
</div> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters