Skip to content

Commit

Permalink
Dedicated docs page for distributed checkpoints (#19287)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 16, 2024
1 parent 052c0d5 commit a4ecf8d
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 13 deletions.
5 changes: 3 additions & 2 deletions docs/source-fabric/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ The resulting checkpoint folder will have this structure:
├── .metadata
├── __0_0.distcp
├── __1_0.distcp
...
└── meta.pt
The “sharded” checkpoint format is the most efficient to save and load in Fabric.
Expand All @@ -374,7 +375,7 @@ However, if you prefer to have a single consolidated file instead, you can confi
**Which checkpoint format should I use?**

- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though).
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to :doc:`convert the sharded checkpoint into a regular checkpoint file <../../guide/checkpoint/distributed_checkpoint>`.
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.


Expand All @@ -400,7 +401,7 @@ You can easily load checkpoints saved by Fabric to resume training:
Fabric will automatically recognize whether the provided path contains a checkpoint saved with ``state_dict_type="full"`` or ``state_dict_type="sharded"``.
Checkpoints saved with ``state_dict_type="full"`` can be loaded by all strategies, but sharded checkpoints can only be loaded by FSDP.
Read :doc:`the checkpoints guide <../../guide/checkpoint>` to explore more features.
Read :doc:`the checkpoints guide <../../guide/checkpoint/index>` to explore more features.


----
Expand Down
6 changes: 3 additions & 3 deletions docs/source-fabric/api/fabric_methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Fabric will handle the saving part correctly, whether running a single device, m
You should pass the model and optimizer objects directly into the dictionary so Fabric can unwrap them and automatically retrieve their *state-dict*.

See also: :doc:`../guide/checkpoint`
See also: :doc:`../guide/checkpoint/index`


load
Expand Down Expand Up @@ -248,7 +248,7 @@ Fabric will handle the loading part correctly, whether running a single device,
To load the state of your model or optimizer from a raw PyTorch checkpoint (not saved with Fabric), use :meth:`~lightning.fabric.fabric.Fabric.load_raw` instead.
See also: :doc:`../guide/checkpoint`
See also: :doc:`../guide/checkpoint/index`


load_raw
Expand All @@ -267,7 +267,7 @@ Load the state-dict of a model or optimizer from a raw PyTorch checkpoint not sa
# model.load_state_dict(torch.load("path/to/model.pt"))
See also: :doc:`../guide/checkpoint`
See also: :doc:`../guide/checkpoint/index`


barrier
Expand Down
8 changes: 7 additions & 1 deletion docs/source-fabric/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
Glossary
########

.. toctree::
:maxdepth: 1
:hidden:

Checkpoint <../guide/checkpoint/index>


.. raw:: html

Expand Down Expand Up @@ -45,7 +51,7 @@ Glossary

.. displayitem::
:header: Checkpoint
:button_link: ../guide/checkpoint.html
:button_link: ../guide/checkpoint/index.html
:col_css: col-md-4

