Skip to content

Commit 9a150f7

Browse files
authored
feat: use a global state to setup pjrt distributed runtime (#780)
fix: devices are not necessarily from 0 to N-1 fix: initialize clients on first use rather than on init fix: make device selection consistent with clients feat: add OMPI cluster detection fix: correctly set kv_store refactor: Distributed setup is not PJRT specific refactor: OMPI detection doesn't need to be in an extension
1 parent 6b760ba commit 9a150f7

12 files changed

+576
-63
lines changed

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3232
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3333
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
34+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3435
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3536
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3637
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
@@ -47,6 +48,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
4748
ReactantArrayInterfaceExt = "ArrayInterface"
4849
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
4950
ReactantKernelAbstractionsExt = "KernelAbstractions"
51+
ReactantMPIExt = "MPI"
5052
ReactantNNlibExt = "NNlib"
5153
ReactantOffsetArraysExt = "OffsetArrays"
5254
ReactantPythonCallExt = "PythonCall"
@@ -72,6 +74,7 @@ KernelAbstractions = "0.9.30"
7274
LLVM = "9.1"
7375
LLVMOpenMP_jll = "18.1.7"
7476
LinearAlgebra = "1.10"
77+
MPI = "0.20"
7578
NNlib = "0.9.26"
7679
OffsetArrays = "1"
7780
OrderedCollections = "1"

ext/ReactantMPIExt.jl

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module ReactantMPIExt
2+
3+
using Reactant: Reactant, Distributed
4+
using MPI: MPI
5+
6+
# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
7+
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()
8+
9+
function Distributed.get_coordinator_address(
10+
::Distributed.MPIEnvDetector, timeout_in_seconds::Integer
11+
)
12+
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
13+
hostname = gethostname()
14+
port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1)
15+
hostname = "$(hostname):$(port_id)"
16+
else
17+
hostname = nothing
18+
end
19+
20+
return MPI.bcast(hostname, MPI.COMM_WORLD; root=0)
21+
end
22+
23+
function Distributed.get_process_count(::Distributed.MPIEnvDetector)
24+
return Int(MPI.Comm_size(MPI.COMM_WORLD))
25+
end
26+
27+
function Distributed.get_process_id(::Distributed.MPIEnvDetector)
28+
return Int(MPI.Comm_rank(MPI.COMM_WORLD))
29+
end
30+
31+
function Distributed.get_local_process_id(::Distributed.MPIEnvDetector)
32+
new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0)
33+
return Int(MPI.Comm_rank(new_comm))
34+
end
35+
36+
end

src/Compiler.jl

+8-14
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,8 @@ function compile_mlir(f, args; client=nothing, kwargs...)
495495
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
496496
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
497497

498-
if client !== nothing
499-
backend = XLA.platform_name(client)
500-
else
501-
backend = XLA.platform_name(XLA.default_backend[])
502-
end
498+
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
499+
503500
if backend == "CUDA"
504501
backend = "GPU"
505502
elseif backend == "CPU"
@@ -1423,7 +1420,7 @@ end
14231420

14241421
function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
14251422
if is_sharded
1426-
client === nothing && (client = XLA.default_backend[])
1423+
client === nothing && (client = XLA.default_backend())
14271424
return client, nothing
14281425
end
14291426

@@ -1449,14 +1446,14 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
14491446
if device !== nothing
14501447
client = XLA.client(device)
14511448
else
1452-
client = XLA.default_backend[]
1453-
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
1449+
client = XLA.default_backend()
1450+
device = XLA.default_device(client)
14541451
end
14551452
else
14561453
if device !== nothing
14571454
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
14581455
else
1459-
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
1456+
device = XLA.default_device(client)
14601457
end
14611458
end
14621459

@@ -1469,11 +1466,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
14691466
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
14701467
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
14711468

1472-
if client !== nothing
1473-
backend = XLA.platform_name(client)
1474-
else
1475-
backend = XLA.platform_name(XLA.default_backend[])
1476-
end
1469+
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
1470+
14771471
if backend == "CUDA"
14781472
backend = "GPU"
14791473
elseif backend == "CPU"

