Skip to content

Commit a0f4144

Browse files
committed
Merge branch 'master' into docs/compile
2 parents a248677 + d02009a commit a0f4144

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1747
-306
lines changed

docs/source-fabric/advanced/model_parallel/fsdp.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to t
1212
The memory consumption for training is generally made up of
1313

1414
1. the model parameters,
15-
2. the layer activations (forward) and
16-
3. the gradients (backward).
17-
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
15+
2. the layer activations (forward),
16+
3. the gradients (backward) and
17+
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter).
1818

1919
|
2020
@@ -358,6 +358,7 @@ The resulting checkpoint folder will have this structure:
358358
├── .metadata
359359
├── __0_0.distcp
360360
├── __1_0.distcp
361+
...
361362
└── meta.pt
362363
363364
The “sharded” checkpoint format is the most efficient to save and load in Fabric.
@@ -374,7 +375,7 @@ However, if you prefer to have a single consolidated file instead, you can confi
374375
375376
**Which checkpoint format should I use?**
376377

377-
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though).
378+
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to :doc:`convert the sharded checkpoint into a regular checkpoint file <../../guide/checkpoint/distributed_checkpoint>`.
378379
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.
379380

380381

@@ -400,7 +401,7 @@ You can easily load checkpoints saved by Fabric to resume training:
400401
401402
Fabric will automatically recognize whether the provided path contains a checkpoint saved with ``state_dict_type="full"`` or ``state_dict_type="sharded"``.
402403
Checkpoints saved with ``state_dict_type="full"`` can be loaded by all strategies, but sharded checkpoints can only be loaded by FSDP.
403-
Read :doc:`the checkpoints guide <../../guide/checkpoint>` to explore more features.
404+
Read :doc:`the checkpoints guide <../../guide/checkpoint/index>` to explore more features.
404405

405406

406407
----

docs/source-fabric/api/fabric_methods.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Fabric will handle the saving part correctly, whether running a single device, m
218218
219219
You should pass the model and optimizer objects directly into the dictionary so Fabric can unwrap them and automatically retrieve their *state-dict*.
220220

221-
See also: :doc:`../guide/checkpoint`
221+
See also: :doc:`../guide/checkpoint/index`
222222

223223

224224
load
@@ -248,7 +248,7 @@ Fabric will handle the loading part correctly, whether running a single device,
248248
249249
250250
To load the state of your model or optimizer from a raw PyTorch checkpoint (not saved with Fabric), use :meth:`~lightning.fabric.fabric.Fabric.load_raw` instead.
251-
See also: :doc:`../guide/checkpoint`
251+
See also: :doc:`../guide/checkpoint/index`
252252

253253

254254
load_raw
@@ -267,7 +267,7 @@ Load the state-dict of a model or optimizer from a raw PyTorch checkpoint not sa
267267
# model.load_state_dict(torch.load("path/to/model.pt"))
268268
269269
270-
See also: :doc:`../guide/checkpoint`
270+
See also: :doc:`../guide/checkpoint/index`
271271

272272

273273
barrier

docs/source-fabric/glossary/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
Glossary
33
########
44

5+
.. toctree::
6+
:maxdepth: 1
7+
:hidden:
8+
9+
Checkpoint <../guide/checkpoint/index>
10+
511

612
.. raw:: html
713

@@ -45,7 +51,7 @@ Glossary
4551

4652
.. displayitem::
4753
:header: Checkpoint
48-
:button_link: ../guide/checkpoint.html
54+
:button_link: ../guide/checkpoint/index.html
4955
:col_css: col-md-4
5056

5157
.. displayitem::

