Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quickstart: using JIT-compiled functions, jax.vmap does not provide significant speedup over naive implementation #25703

Open
zsamboki opened this issue Dec 31, 2024 · 1 comment
Assignees
Labels
question Questions for the JAX team

Comments

@zsamboki
Copy link

Description

In the Auto-vectorization section of the Quickstart, we are presented 3 versions of batched matrix-vector multiplication:

  1. A naively batched version, with an explicit for loop,
  2. A manually batched version, using matrix multiplication and
  3. An auto-vectorized version, using jax.vmap.

In the code cells of the Quickstart, we apply jax.jit to versions 2 and 3 and not version 1. I measured speeds both on CPU and GPU, with and without JIT-compilation for all 3 versions and ended up with the following results: Both on CPU and GPU, all JIT-compiled versions are much faster than all non-JIT-compiled versions. Moreover, I'm not seeing clear differences between the speeds of the JIT-compiled versions.

It would be nice to have a minimal example where auto-vectorization makes code significantly faster, even with JIT-compilation.

Code:

import jax
import jax.numpy as jnp
from timeit import timeit

def test():
    key = jax.random.key(1701)

    key1, key2 = jax.random.split(key)
    mat = jax.random.normal(key1, (150, 100))
    batched_x = jax.random.normal(key2, (10, 100))

    def apply_matrix(x):
        return jnp.dot(mat, x)

    def naively_batched_apply_matrix(v_batched):
        return jnp.stack([apply_matrix(v) for v in v_batched])

    naively_batched_apply_matrix(batched_x)

    print(
        f"Naively batched: {timeit(lambda: naively_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    def batched_apply_matrix(batched_x):
        return jnp.dot(batched_x, mat.T)

    batched_apply_matrix(batched_x)

    print(
        f"Manually batched: {timeit(lambda: batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    def vmap_batched_apply_matrix(batched_x):
        return jax.vmap(apply_matrix)(batched_x)

    vmap_batched_apply_matrix(batched_x)

    print(
        f"Auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def naively_batched_apply_matrix_jit(v_batched):
        return jnp.stack([apply_matrix(v) for v in v_batched])

    naively_batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled naively batched: {timeit(lambda: naively_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def batched_apply_matrix_jit(batched_x):
        return jnp.dot(batched_x, mat.T)

    batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled manually batched: {timeit(lambda: batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

    @jax.jit
    def vmap_batched_apply_matrix_jit(batched_x):
        return jax.vmap(apply_matrix)(batched_x)

    vmap_batched_apply_matrix_jit(batched_x)

    print(
        f"JIT compiled auto-vectorized with vmap: {timeit(lambda: vmap_batched_apply_matrix_jit(batched_x).block_until_ready(), number=1000):.4f} ms"
    )

if __name__ == "__main__":
    # https://stackoverflow.com/a/74590238
    gpu_device = jax.devices('gpu')[0]
    cpu_device = jax.devices('cpu')[0]

    print("On CPU\n------")
    with jax.default_device(cpu_device):
        test()

    print("\nOn GPU\n------")
    with jax.default_device(gpu_device):
        test()

Output:

On CPU
------
Naively batched: 0.2156 ms
Manually batched: 0.0445 ms
Auto-vectorized with vmap: 0.1096 ms
JIT compiled naively batched: 0.0132 ms
JIT compiled manually batched: 0.0091 ms
JIT compiled auto-vectorized with vmap: 0.0124 ms

On GPU
------
Naively batched: 0.6071 ms
Manually batched: 0.0593 ms
Auto-vectorized with vmap: 0.2160 ms
JIT compiled naively batched: 0.0268 ms
JIT compiled manually batched: 0.0315 ms
JIT compiled auto-vectorized with vmap: 0.0346 ms

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 4070 Laptop GPU-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='StrikeFreedom', release='6.8.0-51-generic', version='#52-Ubuntu SMP PREEMPT_DYNAMIC Thu Dec  5 13:09:44 UTC 2024', machine='x86_64')


$ nvidia-smi
Tue Dec 31 10:38:31 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4070 ...    Off |   00000000:01:00.0  On |                  N/A |
| N/A   42C    P3             15W /   45W |    6834MiB /   8188MiB |     20%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3043      G   /usr/lib/xorg/Xorg                            260MiB |
|    0   N/A  N/A      3383      G   /usr/bin/gnome-shell                           45MiB |
|    0   N/A  N/A      6471      G   ...irefox/5437/usr/lib/firefox/firefox        301MiB |
|    0   N/A  N/A     10054      G   /usr/bin/transmission-gtk                       8MiB |
|    0   N/A  N/A     18500      G   ...erProcess --variations-seed-version         29MiB |
|    0   N/A  N/A     23218      C   python                                       6094MiB |
+-----------------------------------------------------------------------------------------+
@zsamboki zsamboki added the bug Something isn't working label Dec 31, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 31, 2024

Interesting exploration, thanks for sharing! However runtime isn't the only consideration: if you benchmark JIT compilation time, I suspect you'll find that the naively batched version to be much more costly, and this compilation cost will grow super-linearly with the batch size.

@jakevdp jakevdp added question Questions for the JAX team and removed bug Something isn't working labels Dec 31, 2024
@jakevdp jakevdp self-assigned this Dec 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants