Skip to content

Commit 4d0eac3

Browse files
authored
Edits and fixes in JAX intro lecture (#101)
* misc * misc * misc
1 parent 5fd7f38 commit 4d0eac3

File tree

1 file changed

+79
-16
lines changed

1 file changed

+79
-16
lines changed

lectures/jax_intro.md

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ from jax.numpy import linalg
9494
```
9595

9696
```{code-cell} ipython3
97-
linalg.solve(B, A)
97+
linalg.inv(B) # Inverse of identity is identity
9898
```
9999

100100
```{code-cell} ipython3
@@ -104,7 +104,7 @@ linalg.eigh(B) # Computes eigenvalues and eigenvectors
104104
### Differences
105105

106106

107-
One difference between NumPy and JAX is that, when running on a GPU, JAX uses 32 bit floats by default.
107+
One difference between NumPy and JAX is that JAX currently uses 32 bit floats by default.
108108

109109
This is standard for GPU computing and can lead to significant speed gains with small loss of precision.
110110

@@ -260,19 +260,22 @@ One point to remember is that JAX expects tuples to describe array shapes, even
260260
random.normal(key, (5, ))
261261
```
262262

263-
## JIT Compilation
264263

265264

266-
The JAX JIT compiler accelerates logic within functions by fusing linear
267-
algebra operations into a single, highly optimized kernel that the host can
265+
## JIT compilation
266+
267+
The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear
268+
algebra operations into a single optimized kernel that the host can
268269
launch on the GPU / TPU (or CPU if no accelerator is detected).
269270

271+
### A first example
270272

271-
Consider the following pure Python function.
273+
To see the JIT compiler in action, consider the following function.
272274

273275
```{code-cell} ipython3
274-
def f(x, p=1000):
275-
return sum((k*x for k in range(p)))
276+
def f(x):
277+
a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5
278+
return jnp.sum(a)
276279
```
277280

278281
Let's build an array to call the function on.
@@ -291,18 +294,65 @@ How long does the function take to execute?
291294
```{note}
292295
Here, in order to measure actual speed, we use the `block_until_ready()` method
293296
to hold the interpreter until the results of the computation are returned from
294-
the device.
297+
the device. This is necessary because JAX uses asynchronous dispatch, which
298+
allows the Python interpreter to run ahead of GPU computations.
299+
300+
```
301+
302+
The code doesn't run as fast as we might hope, given that it's running on a GPU.
295303

296-
This is necessary because JAX uses asynchronous dispatch, which allows the
297-
Python interpreter to run ahead of GPU computations.
304+
But if we run it a second time it becomes much faster:
298305

306+
```{code-cell} ipython3
307+
%time f(x).block_until_ready()
299308
```
300309

301-
This code is not particularly fast.
310+
This is because the built in functions like `jnp.cos` are JIT compiled and the
311+
first run includes compile time.
312+
313+
Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of
314+
just providing pre-compiled versions, like NumPy?
302315

303-
While it is run on the GPU (since `x` is a JAX array), each vector `k * x` has to be instantiated before the final sum is computed.
316+
The reason is that the JIT compiler can specialize on the *size* of the array
317+
being used, which is helpful for parallelization.
304318

305-
If we JIT-compile the function with JAX, then the operations are fused and no intermediate arrays are created.
319+
For example, in running the code above, the JIT compiler produced a version of `jnp.cos` that is
320+
specialized to floating point arrays of size `n = 50_000_000`.
321+
322+
We can check this by calling `f` with a new array of different size.
323+
324+
```{code-cell} ipython3
325+
m = 50_000_001
326+
y = jnp.ones(m)
327+
```
328+
329+
```{code-cell} ipython3
330+
%time f(y).block_until_ready()
331+
```
332+
333+
Notice that the execution time increases, because now new versions of
334+
the built-ins like `jnp.cos` are being compiled, specialized to the new array
335+
size.
336+
337+
If we run again, the code is dispatched to the correct compiled version and we
338+
get faster execution.
339+
340+
```{code-cell} ipython3
341+
%time f(y).block_until_ready()
342+
```
343+
344+
The compiled versions for the previous array size are still available in memory
345+
too, and the following call is dispatched to the correct compiled code.
346+
347+
```{code-cell} ipython3
348+
%time f(x).block_until_ready()
349+
```
350+
351+
352+
353+
### Compiling the outer function
354+
355+
We can do even better if we manually JIT-compile the outer function.
306356

307357
```{code-cell} ipython3
308358
f_jit = jax.jit(f) # target for JIT compilation
@@ -320,7 +370,20 @@ And now let's time it.
320370
%time f_jit(x).block_until_ready()
321371
```
322372

323-
Note the large speed gain.
373+
Note the speed gain.
374+
375+
This is because the array operations are fused and no intermediate arrays are created.
376+
377+
378+
Incidentally, a more common syntax when targetting a function for the JIT
379+
compiler is
380+
381+
```{code-cell} ipython3
382+
@jax.jit
383+
def f(x):
384+
a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5
385+
return jnp.sum(a)
386+
```
324387

325388

326389
## Functional Programming
@@ -377,7 +440,7 @@ f(x)
377440
Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of `a` takes effect:
378441

379442
```{code-cell} ipython3
380-
x = np.ones(3)
443+
x = jnp.ones(3)
381444
```
382445

383446
```{code-cell} ipython3

0 commit comments

Comments
 (0)