-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Compile guide for Trainer #19531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Compile guide for Trainer #19531
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e86c461
add guide
awaelchli 29cb4ca
update
awaelchli fe63041
delete
awaelchli dce40c3
speed
awaelchli 5b8dad0
update
awaelchli 172add0
Merge branch 'master' into docs/compile
awaelchli e2d54e7
integrate
awaelchli e358820
Merge branch 'master' into docs/compile
awaelchli b4562c2
Apply suggestions from code review
awaelchli ff63886
Port Luca's fixes to Fabric
awaelchli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,273 @@ | ||
| ################################# | ||
| Speed up models by compiling them | ||
| ################################# | ||
|
|
||
| Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs. | ||
| This guide shows you how to apply ``torch.compile`` correctly in your code. | ||
|
|
||
| .. note:: | ||
|
|
||
| This requires PyTorch >= 2.0. | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| ******************************************* | ||
| Apply torch.compile to your LightningModule | ||
| ******************************************* | ||
|
|
||
| Compiling a LightningModule is as simple as adding one line of code, calling :func:`torch.compile`: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import torch | ||
| import lightning as L | ||
|
|
||
| # Define the model | ||
| model = MyLightningModule() | ||
|
|
||
| # Compile the model | ||
| model = torch.compile(model) | ||
|
|
||
| # Run with the Trainer | ||
| trainer = L.Trainer() | ||
| trainer.fit(model) | ||
|
|
||
|
|
||
| .. important:: | ||
|
|
||
| You should compile the model **before** calling ``trainer.fit()`` as shown above for an optimal integration with features in Trainer. | ||
|
|
||
| The newly added call to ``torch.compile()`` by itself doesn't do much. It just wraps the model in a "compiled model". | ||
| The actual optimization will start when calling the ``forward()`` method for the first time: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # 1st execution compiles the model (slow) | ||
| output = model(input) | ||
|
|
||
| # All future executions will be fast (for inputs of the same size) | ||
| output = model(input) | ||
| output = model(input) | ||
| ... | ||
|
|
||
| **When you pass the LightningModule to the Trainer, it will automatically also compile the ``*_step()`` methods.** | ||
|
|
||
| When measuring the speed of a compiled model and comparing it to a regular model, it is important to | ||
| always exclude the first call to ``forward()``/``*_step()`` from your measurements, since it includes the compilation time. | ||
|
|
||
|
|
||
| .. collapse:: Full example with benchmark | ||
|
|
||
| Below is an example that measures the speedup you get when compiling the InceptionV3 from TorchVision. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import statistics | ||
| import torch | ||
| import torchvision.models as models | ||
| import lightning as L | ||
| from torch.utils.data import DataLoader | ||
|
|
||
|
|
||
| class MyLightningModule(L.LightningModule): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.model = models.inception_v3() | ||
|
|
||
| def training_step(self, batch): | ||
| return self.model(batch).logits.sum() | ||
|
|
||
| def train_dataloader(self): | ||
| return DataLoader([torch.randn(3, 512, 512) for _ in range(256)], batch_size=16) | ||
|
|
||
| def configure_optimizers(self): | ||
| return torch.optim.SGD(self.parameters(), lr=0.01) | ||
|
|
||
|
|
||
| class Benchmark(L.Callback): | ||
| """A callback that measures the median execution time between the start and end of a batch.""" | ||
| def __init__(self): | ||
| self.start = torch.cuda.Event(enable_timing=True) | ||
| self.end = torch.cuda.Event(enable_timing=True) | ||
| self.times = [] | ||
|
|
||
| def median_time(self): | ||
| return statistics.median(self.times) | ||
|
|
||
| def on_train_batch_start(self, trainer, *args, **kwargs): | ||
| self.start.record() | ||
|
|
||
| def on_train_batch_end(self, trainer, *args, **kwargs): | ||
| # Exclude the first iteration to let the model warm up | ||
| if trainer.global_step > 1: | ||
| self.end.record() | ||
| torch.cuda.synchronize() | ||
| self.times.append(self.start.elapsed_time(self.end) / 1000) | ||
|
|
||
|
|
||
| model = MyLightningModule() | ||
|
|
||
| # Compile! | ||
| compiled_model = torch.compile(model) | ||
|
|
||
| # Measure the median iteration time with uncompiled model | ||
| benchmark = Benchmark() | ||
| trainer = L.Trainer(accelerator="cuda", devices=1, max_steps=10, callbacks=[benchmark]) | ||
| trainer.fit(model) | ||
| eager_time = benchmark.median_time() | ||
|
|
||
| # Measure the median iteration time with compiled model | ||
| benchmark = Benchmark() | ||
| trainer = L.Trainer(accelerator="cuda", devices=1, max_steps=10, callbacks=[benchmark]) | ||
| trainer.fit(compiled_model) | ||
| compile_time = benchmark.median_time() | ||
|
|
||
| # Compare the speedup for the compiled execution | ||
| speedup = eager_time / compile_time | ||
| print(f"Eager median time: {eager_time:.4f} seconds") | ||
| print(f"Compile median time: {compile_time:.4f} seconds") | ||
| print(f"Speedup: {speedup:.1f}x") | ||
|
|
||
|
|
||
| On an NVIDIA A100 SXM4 40GB with PyTorch 2.2.0, CUDA 12.1, we get the following speedup: | ||
|
|
||
| .. code-block:: text | ||
|
|
||
| Eager median time: 0.0863 seconds | ||
| Compile median time: 0.0709 seconds | ||
| Speedup: 1.2x | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| ****************** | ||
| Avoid graph breaks | ||
| ****************** | ||
|
|
||
| When ``torch.compile`` looks at the code in your model's ``forward()`` or ``*_step()`` method, it will try to compile as much of the code as possible. | ||
| If there are regions in the code that it doesn't understand, it will introduce a so-called "graph break" that essentially splits the code in optimized and unoptimized parts. | ||
| Graph breaks aren't a deal breaker, since the optimized parts should still run faster. | ||
| But if you want to get the most out of ``torch.compile``, you might want to invest rewriting the problematic section of the code that produce the breaks. | ||
|
|
||
| You can check whether your model produces graph breaks by calling ``torch.compile`` with ``fullraph=True``: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # Force an error if there is a graph break in the model | ||
| model = torch.compile(model, fullgraph=True) | ||
|
|
||
| Be aware that the error messages produced here are often quite cryptic, so you will likely have to do some `troubleshooting <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`_ to fully optimize your model. | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| ******************* | ||
| Avoid recompilation | ||
| ******************* | ||
|
|
||
| As mentioned before, the compilation of the model happens the first time you call ``forward()`` or the first time the Trainer calls the ``*_step()`` methods. | ||
| At this point, PyTorch will inspect the input tensor(s) and optimize the compiled code for the particular shape, data type and other properties the input has. | ||
| If the shape of the input remains the same across all calls, PyTorch will reuse the compiled code it generated and you will get the best speedup. | ||
| However, if these properties change across subsequent calls to ``forward()``/``*_step()``, PyTorch will be forced to recompile the model for the new shapes, and this will significantly slow down your training if it happens on every iteration. | ||
|
|
||
| **When your training suddenly becomes slow, it's probably because PyTorch is recompiling the model!** | ||
| Here are some common scenarios when this can happen: | ||
|
|
||
| - You are using dataset with different inputs or shapes for validation than for training, causing a recompilation whenever the Trainer switches between training and validation. | ||
| - Your dataset size is not divisible by the batch size, and the dataloader has ``drop_last=False`` (the default). | ||
| The last batch in your training loop will be smaller and trigger a recompilation. | ||
|
|
||
| Ideally, you should try to make the input shape(s) to ``forward()`` static. | ||
| However, when this is not possible, you can request PyTorch to compile the code by taking into account possible changes to the input shapes. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # On PyTorch < 2.2 | ||
| model = torch.compile(model, dynamic=True) | ||
|
|
||
| A model compiled with ``dynamic=True`` will typically be slower than a model compiled with static shapes, but it will avoid the extreme cost of recompilation every iteration. | ||
| On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically and you should no longer need to set this. | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| *********************************** | ||
| Experiment with compilation options | ||
| *********************************** | ||
|
|
||
| There are optional settings that, depending on your model, can give additional speedups. | ||
|
|
||
| **CUDA Graphs:** By enabling CUDA Graphs, CUDA will record all computations in a graph and replay it every time forward and backward is called. | ||
| The requirement is that your model must be static, i.e., the input shape must not change and your model must execute the same operations every time. | ||
| Enabling CUDA Graphs often results in a significant speedup, but sometimes also increases the memory usage of your model. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # Enable CUDA Graphs | ||
| compiled_model = torch.compile(model, mode="reduce-overhead") | ||
|
|
||
| # This does the same | ||
| compiled_model = torch.compile(model, options={"triton.cudagraphs": True}) | ||
|
|
||
| | | ||
|
|
||
| **Shape padding:** The specific shape/size of the tensors involved in the computation of your model (input, activations, weights, gradients, etc.) can have an impact on the performance. | ||
| With shape padding enabled, ``torch.compile`` can extend the tensors by padding to a size that gives a better memory alignment. | ||
| Naturally, the tradoff here is that it will consume a bit more memory. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # Default is False | ||
| compiled_model = torch.compile(model, options={"shape_padding": True}) | ||
|
|
||
|
|
||
| You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_. | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| ************************************** | ||
| A note about torch.compile in practice | ||
| ************************************** | ||
|
|
||
| In practice, you will find that ``torch.compile`` often doesn't work well and can even be counter-productive. | ||
awaelchli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Compilation may fail with cryptic error messages that are impossible to debug without help from the PyTorch team. | ||
awaelchli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| It is also not uncommon that ``torch.compile`` will produce a significantly *slower* model or one with much higher memory usage. | ||
awaelchli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| On top of that, the compilation phase itself can be incredibly slow, taking several minutes to finish. | ||
awaelchli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| For these reasons, we recommend that you don't waste too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments. | ||
awaelchli marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Always compare the speed and memory usage of the compiled model against the original model! | ||
|
|
||
|
|
||
| ---- | ||
|
|
||
|
|
||
| *********** | ||
| Limitations | ||
| *********** | ||
|
|
||
| There are a few limitations you should be aware of when using ``torch.compile`` in conjunction with the Trainer: | ||
|
|
||
| * ``torch.compile`` currently does not get reapplied over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment. | ||
| This limitation will be lifted in the future. | ||
|
|
||
| * In some cases, using ``self.log()`` in your LightningModule will cause compilation errors. | ||
| Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import lightning as L | ||
|
|
||
| class MyLightningModule(L.LightningModule): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.model = MySubModule() | ||
| self.model = torch.compile(self.model) | ||
| ... | ||
|
|
||
| | | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.