Skip to content

Commit 7862ca4

Browse files
committed
example
1 parent a0f4144 commit 7862ca4

File tree

1 file changed

+75
-8
lines changed

1 file changed

+75
-8
lines changed

docs/source-fabric/guide/compile.rst

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ Compile
33
#######
44

55
Compiling your PyTorch model can result in significant speedups, especially on the latest hardware such as NVIDIA GPUs.
6-
This guide shows you how to apply `torch.compile` correctly in your code.
6+
This guide shows you how to apply ``torch.compile`` correctly in your code.
77

88
.. note::
9+
910
This requires PyTorch >= 2.0.
1011

1112

@@ -22,10 +23,12 @@ Compiling a model in a script together with Fabric is as simple as adding one li
2223
2324
import torch
2425
import lightning as L
25-
from lightning.pytorch.demos import Transformer
2626
27+
# Set up Fabric
2728
fabric = L.Fabric(devices=1)
28-
model = Transformer(128)
29+
30+
# Define the model
31+
model = ...
2932
3033
# Compile the model
3134
model = torch.compile(model)
@@ -38,15 +41,78 @@ Compiling a model in a script together with Fabric is as simple as adding one li
3841

3942
You should compile the model **before** calling ``fabric.setup()`` as shown above for an optimal integration with features in Fabric.
4043

41-
The newly added call to ``torch.compile()`` by itself doesn't do much yet. It just wraps the model in a "compiled model".
44+
The newly added call to ``torch.compile()`` by itself doesn't do much. It just wraps the model in a "compiled model".
4245
The actual optimization will start when calling ``forward()`` on the model for the first time:
4346

4447
.. code-block:: python
4548
46-
input = torch.randint(0, 128, (4, 256), device=fabric.device)
47-
target = torch.randint(0, 128, (4, 256), device=fabric.device)
49+
# 1st execution compiles the model (slow)
50+
output = model(input)
51+
52+
# All future executions will be fast
53+
output = model(input)
54+
output = model(input)
55+
...
56+
57+
This is important to know when you measure the speed of a compiled model and compare it to a regular model.
58+
You should always *exclude* the first call to ``forward()`` from your measurements, since it includes the compilation time.
59+
60+
.. collapse:: Full example with benchmark
61+
62+
Below is an example that measures the speedup you get by compiling a DenseNet vision model.
63+
64+
.. code-block:: python
65+
66+
import statistics
67+
import torch
68+
import torchvision.models as models
69+
import lightning as L
70+
71+
72+
@torch.no_grad()
73+
def benchmark(model, input, num_iters=10):
74+
"""Runs the model on the input several times and returns the median execution time."""
75+
start = torch.cuda.Event(enable_timing=True)
76+
end = torch.cuda.Event(enable_timing=True)
77+
times = []
78+
for _ in range(num_iters):
79+
start.record()
80+
model(input)
81+
end.record()
82+
torch.cuda.synchronize()
83+
times.append(start.elapsed_time(end) / 1000)
84+
return statistics.median(times)
4885
49-
output = model(input, target) # compiles when `forward()` runs for the first time
86+
87+
fabric = L.Fabric(accelerator="cuda", devices=1)
88+
89+
model = models.densenet121() #.to(torch.float32)
90+
input = torch.randn(16, 3, 128, 128, device=fabric.device)
91+
92+
compiled_model = torch.compile(model, mode="reduce-overhead")
93+
model = fabric.setup(model)
94+
compiled_model = fabric.setup(compiled_model)
95+
96+
# warmup the compiled model before we benchmark
97+
compiled_model(input)
98+
99+
# Run multiple forward passes and time them
100+
eager_time = benchmark(model, input)
101+
compile_time = benchmark(compiled_model, input)
102+
103+
# Compare the speedup for the compiled execution
104+
speedup = eager_time / compile_time
105+
print(f"Eager median time: {eager_time:.4f} seconds")
106+
print(f"Compile median time: {compile_time:.4f} seconds")
107+
print(f"Speedup: {speedup:.1f}x")
108+
109+
On an NVIDIA A100 with PyTorch 2.1.2, CUDA 12.1, we get the following speedup:
110+
111+
.. code-block:: text
112+
113+
Eager median time: 0.0151 seconds
114+
Compile median time: 0.0056 seconds
115+
Speedup: 2.7x
50116
51117
52118
----
@@ -58,7 +124,7 @@ Avoid graph breaks
58124

59125
When ``torch.compile`` looks at the code in your model's ``forward()`` method, it will try to compile as much of the code as possible.
60126
If there are regions in the code that it doesn't understand, it will introduce a so-called "graph break" that essentially splits the code in optimized and unoptimized parts.
61-
Graph breaks aren't a deal breaker, since the optimized parts will still run faster.
127+
Graph breaks aren't a deal breaker, since the optimized parts should still run faster.
62128
But if you want to get the most out of ``torch.compile``, you might want to invest rewriting the problematic section of the code that produce the breaks.
63129

64130
You can check whether your model produces graph breaks by calling ``torch.compile`` with ``fullraph=True``:
@@ -70,6 +136,7 @@ You can check whether your model produces graph breaks by calling ``torch.compil
70136
71137
The error messages produced here are often quite cryptic.
72138

139+
73140
----
74141

75142

0 commit comments

Comments
 (0)