.. displayitem::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ To save the state to the filesystem, pass it to the :meth:`~lightning.fabric.fab
This will unwrap your model and optimizer and automatically convert their ``state_dict`` for you.
Fabric and the underlying strategy will decide in which format your checkpoint gets saved.
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` saves multiple files from all ranks.
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` :doc:`saves multiple files from all ranks <distributed_checkpoint>`.


----
Expand Down Expand Up @@ -85,7 +85,7 @@ If you want to be in complete control of how states get restored, you can omit p
optimizer.load_state_dict(full_checkpoint["optimizer"])
...
See also: :doc:`../advanced/model_init`
See also: :doc:`../../advanced/model_init`


From a raw state-dict file
Expand Down Expand Up @@ -195,13 +195,19 @@ Here's an example of using a filter when saving a checkpoint:
Next steps
**********

Learn from our template how Fabrics checkpoint mechanism can be integrated into a full Trainer:

.. raw:: html

<div class="display-card-container">
<div class="row">

.. displayitem::
:header: Working with very large models
:description: Save and load very large models efficiently with distributed checkpoints
:button_link: distributed_checkpoint.html
:col_css: col-md-4
:height: 150
:tag: advanced

.. displayitem::
:header: Trainer Template
:description: Take our Fabric Trainer template and customize it for your needs
Expand Down
186 changes: 186 additions & 0 deletions docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
##########################################
Saving and Loading Distributed Checkpoints
##########################################

Generally, the bigger your model is, the longer it takes to save a checkpoint to disk.
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues.


----


*****************************
Save a distributed checkpoint
*****************************

The distributed checkpoint format is the default when you train with the :doc:`FSDP strategy <../../advanced/model_parallel/fsdp>`.

.. code-block:: python
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
# 1. Select the FSDP strategy
strategy = FSDPStrategy(
# Default: sharded/distributed checkpoint
state_dict_type="sharded",
# Full checkpoint (not distributed)
# state_dict_type="full",
)
fabric = L.Fabric(devices=2, strategy=strategy, ...)
fabric.launch()
...
model, optimizer = fabric.setup(model, optimizer)
# 2. Define model, optimizer, and other training loop state
state = {"model": model, "optimizer": optimizer, "iter": iteration}
# DON'T do this (inefficient):
# state = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), ...}
# 3. Save using Fabric's method
fabric.save("path/to/checkpoint/file", state)
# DON'T do this (inefficient):
# torch.save("path/to/checkpoint/file", state)
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
This reduces memory peaks and speeds up the saving to disk.
The resulting checkpoint folder will have this structure:

.. collapse:: Full example

.. code-block:: python
import time
import torch
import torch.nn.functional as F
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2
strategy = FSDPStrategy(state_dict_type="sharded")
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
fabric.launch()
with fabric.rank_zero_first():
dataset = WikiText2()
# 1B parameters
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "iteration": 0}
for i in range(10):
input, target = fabric.to_device(dataset[i])
output = model(input.unsqueeze(0), target.unsqueeze(0))
loss = F.nll_loss(output, target.view(-1))
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
fabric.print(loss.item())
fabric.print("Saving checkpoint ...")
t0 = time.time()
fabric.save("my-checkpoint.ckpt", state)
fabric.print(f"Took {time.time() - t0:.2f} seconds.")
Check the contents of the checkpoint folder:

.. code-block:: bash
ls -a my-checkpoint.ckpt/
.. code-block::
my-checkpoint.ckpt/
├── __0_0.distcp
├── __1_0.distcp
├── __2_0.distcp
├── __3_0.distcp
└── meta.pt
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs.


----


*****************************
Load a distributed checkpoint
*****************************

You can easily load a distributed checkpoint in Fabric if your script uses :doc:`FSDP <../../advanced/model_parallel/fsdp>`.

.. code-block:: python
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
# 1. Select the FSDP strategy
fabric = L.Fabric(devices=2, strategy=FSDPStrategy(), ...)
fabric.launch()
...
model, optimizer = fabric.setup(model, optimizer)
# 2. Define model, optimizer, and other training loop state
state = {"model": model, "optimizer": optimizer, "iter": iteration}
# 3. Load using Fabric's method
fabric.load("path/to/checkpoint/file", state)
# DON'T do this (inefficient):
# model.load_state_dict(torch.load("path/to/checkpoint/file"))
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.

.. collapse:: Full example

.. code-block:: python
import torch
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
from lightning.pytorch.demos import Transformer, WikiText2
strategy = FSDPStrategy(state_dict_type="sharded")
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
with fabric.rank_zero_first():
dataset = WikiText2()
# 1B parameters
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "iteration": 0}
fabric.print("Loading checkpoint ...")
fabric.load("my-checkpoint.ckpt", state)
.. important::

If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Fabric at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`.


----


.. _Convert dist-checkpoint:

********************************
Convert a distributed checkpoint
********************************

Coming soon.
30 changes: 30 additions & 0 deletions docs/source-fabric/guide/checkpoint/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
###########
Checkpoints
###########

.. raw:: html

<div class="display-card-container">
<div class="row">

.. displayitem::
:header: Save and load model progress
:description: Efficient saving and loading of model weights, training state, hyperparameters and more.
:button_link: checkpoint.html
:col_css: col-md-4
:height: 150
:tag: intermediate

.. displayitem::
:header: Working with very large models
:description: Save and load very large models efficiently with distributed checkpoints
:button_link: distributed_checkpoint.html
:col_css: col-md-4
:height: 150
:tag: advanced


.. raw:: html

</div>
</div>
8 changes: 8 additions & 0 deletions docs/source-fabric/guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ Advanced Topics
:height: 160
:tag: advanced

.. displayitem::
:header: Save and load very large models
:description: Save and load very large models efficiently with distributed checkpoints
:button_link: checkpoint/distributed_checkpoint.html
:col_css: col-md-4
:height: 160
:tag: advanced

.. raw:: html

</div>
Expand Down
9 changes: 9 additions & 0 deletions docs/source-fabric/levels/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<../advanced/distributed_communication>
<../advanced/multiple_setup>
<../advanced/model_parallel/fsdp>
<../guide/checkpoint/distributed_checkpoint>


###############
Expand Down Expand Up @@ -49,6 +50,14 @@ Advanced skills
:height: 170
:tag: advanced

.. displayitem::
:header: Save and load very large models
:description: Save and load very large models efficiently with distributed checkpoints
:button_link: ../guide/checkpoint/distributed_checkpoint.html
:col_css: col-md-4
:height: 170
:tag: advanced

.. raw:: html

</div>
Expand Down
4 changes: 2 additions & 2 deletions docs/source-fabric/levels/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<../guide/lightning_module>
<../guide/callbacks>
<../guide/logging>
<../guide/checkpoint>
<../guide/checkpoint/checkpoint>
<../guide/trainer_template>


Expand Down Expand Up @@ -45,7 +45,7 @@ Intermediate skills
.. displayitem::
:header: Save and load model progress
:description: Efficient saving and loading of model weights, training state, hyperparameters and more.
:button_link: ../guide/checkpoint.html
:button_link: ../guide/checkpoint/checkpoint.html
:col_css: col-md-4
:height: 180
:tag: intermediate
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ The “sharded” checkpoint format is the most efficient to save and load in Li

**Which checkpoint format should I use?**

- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though).
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to convert the sharded checkpoint into a regular checkpoint file.
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.


Expand Down

0 comments on commit a4ecf8d

Please sign in to comment.