Skip to content

Quickstart: using JIT-compiled functions, jax.vmap does not provide significant speedup over naive implementation #25703

Closed
@zsamboki

Description

@zsamboki

Description

In the Auto-vectorization section of the Quickstart, we are presented 3 versions of batched matrix-vector multiplication:

  1. A naively batched version, with an explicit for loop,
  2. A manually batched version, using matrix multiplication and
  3. An auto-vectorized version, using jax.vmap.

In the code cells of the Quickstart, we apply jax.jit to versions 2 and 3 and not version 1. I measured speeds both on CPU and GPU, with and without JIT-compilation for all 3 versions and ended up with the following results: Both on CPU and GPU, all JIT-compiled versions are much faster than all non-JIT-compiled versions. Moreover, I'm not seeing clear differences between the speeds of the JIT-compiled versions.

It would be nice to have a minimal example where auto-vectorization makes code significantly faster, even with JIT-compilation.

Code:

import jax
import jax.numpy as jnp
from timeit import timeit

def test():
    key = jax.random.key(1701)

    key1, key2 = jax.random.split(key)
    mat = jax.random.normal(key1, (150, 100))
    batched_x = jax.random.normal(key2, (10, 100))

    def apply_matrix(x):
        return jnp.dot(mat, x)

    def naively_batched_apply_matrix(v_batched):
        return jnp.stack([apply_matrix(v) for v in v_batched])

    naively_batched_apply_matrix(batched_x)

    print(
        f"Naively batched: {timeit(lambda: naively_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    def batched_apply_matrix(batched_x):
        return jnp.dot(batched_x, mat.T)

    batched_apply_matrix(batched_x)

    print(
        f"Manually batched: {timeit(lambda: batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    def vmap_batched_apply_matrix(batched_x):
        return jax.vmap(apply_matrix)(batched_x)

    vmap_batched_apply_matrix(batched_x)

    print(
        f"Auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def naively_batched_apply_matrix_jit(v_batched):
        return jnp.stack([apply_matrix(v) for v in v_batched])

    naively_batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled naively batched: {timeit(lambda: naively_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def batched_apply_matrix_jit(batched_x):
        return jnp.dot(batched_x, mat.T)

    batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled manually batched: {timeit(lambda: batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def vmap_batched_apply_matrix_jit(batched_x):
        return jax.vmap(apply_matrix)(batched_x)

    vmap_batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

if __name__ == "__main__":
    # https://stackoverflow.com/a/74590238
    gpu_device = jax.devices('gpu')[0]
    cpu_device = jax.devices('cpu')[0]

    print("On CPU\n------")
    with jax.default_device(cpu_device):
        test()

    print("\nOn GPU\n------")
    with jax.default_device(gpu_device):
        test()

Output:

On CPU
------
Naively batched: 0.2156 ms
Manually batched: 0.0445 ms
Auto-vectorized with vmap: 0.1096 ms
JIT compiled naively batched: 0.0132 ms
JIT compiled manually batched: 0.0091 ms
JIT compiled auto-vectorized with vmap: 0.0124 ms

On GPU
------
Naively batched: 0.6071 ms
Manually batched: 0.0593 ms
Auto-vectorized with vmap: 0.2160 ms
JIT compiled naively batched: 0.0268 ms
JIT compiled manually batched: 0.0315 ms
JIT compiled auto-vectorized with vmap: 0.0346 ms

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4070 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='StrikeFreedom', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec  5 13:09:44 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Dec 31 10:38:31 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4070 ...    Off |   00000000:01:00.0  On |                  N/A |
| N/A   42C    P3             15W /   45W |    6834MiB /   8188MiB |     20%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3043      G   /usr/lib/xorg/Xorg                            260MiB |
|    0   N/A  N/A      3383      G   /usr/bin/gnome-shell                           45MiB |
|    0   N/A  N/A      6471      G   ...irefox/5437/usr/lib/firefox/firefox        301MiB |
|    0   N/A  N/A     10054      G   /usr/bin/transmission-gtk                       8MiB |
|    0   N/A  N/A     18500      G   ...erProcess --variations-seed-version         29MiB |
|    0   N/A  N/A     23218      C   python                                       6094MiB |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

Labels

questionQuestions for the JAX team

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions