Closed
Description
Description
In the Auto-vectorization section of the Quickstart, we are presented 3 versions of batched matrix-vector multiplication:
- A naively batched version, with an explicit
for
loop, - A manually batched version, using matrix multiplication and
- An auto-vectorized version, using
jax.vmap
.
In the code cells of the Quickstart, we apply jax.jit
to versions 2 and 3 and not version 1. I measured speeds both on CPU and GPU, with and without JIT-compilation for all 3 versions and ended up with the following results: Both on CPU and GPU, all JIT-compiled versions are much faster than all non-JIT-compiled versions. Moreover, I'm not seeing clear differences between the speeds of the JIT-compiled versions.
It would be nice to have a minimal example where auto-vectorization makes code significantly faster, even with JIT-compilation.
Code:
import jax
import jax.numpy as jnp
from timeit import timeit
def test():
key = jax.random.key(1701)
key1, key2 = jax.random.split(key)
mat = jax.random.normal(key1, (150, 100))
batched_x = jax.random.normal(key2, (10, 100))
def apply_matrix(x):
return jnp.dot(mat, x)
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
naively_batched_apply_matrix(batched_x)
print(
f"Naively batched: {timeit(lambda: naively_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
)
def batched_apply_matrix(batched_x):
return jnp.dot(batched_x, mat.T)
batched_apply_matrix(batched_x)
print(
f"Manually batched: {timeit(lambda: batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
)
def vmap_batched_apply_matrix(batched_x):
return jax.vmap(apply_matrix)(batched_x)
vmap_batched_apply_matrix(batched_x)
print(
f"Auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
)
@jax.jit
def naively_batched_apply_matrix_jit(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
naively_batched_apply_matrix_jit(batched_x)
print(
f"JIT compiled naively batched: {timeit(lambda: naively_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
)
@jax.jit
def batched_apply_matrix_jit(batched_x):
return jnp.dot(batched_x, mat.T)
batched_apply_matrix_jit(batched_x)
print(
f"JIT compiled manually batched: {timeit(lambda: batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
)
@jax.jit
def vmap_batched_apply_matrix_jit(batched_x):
return jax.vmap(apply_matrix)(batched_x)
vmap_batched_apply_matrix_jit(batched_x)
print(
f"JIT compiled auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
)
if __name__ == "__main__":
# https://stackoverflow.com/a/74590238
gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]
print("On CPU\n------")
with jax.default_device(cpu_device):
test()
print("\nOn GPU\n------")
with jax.default_device(gpu_device):
test()
Output:
On CPU
------
Naively batched: 0.2156 ms
Manually batched: 0.0445 ms
Auto-vectorized with vmap: 0.1096 ms
JIT compiled naively batched: 0.0132 ms
JIT compiled manually batched: 0.0091 ms
JIT compiled auto-vectorized with vmap: 0.0124 ms
On GPU
------
Naively batched: 0.6071 ms
Manually batched: 0.0593 ms
Auto-vectorized with vmap: 0.2160 ms
JIT compiled naively batched: 0.0268 ms
JIT compiled manually batched: 0.0315 ms
JIT compiled auto-vectorized with vmap: 0.0346 ms
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.2.1
python: 3.13.1 | packaged by conda-forge | (main, Dec 5 2024, 21:23:54) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4070 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='StrikeFreedom', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec 5 13:09:44 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Dec 31 10:38:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4070 ... Off | 00000000:01:00.0 On | N/A |
| N/A 42C P3 15W / 45W | 6834MiB / 8188MiB | 20% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3043 G /usr/lib/xorg/Xorg 260MiB |
| 0 N/A N/A 3383 G /usr/bin/gnome-shell 45MiB |
| 0 N/A N/A 6471 G ...irefox/5437/usr/lib/firefox/firefox 301MiB |
| 0 N/A N/A 10054 G /usr/bin/transmission-gtk 8MiB |
| 0 N/A N/A 18500 G ...erProcess --variations-seed-version 29MiB |
| 0 N/A N/A 23218 C python 6094MiB |
+-----------------------------------------------------------------------------------------+