diff --git a/docs/source-fabric/advanced/model_parallel/fsdp.rst b/docs/source-fabric/advanced/model_parallel/fsdp.rst index 6707f9834d698..df37fac01c723 100644 --- a/docs/source-fabric/advanced/model_parallel/fsdp.rst +++ b/docs/source-fabric/advanced/model_parallel/fsdp.rst @@ -358,6 +358,7 @@ The resulting checkpoint folder will have this structure: ├── .metadata ├── __0_0.distcp ├── __1_0.distcp + ... └── meta.pt The “sharded” checkpoint format is the most efficient to save and load in Fabric. @@ -374,7 +375,7 @@ However, if you prefer to have a single consolidated file instead, you can confi **Which checkpoint format should I use?** -- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though). +- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to :doc:`convert the sharded checkpoint into a regular checkpoint file <../../guide/checkpoint/distributed_checkpoint>`. - ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required. @@ -400,7 +401,7 @@ You can easily load checkpoints saved by Fabric to resume training: Fabric will automatically recognize whether the provided path contains a checkpoint saved with ``state_dict_type="full"`` or ``state_dict_type="sharded"``. Checkpoints saved with ``state_dict_type="full"`` can be loaded by all strategies, but sharded checkpoints can only be loaded by FSDP. -Read :doc:`the checkpoints guide <../../guide/checkpoint>` to explore more features. +Read :doc:`the checkpoints guide <../../guide/checkpoint/index>` to explore more features. ---- diff --git a/docs/source-fabric/api/fabric_methods.rst b/docs/source-fabric/api/fabric_methods.rst index 05597b10c7b28..d6e7ccbc38e31 100644 --- a/docs/source-fabric/api/fabric_methods.rst +++ b/docs/source-fabric/api/fabric_methods.rst @@ -218,7 +218,7 @@ Fabric will handle the saving part correctly, whether running a single device, m You should pass the model and optimizer objects directly into the dictionary so Fabric can unwrap them and automatically retrieve their *state-dict*. -See also: :doc:`../guide/checkpoint` +See also: :doc:`../guide/checkpoint/index` load @@ -248,7 +248,7 @@ Fabric will handle the loading part correctly, whether running a single device, To load the state of your model or optimizer from a raw PyTorch checkpoint (not saved with Fabric), use :meth:`~lightning.fabric.fabric.Fabric.load_raw` instead. -See also: :doc:`../guide/checkpoint` +See also: :doc:`../guide/checkpoint/index` load_raw @@ -267,7 +267,7 @@ Load the state-dict of a model or optimizer from a raw PyTorch checkpoint not sa # model.load_state_dict(torch.load("path/to/model.pt")) -See also: :doc:`../guide/checkpoint` +See also: :doc:`../guide/checkpoint/index` barrier diff --git a/docs/source-fabric/glossary/index.rst b/docs/source-fabric/glossary/index.rst index 38b0c2b2faa65..298c08f4e2da5 100644 --- a/docs/source-fabric/glossary/index.rst +++ b/docs/source-fabric/glossary/index.rst @@ -2,6 +2,12 @@ Glossary ######## +.. toctree:: + :maxdepth: 1 + :hidden: + + Checkpoint <../guide/checkpoint/index> + .. raw:: html @@ -45,7 +51,7 @@ Glossary .. displayitem:: :header: Checkpoint - :button_link: ../guide/checkpoint.html + :button_link: ../guide/checkpoint/index.html :col_css: col-md-4 .. displayitem:: diff --git a/docs/source-fabric/guide/checkpoint.rst b/docs/source-fabric/guide/checkpoint/checkpoint.rst similarity index 94% rename from docs/source-fabric/guide/checkpoint.rst rename to docs/source-fabric/guide/checkpoint/checkpoint.rst index d556734dbbdb0..f644ee99c79b3 100644 --- a/docs/source-fabric/guide/checkpoint.rst +++ b/docs/source-fabric/guide/checkpoint/checkpoint.rst @@ -45,7 +45,7 @@ To save the state to the filesystem, pass it to the :meth:`~lightning.fabric.fab This will unwrap your model and optimizer and automatically convert their ``state_dict`` for you. Fabric and the underlying strategy will decide in which format your checkpoint gets saved. -For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` saves multiple files from all ranks. +For example, ``strategy="ddp"`` saves a single file on rank 0, while ``strategy="fsdp"`` :doc:`saves multiple files from all ranks `. ---- @@ -85,7 +85,7 @@ If you want to be in complete control of how states get restored, you can omit p optimizer.load_state_dict(full_checkpoint["optimizer"]) ... -See also: :doc:`../advanced/model_init` +See also: :doc:`../../advanced/model_init` From a raw state-dict file @@ -195,13 +195,19 @@ Here's an example of using a filter when saving a checkpoint: Next steps ********** -Learn from our template how Fabrics checkpoint mechanism can be integrated into a full Trainer: - .. raw:: html
+.. displayitem:: + :header: Working with very large models + :description: Save and load very large models efficiently with distributed checkpoints + :button_link: distributed_checkpoint.html + :col_css: col-md-4 + :height: 150 + :tag: advanced + .. displayitem:: :header: Trainer Template :description: Take our Fabric Trainer template and customize it for your needs diff --git a/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst b/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst new file mode 100644 index 0000000000000..d78813d527e60 --- /dev/null +++ b/docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst @@ -0,0 +1,186 @@ +########################################## +Saving and Loading Distributed Checkpoints +########################################## + +Generally, the bigger your model is, the longer it takes to save a checkpoint to disk. +With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues. + + +---- + + +***************************** +Save a distributed checkpoint +***************************** + +The distributed checkpoint format is the default when you train with the :doc:`FSDP strategy <../../advanced/model_parallel/fsdp>`. + +.. code-block:: python + + import lightning as L + from lightning.fabric.strategies import FSDPStrategy + + # 1. Select the FSDP strategy + strategy = FSDPStrategy( + # Default: sharded/distributed checkpoint + state_dict_type="sharded", + # Full checkpoint (not distributed) + # state_dict_type="full", + ) + + fabric = L.Fabric(devices=2, strategy=strategy, ...) + fabric.launch() + ... + model, optimizer = fabric.setup(model, optimizer) + + # 2. Define model, optimizer, and other training loop state + state = {"model": model, "optimizer": optimizer, "iter": iteration} + + # DON'T do this (inefficient): + # state = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), ...} + + # 3. Save using Fabric's method + fabric.save("path/to/checkpoint/file", state) + + # DON'T do this (inefficient): + # torch.save("path/to/checkpoint/file", state) + +With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path. +This reduces memory peaks and speeds up the saving to disk. +The resulting checkpoint folder will have this structure: + +.. collapse:: Full example + + .. code-block:: python + + import time + import torch + import torch.nn.functional as F + + import lightning as L + from lightning.fabric.strategies import FSDPStrategy + from lightning.pytorch.demos import Transformer, WikiText2 + + strategy = FSDPStrategy(state_dict_type="sharded") + fabric = L.Fabric(accelerator="cuda", devices=4, strategy=strategy) + fabric.launch() + + with fabric.rank_zero_first(): + dataset = WikiText2() + + # 1B parameters + model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + model, optimizer = fabric.setup(model, optimizer) + + state = {"model": model, "optimizer": optimizer, "iteration": 0} + + for i in range(10): + input, target = fabric.to_device(dataset[i]) + output = model(input.unsqueeze(0), target.unsqueeze(0)) + loss = F.nll_loss(output, target.view(-1)) + fabric.backward(loss) + optimizer.step() + optimizer.zero_grad() + fabric.print(loss.item()) + + fabric.print("Saving checkpoint ...") + t0 = time.time() + fabric.save("my-checkpoint.ckpt", state) + fabric.print(f"Took {time.time() - t0:.2f} seconds.") + + Check the contents of the checkpoint folder: + + .. code-block:: bash + + ls -a my-checkpoint.ckpt/ + + .. code-block:: + + my-checkpoint.ckpt/ + ├── __0_0.distcp + ├── __1_0.distcp + ├── __2_0.distcp + ├── __3_0.distcp + └── meta.pt + + The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files + is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs. + + +---- + + +***************************** +Load a distributed checkpoint +***************************** + +You can easily load a distributed checkpoint in Fabric if your script uses :doc:`FSDP <../../advanced/model_parallel/fsdp>`. + +.. code-block:: python + + import lightning as L + from lightning.fabric.strategies import FSDPStrategy + + # 1. Select the FSDP strategy + fabric = L.Fabric(devices=2, strategy=FSDPStrategy(), ...) + fabric.launch() + ... + model, optimizer = fabric.setup(model, optimizer) + + # 2. Define model, optimizer, and other training loop state + state = {"model": model, "optimizer": optimizer, "iter": iteration} + + # 3. Load using Fabric's method + fabric.load("path/to/checkpoint/file", state) + + # DON'T do this (inefficient): + # model.load_state_dict(torch.load("path/to/checkpoint/file")) + +Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint. + +.. collapse:: Full example + + .. code-block:: python + + import torch + + import lightning as L + from lightning.fabric.strategies import FSDPStrategy + from lightning.pytorch.demos import Transformer, WikiText2 + + strategy = FSDPStrategy(state_dict_type="sharded") + fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy) + fabric.launch() + + with fabric.rank_zero_first(): + dataset = WikiText2() + + # 1B parameters + model = Transformer(vocab_size=dataset.vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + model, optimizer = fabric.setup(model, optimizer) + + state = {"model": model, "optimizer": optimizer, "iteration": 0} + + fabric.print("Loading checkpoint ...") + fabric.load("my-checkpoint.ckpt", state) + + +.. important:: + + If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Fabric at all), then you will have to :ref:`convert it to a single-file checkpoint first `. + + +---- + + +.. _Convert dist-checkpoint: + +******************************** +Convert a distributed checkpoint +******************************** + +Coming soon. diff --git a/docs/source-fabric/guide/checkpoint/index.rst b/docs/source-fabric/guide/checkpoint/index.rst new file mode 100644 index 0000000000000..a7f0f046c54cf --- /dev/null +++ b/docs/source-fabric/guide/checkpoint/index.rst @@ -0,0 +1,30 @@ +########### +Checkpoints +########### + +.. raw:: html + +
+
+ +.. displayitem:: + :header: Save and load model progress + :description: Efficient saving and loading of model weights, training state, hyperparameters and more. + :button_link: checkpoint.html + :col_css: col-md-4 + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Working with very large models + :description: Save and load very large models efficiently with distributed checkpoints + :button_link: distributed_checkpoint.html + :col_css: col-md-4 + :height: 150 + :tag: advanced + + +.. raw:: html + +
+
diff --git a/docs/source-fabric/guide/index.rst b/docs/source-fabric/guide/index.rst index a8c961bb50c65..7b13e8eb4bbc7 100644 --- a/docs/source-fabric/guide/index.rst +++ b/docs/source-fabric/guide/index.rst @@ -173,6 +173,14 @@ Advanced Topics :height: 160 :tag: advanced +.. displayitem:: + :header: Save and load very large models + :description: Save and load very large models efficiently with distributed checkpoints + :button_link: checkpoint/distributed_checkpoint.html + :col_css: col-md-4 + :height: 160 + :tag: advanced + .. raw:: html
diff --git a/docs/source-fabric/levels/advanced.rst b/docs/source-fabric/levels/advanced.rst index b8fb45dd7a933..3760acab2e6da 100644 --- a/docs/source-fabric/levels/advanced.rst +++ b/docs/source-fabric/levels/advanced.rst @@ -6,6 +6,7 @@ <../advanced/distributed_communication> <../advanced/multiple_setup> <../advanced/model_parallel/fsdp> + <../guide/checkpoint/distributed_checkpoint> ############### @@ -49,6 +50,14 @@ Advanced skills :height: 170 :tag: advanced +.. displayitem:: + :header: Save and load very large models + :description: Save and load very large models efficiently with distributed checkpoints + :button_link: ../guide/checkpoint/distributed_checkpoint.html + :col_css: col-md-4 + :height: 170 + :tag: advanced + .. raw:: html
diff --git a/docs/source-fabric/levels/intermediate.rst b/docs/source-fabric/levels/intermediate.rst index 2d2037a00a6fa..f21e7c96608ab 100644 --- a/docs/source-fabric/levels/intermediate.rst +++ b/docs/source-fabric/levels/intermediate.rst @@ -5,7 +5,7 @@ <../guide/lightning_module> <../guide/callbacks> <../guide/logging> - <../guide/checkpoint> + <../guide/checkpoint/checkpoint> <../guide/trainer_template> @@ -45,7 +45,7 @@ Intermediate skills .. displayitem:: :header: Save and load model progress :description: Efficient saving and loading of model weights, training state, hyperparameters and more. - :button_link: ../guide/checkpoint.html + :button_link: ../guide/checkpoint/checkpoint.html :col_css: col-md-4 :height: 180 :tag: intermediate diff --git a/docs/source-pytorch/advanced/model_parallel/fsdp.rst b/docs/source-pytorch/advanced/model_parallel/fsdp.rst index b1c312b927172..cd8d2601a704c 100644 --- a/docs/source-pytorch/advanced/model_parallel/fsdp.rst +++ b/docs/source-pytorch/advanced/model_parallel/fsdp.rst @@ -403,7 +403,7 @@ The “sharded” checkpoint format is the most efficient to save and load in Li **Which checkpoint format should I use?** -- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though). +- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to convert the sharded checkpoint into a regular checkpoint file. - ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.