src/Devices.jl

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
"""
22
devices(backend::String)
3-
devices(backend::XLA.AbstractClient = XLA.default_backend[])
3+
devices(backend::XLA.AbstractClient = XLA.default_backend())
44
5-
Return a list of devices available on the backend.
5+
Return a list of devices available for the given client.
66
"""
7-
devices(backend::String) = devices(XLA.backends[backend])
7+
devices(backend::String) = devices(XLA.client(backend))
88

9-
function devices(client::XLA.AbstractClient=XLA.default_backend[])
10-
ndevices = XLA.num_devices(client)
11-
return [XLA.get_device(client, i - 1) for i in 1:ndevices]
12-
end
9+
devices(client::XLA.AbstractClient=XLA.default_backend()) = XLA.devices(client)
1310

1411
"""
1512
addressable_devices(backend::String)
16-
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend[])
13+
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend())
1714
18-
Return a list of addressable devices available on the backend.
15+
Return a list of addressable devices available for the given client.
1916
"""
20-
addressable_devices(backend::String) = addressable_devices(XLA.backends[backend])
17+
addressable_devices(backend::String) = addressable_devices(XLA.client(backend))
2118

22-
function addressable_devices(client::XLA.AbstractClient=XLA.default_backend[])
23-
ndevices = XLA.num_addressable_devices(client)
24-
return [XLA.get_addressable_device(client, i - 1) for i in 1:ndevices]
19+
function addressable_devices(client::XLA.AbstractClient=XLA.default_backend())
20+
return XLA.addressable_devices(client)
2521
end
2622

2723
# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55

src/Distributed.jl

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
module Distributed
2+
3+
using ..Reactant: Reactant
4+
5+
const initialized = Ref(false)
6+
7+
function initialize(;
8+
coordinator_address::Union{Nothing,String}=nothing,
9+
num_processes::Union{Nothing,Integer}=nothing,
10+
process_id::Union{Nothing,Integer}=nothing,
11+
local_device_ids::Union{Nothing,Vector{Int}}=nothing,
12+
initialization_timeout_in_seconds::Integer=300,
13+
kwargs...,
14+
)
15+
@assert !initialized[] "`Distributed.initialize` has already been called"
16+
17+
(coordinator_address, num_processes, process_id, local_device_ids) = auto_detect_unset_distributed_params(;
18+
coordinator_address,
19+
num_processes,
20+
process_id,
21+
local_device_ids,
22+
initialization_timeout_in_seconds,
23+
)
24+
25+
@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_device_ids
26+
27+
Reactant.XLA.update_global_state!(;
28+
coordinator_address, num_processes, process_id, local_device_ids, kwargs...
29+
)
30+
31+
@debug "New Global State" Reactant.XLA.global_state
32+
33+
initialized[] = true
34+
return nothing
35+
end
36+
37+
abstract type AbstractClusterEnvDetector end
38+
39+
abstract type AbstractOMPIClusterEnvDetector <: AbstractClusterEnvDetector end
40+
41+
struct OpenMPIORTEEnvDetector <: AbstractOMPIClusterEnvDetector end
42+
struct OpenMPIPMIXEnvDetector <: AbstractOMPIClusterEnvDetector end
43+
44+
struct MPIEnvDetector <: AbstractClusterEnvDetector end
45+
46+
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/cluster.py
47+
48+
is_env_present(::AbstractClusterEnvDetector) = false
49+
50+
function get_coordinator_address end
51+
function get_process_count end
52+
function get_process_id end
53+
function get_local_process_id end
54+
55+
function auto_detect_unset_distributed_params(;
56+
detector_list=[OpenMPIORTEEnvDetector(), OpenMPIPMIXEnvDetector(), MPIEnvDetector()],
57+
coordinator_address::Union{Nothing,String}=nothing,
58+
num_processes::Union{Nothing,Integer}=nothing,
59+
process_id::Union{Nothing,Integer}=nothing,
60+
local_device_ids::Union{Nothing,Vector{Int}}=nothing,
61+
initialization_timeout_in_seconds::Integer=300,
62+
)
63+
if all(
64+
Base.Fix2(!==, nothing),
65+
(coordinator_address, num_processes, process_id, local_device_ids),
66+
)
67+
return coordinator_address, num_processes, process_id, local_device_ids
68+
end
69+
70+
idx = findfirst(is_env_present, detector_list)
71+
if idx === nothing
72+
error("Couldn't find a functional cluster environment detector. Attempted to use: \
73+
$(detector_list)")
74+
end
75+
76+
detector = detector_list[idx]
77+
78+
@debug "Detected cluster environment" detector
79+
80+
if coordinator_address === nothing
81+
coordinator_address = get_coordinator_address(
82+
detector, initialization_timeout_in_seconds
83+
)
84+
end
85+
86+
if num_processes === nothing
87+
num_processes = get_process_count(detector)
88+
end
89+
90+
if process_id === nothing
91+
process_id = get_process_id(detector)
92+
end
93+
94+
if local_device_ids === nothing
95+
local_device_ids = [get_local_process_id(detector)]
96+
end
97+
98+
return coordinator_address, num_processes, process_id, local_device_ids
99+
end
100+
101+
# OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector
102+
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py and adapted for latest OpenMPI versions
103+
const _ORTE_URI = "OMPI_MCA_orte_hnp_uri"
104+
const _PMIX_SERVER_URI = (
105+
"PMIX_SERVER_URI2",
106+
"PMIX_SERVER_URI3",
107+
"PMIX_SERVER_URI4",
108+
"PMIX_SERVER_URI41",
109+
"PMIX_SERVER_URI21",
110+
)
111+
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
112+
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
113+
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"
114+
115+
is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
116+
is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
117+
118+
function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
119+
orte_uri = ENV[_ORTE_URI]
120+
121+
job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
122+
port = job_id % 2^12 + (65535 - 2^12 + 1)
123+
124+
launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
125+
126+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
127+
Open MPI environment."
128+
129+
launcher_ip = launcher_ip_match.captures[findfirst(
130+
!isnothing, launcher_ip_match.captures
131+
)]
132+
return "$(launcher_ip):$(port)"
133+
end
134+
135+
function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer)
136+
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
137+
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
138+
139+
job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
140+
port = job_id % 2^12 + (65535 - 2^12 + 1)
141+
142+
launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
143+
144+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
145+
Open MPI environment."
146+
147+
launcher_ip = launcher_ip_match.captures[findfirst(
148+
!isnothing, launcher_ip_match.captures
149+
)]
150+
151+
return "$(launcher_ip):$(port)"
152+
end
153+
154+
get_process_count(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_COUNT])
155+
156+
get_process_id(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_ID])
157+
158+
function get_local_process_id(::AbstractOMPIClusterEnvDetector)
159+
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
160+
end
161+
162+
end

