You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Andrej Karpathy has the [nanochat](https://github.com/karpathy/nanochat) project with the description "The best ChatGPT that $100 can buy". He evolved a model architecture and training setup that reaches the performance of GPT-2 while costing 600 times less than the original OpenAI run from 2019. This is an inspiring example, showing that pretraining experiments can now be available even to individuals without corporate/university backing. Andrej's run took ~3 hours on 8xH100, costing $73.
9
9
10
-
I decided to investigate how much LLM pretraining research can be done using the latest TPUs on a tight personal budget - without paying more than we pay for our coding assistants. Google Colab Pro+ has a $50 / month plan that provides 600 credits. These credits can be used to rent GPU/TPU kernels. Supported TPUs are v5e and v6e. Their price in Colab credits is roughly the same[^price], while v6e (Trillium) packs 2x more HBM, and has 4.7x quicker matmuls. We only consider v6e (Trillium) below, but the provided notebook supports v5e as well[^free_v5e].
10
+
I decided to investigate how much LLM pretraining research can be done using the latest TPUs on a tight personal budget - without paying more than we pay for our coding assistants. Google Colab Pro+ has a $50 / month plan that provides 600 credits. These credits can be used to rent GPU/TPU kernels. Supported TPUs are v5e and v6e. Their price in Colab credits is roughly the same[^price], while v6e packs 2x more HBM, and has 4.7x quicker matmuls. We only consider v6e below, but the provided notebook supports v5e as well[^free_v5e].
11
11
12
12
[^price]: v5e costs 3.14 credits per hour, v6e - 3.71.
13
13
[^free_v5e]: The free Colab plan allows using v5e, but not v6e. The free quota is enough for a few short training runs. v5e has less HBM, but enough for ~100M models.
14
14
15
15
## Back of the envelope calculations
16
16
17
-
Here are the v6e (Trillium) performance specs from the [Google Cloud docs](https://docs.cloud.google.com/tpu/docs/v6e).
17
+
Here are the TPU v6e performance specs from the [Google Cloud docs](https://docs.cloud.google.com/tpu/docs/v6e).
18
18
19
19
| Specification | Values |
20
20
|--------|-------|
@@ -125,6 +125,11 @@ Nevertheless first versions of the training only reached 25% MXU usage. I pushed
125
125
126
126
When I was waiting for a plane, I came up with the following idea: let Claude Code (via Claude Code Web) build [a Colab notebook with a thorough set of TPU performance tests](https://github.com/vorushin/tpuchat/blob/master/05_tpu_perf.ipynb), building the transformer block by block, and measure the MFU of different parts, in different sizes and in various combinations. Start from the pure matmuls, then, implement and profile individual components, then a single layer, multiple layers, forward and backward pass, the optimizer implementation, each phase independently runnable. Even though the first implementation had a lot of issues, it helped me to start seeing MXU usage north of 50% and I was eventually able to dissect the slow parts and replace them with the faster implementations.
<figcaption>That's how one of the cells in the beginning of the notebook looks like. Seeing high MXU usage was a big relief.</figcaption>
131
+
</figure>
132
+
128
133
Here are selected results from the benchmark, building up from atoms to the full training step:
129
134
130
135
| Benchmark | Wall ms | MFU% | Takeaway |
@@ -150,12 +155,16 @@ Here is a short list of things that were important:
150
155
* Attention head dimensions were 128, have to be at least 256 for the TPU v6e since it multiplies matrices by 256*256 blocks.
151
156
* Vanilla attention implementation is slow-ish, even at 2k context length, splash (sparse + flash) attention is the fastest.
152
157
* Manual implementation of AdamW was compiled into many different XLA programs because of for-loops over parameter leaves; switching to `optax.adamw()` gained ~10pp MXU[^optax].
153
-
* Batch size with the maximum MXU usage was slower than we wanted for the training stability: adding gradient accumulation (using 16 microbatches of size 4) pushed the MXU usage over 50%.
154
-
* Chunked LM head computation helps to reduce HBM usage - otherwise we see multi-GB tensors in the XProf.
158
+
* Batch size with the maximum MXU usage was slower than I wanted for the training stability: adding gradient accumulation (using 16 microbatches of size 4) pushed the MXU usage over 50%.
159
+
* Chunked LM head computation helped to reduce HBM usage - otherwise I saw multi-GB tensors in the XProf.
155
160
156
161
And in general: splitting the problem into smaller pieces and analyzing them separately speeds up the performance debugging enormously. Another important superpower: looking at XProf and finding where the MXU is idle and why.
157
162
158
-
[^optax]: A Python for loop over parameter leaves inside @jax.jit traced as 58 separate XLA programs that couldn't be fused. optax.adamw uses jax.tree.map internally.
163
+
I highly recommend to open <ahref="https://colab.research.google.com/github/vorushin/tpuchat/blob/master/05_tpu_perf.ipynb?flush_caches=true">this notebook</a>, click through cells one by one, and see if the results match your expectations[^asimov].
164
+
165
+
[^asimov]: "The most exciting phrase to hear in science, the one that heralds new discoveries, is not 'Eureka!' but 'That's funny...'" — Isaac Asimov.
166
+
167
+
[^optax]: Do not loop over parameter leaves, use jax.tree.map. It's JAX 101, but CC didn't consider this when porting from nanochat.
159
168
160
169
### TPU v5e
161
170
@@ -169,7 +178,7 @@ Our baseline model is small enough to fit into 16 GB of TPU v5e HBM.
169
178
It shows MXU usage of 80.6% when run on TPU v5e vs 51.4% on v6e. The older generation of TPUs has lower arithmetic intensity[^arithmetic_intensity] and therefore are much easier to saturate. The throughput is 3x less though (TPU v5e does 4.7x less matmuls per second). This also means that we will have to be creative in saturating newer generations of the accelerators[^tpu_generations].
170
179
171
180
[^arithmetic_intensity]: FLOPs / HBM throughput. [All About Rooflines](https://jax-ml.github.io/scaling-book/roofline/) from the TPU book is a great read.
172
-
[^tpu_generations]: bf16 arithmetic intensity across TPU generations: v5e → v6e grew from 246 to 574 FLOPs/byte (compute 4.7x, bandwidth 2x), while v5p → [Ironwood](https://cloud.google.com/tpu/docs/tpu7x) (v7) grew from 166 to 313 (compute 5x, bandwidth 2.7x).
181
+
[^tpu_generations]: bf16 arithmetic intensity across TPU generations: v5e → v6e grew from 246 to 574 FLOPs/byte (compute 4.7x, bandwidth 2x), while v5p → [v7](https://cloud.google.com/tpu/docs/tpu7x) grew from 166 to 313 (compute 5x, bandwidth 2.7x).
0 commit comments