You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
print(f"Eager median time: {eager_time:.4f} seconds")
106
+
print(f"Compile median time: {compile_time:.4f} seconds")
107
+
print(f"Speedup: {speedup:.1f}x")
108
+
109
+
On an NVIDIA A100 with PyTorch 2.1.2, CUDA 12.1, we get the following speedup:
110
+
111
+
.. code-block:: text
112
+
113
+
Eager median time: 0.0151 seconds
114
+
Compile median time: 0.0056 seconds
115
+
Speedup: 2.7x
50
116
51
117
52
118
----
@@ -58,7 +124,7 @@ Avoid graph breaks
58
124
59
125
When ``torch.compile`` looks at the code in your model's ``forward()`` method, it will try to compile as much of the code as possible.
60
126
If there are regions in the code that it doesn't understand, it will introduce a so-called "graph break" that essentially splits the code in optimized and unoptimized parts.
61
-
Graph breaks aren't a deal breaker, since the optimized parts will still run faster.
127
+
Graph breaks aren't a deal breaker, since the optimized parts should still run faster.
62
128
But if you want to get the most out of ``torch.compile``, you might want to invest rewriting the problematic section of the code that produce the breaks.
63
129
64
130
You can check whether your model produces graph breaks by calling ``torch.compile`` with ``fullraph=True``:
@@ -70,6 +136,7 @@ You can check whether your model produces graph breaks by calling ``torch.compil
70
136
71
137
The error messages produced here are often quite cryptic.
0 commit comments