src/Reactant.jl

+6-7
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ include("Devices.jl")
4444
include("Interpreter.jl")
4545
include("Profiler.jl")
4646
include("Types.jl")
47+
include("Distributed.jl")
4748

4849
const with_profiler = Profiler.with_profiler
4950

@@ -220,15 +221,13 @@ end
220221

221222
function __init__()
222223
initialize_ptrs()
223-
return initialize_dialect()
224-
end
225-
226-
function set_default_backend(backend::XLA.AbstractClient)
227-
return XLA.default_backend[] = backend
224+
initialize_dialect()
225+
return nothing
228226
end
229227

230-
function set_default_backend(backend::String)
231-
return set_default_backend(XLA.backends[backend])
228+
function set_default_backend(backend::Union{String,XLA.AbstractClient})
229+
XLA.set_default_backend(backend)
230+
return nothing
232231
end
233232

234233
include("Precompile.jl")

src/Types.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,18 @@ end
125125

126126
function ConcretePJRTArray(
127127
data::Array{T,N};
128-
client::XLA.AbstractClient=XLA.default_backend[],
128+
client::XLA.AbstractClient=XLA.default_backend(),
129129
idx::Union{Int,Nothing}=nothing,
130130
device::Union{Nothing,XLA.AbstractDevice}=nothing,
131131
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
132132
) where {T,N}
133133
if !Sharding.is_sharded(sharding)
134134
if device === nothing
135-
idx = idx === nothing ? XLA.default_device_idx[] : idx
136-
device = XLA.get_addressable_device(client, idx)
135+
if idx === nothing
136+
device = XLA.default_device(client)
137+
else
138+
device = XLA.get_addressable_device(client, idx)
139+
end
137140
else
138141
if idx !== nothing
139142
device_from_idx = XLA.get_addressable_device(client, idx)

0 commit comments

Comments
 (0)