Quickstart: using JIT-compiled functions, jax.vmap
does not provide significant speedup over naive implementation
#25703
Labels
question
Questions for the JAX team
Description
In the Auto-vectorization section of the Quickstart, we are presented 3 versions of batched matrix-vector multiplication:
for
loop,jax.vmap
.In the code cells of the Quickstart, we apply
jax.jit
to versions 2 and 3 and not version 1. I measured speeds both on CPU and GPU, with and without JIT-compilation for all 3 versions and ended up with the following results: Both on CPU and GPU, all JIT-compiled versions are much faster than all non-JIT-compiled versions. Moreover, I'm not seeing clear differences between the speeds of the JIT-compiled versions.It would be nice to have a minimal example where auto-vectorization makes code significantly faster, even with JIT-compilation.
Code:
Output:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: