Skip to content

Conversation

inailuig
Copy link
Collaborator

@inailuig inailuig commented Jul 25, 2025

no import needed.

uses the same mechanism as e.g. the gpu plugin.

See
https://github.com/jax-ml/jax/blob/13248ce80f448e5a60c2c91dd6a544aae0a81423/jax/_src/xla_bridge.py#L424

Example:

import os
import jax

print("Setup initialize", flush=True)

jax.distributed.initialize(coordinator_address="127.0.0.1:50000", process_id =int(os.environ['PMI_RANK']), num_processes=int(os.environ['PMI_SIZE']))

print(f"{jax.process_index()}/{jax.process_count()} :", jax.local_devices())
print(f"{jax.process_index()}/{jax.process_count()} :", jax.devices())

x = jax.numpy.ones(
    (jax.device_count(),),
    device=jax.sharding.NamedSharding(
        jax.sharding.Mesh(jax.devices(), "i"), jax.sharding.PartitionSpec("i")
    ),
)

print(f"{jax.process_index()}/{jax.process_count()} :", x.sum())

could try to upstream the MPITrampolineLocalCluster to jax as MPICHCluster or whatever

@inailuig inailuig requested a review from PhilipVinc July 30, 2025 13:12
@PhilipVinc
Copy link
Member

Nice!
Could we do the same to register the cluster as well, so no import needed also for automatic setup?

@PhilipVinc
Copy link
Member

could try to upstream the MPITrampolineLocalCluster to jax as MPICHCluster or whatever

yes, that would be a great idea...
But the current implementation won't really work for anything more than 1 node run..
(Someone should also upstream an openmpi5 version, tho...)

# Import the cluster to register it automatically
from .mpitrampoline_cluster import MPITrampolineLocalCluster

__version__ = "0.1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__version__ = "0.1.0"
__version__ = "0.1.1"

@inailuig
Copy link
Collaborator Author

Actually I think this doesn't work.

jaxlib tries to dlopen the mpi library before the plugin is loaded (but doesn't error until you call MPI_INIT, ignoring values set later)

seems like I was fooled by it just using gloo

@inailuig inailuig closed this Jul 30, 2025
@PhilipVinc
Copy link
Member

is this because jax calls initialise() on the plugins too late?

@inailuig
Copy link
Collaborator Author

inailuig commented Jul 31, 2025

yes, it fails at a very early stage when jaxlib (libjax_common.dylib) is imported:

MPITRAMPOLINE_VERBOSE=1 MPITRAMPOLINE_LIB=/dev/null python3 -c 'print("importing jax", flush=True); import jax; print("done importing jax", flush=True)'

If you dont set MPITRAMPOLINE_LIB here still silently errors, but only prints the message that its not set when MPI_Init is called later, see https://github.com/eschnett/MPItrampoline/blob/67292e8b1ac40aa5bd6d0a5dab669da32405a2d7/src/mpi.c#L618-L638
It ignores MPITRAMPOLINE_LIB set later on :/

this is exectued automatically when the jaxlib dylib is opened because of
https://github.com/eschnett/MPItrampoline/blob/67292e8b1ac40aa5bd6d0a5dab669da32405a2d7/src/mpi.c#L735-L740

It can be skipped though by setting MPITRAMPOLINE_DELAY_INIT=1

could set it as default when compiling jaxlib here
https://github.com/openxla/xla/blob/86097f9f5d705ecad9afb139c12e02c8643743e4/third_party/mpitrampoline/mpitrampoline.BUILD#L40

and then explicitly initializing it with mpitrampoline_init() here:
https://github.com/openxla/xla/blob/86097f9f5d705ecad9afb139c12e02c8643743e4/xla/backends/cpu/collectives/mpi_collectives.cc#L38

I would assume they might be willing to merge such a change, I can see if I can open a PR in xla.

@inailuig inailuig reopened this Jul 31, 2025
Comment on lines +15 to +17
if "JAX_CPU_COLLECTIVES_IMPLEMENTATION" not in os.environ.keys():
os.environ["JAX_CPU_COLLECTIVES_IMPLEMENTATION"] = "mpi"
print("mpibackend4jax: Set JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has no effect, would need to change it to jax.config.update('jax_cpu_collectives_implementation', 'mpi')

@inailuig
Copy link
Collaborator Author

I would assume they might be willing to merge such a change, I can see if I can open a PR in xla.

yes, it fails at a very early stage when jaxlib (libjax_common.dylib) is imported:

MPITRAMPOLINE_VERBOSE=1 MPITRAMPOLINE_LIB=/dev/null python3 -c 'print("importing jax", flush=True); import jax; print("done importing jax", flush=True)'

If you dont set MPITRAMPOLINE_LIB here still silently errors, but only prints the message that its not set when MPI_Init is called later, see https://github.com/eschnett/MPItrampoline/blob/67292e8b1ac40aa5bd6d0a5dab669da32405a2d7/src/mpi.c#L618-L638 It ignores MPITRAMPOLINE_LIB set later on :/

this is exectued automatically when the jaxlib dylib is opened because of https://github.com/eschnett/MPItrampoline/blob/67292e8b1ac40aa5bd6d0a5dab669da32405a2d7/src/mpi.c#L735-L740

It can be skipped though by setting MPITRAMPOLINE_DELAY_INIT=1

could set it as default when compiling jaxlib here https://github.com/openxla/xla/blob/86097f9f5d705ecad9afb139c12e02c8643743e4/third_party/mpitrampoline/mpitrampoline.BUILD#L40

and then explicitly initializing it with mpitrampoline_init() here: https://github.com/openxla/xla/blob/86097f9f5d705ecad9afb139c12e02c8643743e4/xla/backends/cpu/collectives/mpi_collectives.cc#L38

I would assume they might be willing to merge such a change, I can see if I can open a PR in xla.

One way to do it would be to patch it in xla, so that it initializes automatically when you call mpi_init.

inailuig/xla@3f46cbb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants