Skip to content

Commit dad792d

Browse files
committed
refactor: OMPI detection doesn't need to be in an extension
1 parent e979b11 commit dad792d

File tree

2 files changed

+62
-76
lines changed

2 files changed

+62
-76
lines changed

ext/ReactantMPIExt.jl

+1-76
Original file line numberDiff line numberDiff line change
@@ -3,82 +3,7 @@ module ReactantMPIExt
33
using Reactant: Reactant, Distributed
44
using MPI: MPI
55

6-
# Code taken from:
7-
# 1. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
8-
# 2. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py
9-
10-
# Based on ompi_cluster
11-
const _ORTE_URI = "OMPI_MCA_orte_hnp_uri"
12-
const _PMIX_SERVER_URI = (
13-
"PMIX_SERVER_URI2",
14-
"PMIX_SERVER_URI3",
15-
"PMIX_SERVER_URI4",
16-
"PMIX_SERVER_URI41",
17-
"PMIX_SERVER_URI21",
18-
)
19-
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
20-
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
21-
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"
22-
23-
Distributed.is_env_present(::Distributed.OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
24-
25-
function Distributed.is_env_present(::Distributed.OpenMPIPMIXEnvDetector)
26-
return any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
27-
end
28-
29-
function Distributed.get_coordinator_address(
30-
::Distributed.OpenMPIORTEEnvDetector, timeout_in_seconds::Integer
31-
)
32-
orte_uri = ENV[_ORTE_URI]
33-
34-
job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
35-
port = job_id % 2^12 + (65535 - 2^12 + 1)
36-
37-
launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
38-
39-
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
40-
Open MPI environment."
41-
42-
launcher_ip = launcher_ip_match.captures[findfirst(
43-
!isnothing, launcher_ip_match.captures
44-
)]
45-
return "$(launcher_ip):$(port)"
46-
end
47-
48-
function Distributed.get_coordinator_address(
49-
::Distributed.OpenMPIPMIXEnvDetector, timeout_in_seconds::Integer
50-
)
51-
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
52-
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
53-
54-
job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
55-
port = job_id % 2^12 + (65535 - 2^12 + 1)
56-
57-
launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
58-
59-
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
60-
Open MPI environment."
61-
62-
launcher_ip = launcher_ip_match.captures[findfirst(
63-
!isnothing, launcher_ip_match.captures
64-
)]
65-
66-
return "$(launcher_ip):$(port)"
67-
end
68-
69-
function Distributed.get_process_count(::Distributed.AbstractOMPIClusterEnvDetector)
70-
return parse(Int, ENV[_OMPI_PROCESS_COUNT])
71-
end
72-
73-
function Distributed.get_process_id(::Distributed.AbstractOMPIClusterEnvDetector)
74-
return parse(Int, ENV[_OMPI_PROCESS_ID])
75-
end
76-
77-
function Distributed.get_local_process_id(::Distributed.AbstractOMPIClusterEnvDetector)
78-
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
79-
end
80-
81-
# Based on mpi4py
6+
# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
827
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()
838

849
function Distributed.get_coordinator_address(

src/Distributed.jl

+61
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,65 @@ function auto_detect_unset_distributed_params(;
9898
return coordinator_address, num_processes, process_id, local_device_ids
9999
end
100100

101+
# OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector
102+
# Based on https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py and adapted for latest OpenMPI versions
103+
const _ORTE_URI = "OMPI_MCA_orte_hnp_uri"
104+
const _PMIX_SERVER_URI = (
105+
"PMIX_SERVER_URI2",
106+
"PMIX_SERVER_URI3",
107+
"PMIX_SERVER_URI4",
108+
"PMIX_SERVER_URI41",
109+
"PMIX_SERVER_URI21",
110+
)
111+
const _OMPI_PROCESS_COUNT = "OMPI_COMM_WORLD_SIZE"
112+
const _OMPI_PROCESS_ID = "OMPI_COMM_WORLD_RANK"
113+
const _OMPI_LOCAL_PROCESS_ID = "OMPI_COMM_WORLD_LOCAL_RANK"
114+
115+
is_env_present(::OpenMPIORTEEnvDetector) = haskey(ENV, _ORTE_URI)
116+
is_env_present(::OpenMPIPMIXEnvDetector) = any(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
117+
118+
function get_coordinator_address(::OpenMPIORTEEnvDetector, ::Integer)
119+
orte_uri = ENV[_ORTE_URI]
120+
121+
job_id = parse(Int, split(orte_uri, '.'; limit=2)[1])
122+
port = job_id % 2^12 + (65535 - 2^12 + 1)
123+
124+
launcher_ip_match = match(r"tcp://(.+?)[,:]|tcp6://\[(.+?)[,\]]", orte_uri)
125+
126+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
127+
Open MPI environment."
128+
129+
launcher_ip = launcher_ip_match.captures[findfirst(
130+
!isnothing, launcher_ip_match.captures
131+
)]
132+
return "$(launcher_ip):$(port)"
133+
end
134+
135+
function get_coordinator_address(::OpenMPIPMIXEnvDetector, ::Integer)
136+
varname = findfirst(Base.Fix1(haskey, ENV), _PMIX_SERVER_URI)
137+
pmix_uri = ENV[_PMIX_SERVER_URI[varname]]
138+
139+
job_id = parse(Int, split(split(pmix_uri, '-'; limit=3)[3], "@"; limit=2)[1])
140+
port = job_id % 2^12 + (65535 - 2^12 + 1)
141+
142+
launcher_ip_match = match(r"tcp4://(.+?):|tcp6://\[(.+?)\]", pmix_uri)
143+
144+
@assert launcher_ip_match !== nothing "Could not parse coordinator IP address from \
145+
Open MPI environment."
146+
147+
launcher_ip = launcher_ip_match.captures[findfirst(
148+
!isnothing, launcher_ip_match.captures
149+
)]
150+
151+
return "$(launcher_ip):$(port)"
152+
end
153+
154+
get_process_count(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_COUNT])
155+
156+
get_process_id(::AbstractOMPIClusterEnvDetector) = parse(Int, ENV[_OMPI_PROCESS_ID])
157+
158+
function get_local_process_id(::AbstractOMPIClusterEnvDetector)
159+
return parse(Int, ENV[_OMPI_LOCAL_PROCESS_ID])
160+
end
161+
101162
end

0 commit comments

Comments
 (0)