docs/source-fabric/guide/checkpoint.rst renamed to docs/source-fabric/guide/checkpoint/checkpoint.rst

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ To save the state to the filesystem, pass it to the :meth:`~lightning.fabric.fab
4545
4646
This will unwrap your model and optimizer and automatically convert their ``state_dict`` for you.
4747
Fabric and the underlying strategy will decide in which format your checkpoint gets saved.
48-
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` saves multiple files from all ranks.
48+
For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` :doc:`saves multiple files from all ranks <distributed_checkpoint>`.
4949

5050

5151
----
@@ -85,7 +85,7 @@ If you want to be in complete control of how states get restored, you can omit p
8585
optimizer.load_state_dict(full_checkpoint["optimizer"])
8686
...
8787
88-
See also: :doc:`../advanced/model_init`
88+
See also: :doc:`../../advanced/model_init`
8989

9090

9191
From a raw state-dict file
@@ -195,13 +195,19 @@ Here's an example of using a filter when saving a checkpoint:
195195
Next steps
196196
**********
197197

198-
Learn from our template how Fabrics checkpoint mechanism can be integrated into a full Trainer:
199-
200198
.. raw:: html
201199

202200
<div class="display-card-container">
203201
<div class="row">
204202

203+
.. displayitem::
204+
:header: Working with very large models
205+
:description: Save and load very large models efficiently with distributed checkpoints
206+
:button_link: distributed_checkpoint.html
207+
:col_css: col-md-4
208+
:height: 150
209+
:tag: advanced
210+
205211
.. displayitem::
206212
:header: Trainer Template
207213
:description: Take our Fabric Trainer template and customize it for your needs
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
##########################################
2+
Saving and Loading Distributed Checkpoints
3+
##########################################
4+
5+
Generally, the bigger your model is, the longer it takes to save a checkpoint to disk.
6+
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues.
7+
8+
9+
----
10+
11+
12+
*****************************
13+
Save a distributed checkpoint
14+
*****************************
15+
16+
The distributed checkpoint format is the default when you train with the :doc:`FSDP strategy <../../advanced/model_parallel/fsdp>`.
17+
18+
.. code-block:: python
19+
20+
import lightning as L
21+
from lightning.fabric.strategies import FSDPStrategy
22+
23+
# 1. Select the FSDP strategy
24+
strategy = FSDPStrategy(
25+
# Default: sharded/distributed checkpoint
26+
state_dict_type="sharded",
27+
# Full checkpoint (not distributed)
28+
# state_dict_type="full",
29+
)
30+
31+
fabric = L.Fabric(devices=2, strategy=strategy, ...)
32+
fabric.launch()
33+
...
34+
model, optimizer = fabric.setup(model, optimizer)
35+
36+
# 2. Define model, optimizer, and other training loop state
37+
state = {"model": model, "optimizer": optimizer, "iter": iteration}
38+
39+
# DON'T do this (inefficient):
40+
# state = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), ...}
41+
42+
# 3. Save using Fabric's method
43+
fabric.save("path/to/checkpoint/file", state)
44+
45+
# DON'T do this (inefficient):
46+
# torch.save("path/to/checkpoint/file", state)
47+
48+
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
49+
This reduces memory peaks and speeds up the saving to disk.
50+
51+
.. collapse:: Full example
52+
53+
.. code-block:: python
54+
55+
import time
56+
import torch
57+
import torch.nn.functional as F
58+
59+
import lightning as L
60+
from lightning.fabric.strategies import FSDPStrategy
61+
from lightning.pytorch.demos import Transformer, WikiText2
62+
63+
strategy = FSDPStrategy(state_dict_type="sharded")
64+
fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy)
65+
fabric.launch()
66+
67+
with fabric.rank_zero_first():
68+
dataset = WikiText2()
69+
70+
# 1B parameters
71+
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
72+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
73+
74+
model, optimizer = fabric.setup(model, optimizer)
75+
76+
state = {"model": model, "optimizer": optimizer, "iteration": 0}
77+
78+
for i in range(10):
79+
input, target = fabric.to_device(dataset[i])
80+
output = model(input.unsqueeze(0), target.unsqueeze(0))
81+
loss = F.nll_loss(output, target.view(-1))
82+
fabric.backward(loss)
83+
optimizer.step()
84+
optimizer.zero_grad()
85+
fabric.print(loss.item())
86+
87+
fabric.print("Saving checkpoint ...")
88+
t0 = time.time()
89+
fabric.save("my-checkpoint.ckpt", state)
90+
fabric.print(f"Took {time.time() - t0:.2f} seconds.")
91+
92+
Check the contents of the checkpoint folder:
93+
94+
.. code-block:: bash
95+
96+
ls -a my-checkpoint.ckpt/
97+
98+
.. code-block::
99+
100+
my-checkpoint.ckpt/
101+
├── __0_0.distcp
102+
├── __1_0.distcp
103+
├── __2_0.distcp
104+
├── __3_0.distcp
105+
├── .metadata
106+
└── meta.pt
107+
108+
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
109+
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs.
110+
111+
112+
----
113+
114+
115+
*****************************
116+
Load a distributed checkpoint
117+
*****************************
118+
119+
You can easily load a distributed checkpoint in Fabric if your script uses :doc:`FSDP <../../advanced/model_parallel/fsdp>`.
120+
121+
.. code-block:: python
122+
123+
import lightning as L
124+
from lightning.fabric.strategies import FSDPStrategy
125+
126+
# 1. Select the FSDP strategy
127+
fabric = L.Fabric(devices=2, strategy=FSDPStrategy(), ...)
128+
fabric.launch()
129+
...
130+
model, optimizer = fabric.setup(model, optimizer)
131+
132+
# 2. Define model, optimizer, and other training loop state
133+
state = {"model": model, "optimizer": optimizer, "iter": iteration}
134+
135+
# 3. Load using Fabric's method
136+
fabric.load("path/to/checkpoint/file", state)
137+
138+
# DON'T do this (inefficient):
139+
# model.load_state_dict(torch.load("path/to/checkpoint/file"))
140+
141+
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.
142+
143+
.. collapse:: Full example
144+
145+
.. code-block:: python
146+
147+
import torch
148+
149+
import lightning as L
150+
from lightning.fabric.strategies import FSDPStrategy
151+
from lightning.pytorch.demos import Transformer, WikiText2
152+
153+
strategy = FSDPStrategy(state_dict_type="sharded")
154+
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
155+
fabric.launch()
156+
157+
with fabric.rank_zero_first():
158+
dataset = WikiText2()
159+
160+
# 1B parameters
161+
model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
162+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
163+
164+
model, optimizer = fabric.setup(model, optimizer)
165+
166+
state = {"model": model, "optimizer": optimizer, "iteration": 0}
167+
168+
fabric.print("Loading checkpoint ...")
169+
fabric.load("my-checkpoint.ckpt", state)
170+
171+
172+
.. important::
173+
174+
If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Fabric at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`.
175+
176+
177+
----
178+
179+
180+
.. _Convert dist-checkpoint:
181+
182+
********************************
183+
Convert a distributed checkpoint
184+
********************************
185+
186+
Coming soon.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
###########
2+
Checkpoints
3+
###########
4+
5+
.. raw:: html
6+
7+
<div class="display-card-container">
8+
<div class="row">
9+
10+
.. displayitem::
11+
:header: Save and load model progress
12+
:description: Efficient saving and loading of model weights, training state, hyperparameters and more.
13+
:button_link: checkpoint.html
14+
:col_css: col-md-4
15+
:height: 150
16+
:tag: intermediate
17+
18+
.. displayitem::
19+
:header: Working with very large models
20+
:description: Save and load very large models efficiently with distributed checkpoints
21+
:button_link: distributed_checkpoint.html
22+
:col_css: col-md-4
23+
:height: 150
24+
:tag: advanced
25+
26+
27+
.. raw:: html
28+
29+
</div>
30+
</div>

docs/source-fabric/guide/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ Advanced Topics
173173
:height: 160
174174
:tag: advanced
175175

176+
.. displayitem::
177+
:header: Save and load very large models
178+
:description: Save and load very large models efficiently with distributed checkpoints
179+
:button_link: checkpoint/distributed_checkpoint.html
180+
:col_css: col-md-4
181+
:height: 160
182+
:tag: advanced
183+
176184
.. raw:: html
177185

178186
</div>

docs/source-fabric/levels/advanced.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
<../advanced/distributed_communication>
77
<../advanced/multiple_setup>
88
<../advanced/model_parallel/fsdp>
9+
<../guide/checkpoint/distributed_checkpoint>
910

1011

1112
###############
@@ -49,6 +50,14 @@ Advanced skills
4950
:height: 170
5051
:tag: advanced
5152

53+
.. displayitem::
54+
:header: Save and load very large models
55+
:description: Save and load very large models efficiently with distributed checkpoints
56+
:button_link: ../guide/checkpoint/distributed_checkpoint.html
57+
:col_css: col-md-4
58+
:height: 170
59+
:tag: advanced
60+
5261
.. raw:: html
5362

5463
</div>

0 commit comments

Comments
 (0)