-
Notifications
You must be signed in to change notification settings - Fork 0
Let jax load it automatically #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This way its easier to override by hand
Nice! |
yes, that would be a great idea... |
# Import the cluster to register it automatically | ||
from .mpitrampoline_cluster import MPITrampolineLocalCluster | ||
|
||
__version__ = "0.1.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__version__ = "0.1.0" | |
__version__ = "0.1.1" |
Actually I think this doesn't work. jaxlib tries to dlopen the mpi library before the plugin is loaded (but doesn't error until you call MPI_INIT, ignoring values set later) seems like I was fooled by it just using gloo |
is this because jax calls initialise() on the plugins too late? |
yes, it fails at a very early stage when jaxlib (libjax_common.dylib) is imported: MPITRAMPOLINE_VERBOSE=1 MPITRAMPOLINE_LIB=/dev/null python3 -c 'print("importing jax", flush=True); import jax; print("done importing jax", flush=True)' If you dont set MPITRAMPOLINE_LIB here still silently errors, but only prints the message that its not set when MPI_Init is called later, see https://github.com/eschnett/MPItrampoline/blob/67292e8b1ac40aa5bd6d0a5dab669da32405a2d7/src/mpi.c#L618-L638 this is exectued automatically when the jaxlib dylib is opened because of It can be skipped though by setting MPITRAMPOLINE_DELAY_INIT=1 could set it as default when compiling jaxlib here and then explicitly initializing it with I would assume they might be willing to merge such a change, I can see if I can open a PR in xla. |
if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys(): | ||
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi" | ||
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has no effect, would need to change it to jax.config.update('jax_cpu_collectives_implementation', 'mpi')
One way to do it would be to patch it in xla, so that it initializes automatically when you call mpi_init. |
no import needed.
uses the same mechanism as e.g. the gpu plugin.
See
https://github.com/jax-ml/jax/blob/13248ce80f448e5a60c2c91dd6a544aae0a81423/jax/_src/xla_bridge.py#L424
Example:
could try to upstream the MPITrampolineLocalCluster to jax as MPICHCluster or whatever