@@ -3,82 +3,7 @@ module ReactantMPIExt
3
3
using Reactant: Reactant, Distributed
4
4
using MPI: MPI
5
5
6
- # Code taken from:
7
- # 1. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
8
- # 2. https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/ompi_cluster.py
9
-
10
- # Based on ompi_cluster
11
- const _ORTE_URI = " OMPI_MCA_orte_hnp_uri"
12
- const _PMIX_SERVER_URI = (
13
- " PMIX_SERVER_URI2" ,
14
- " PMIX_SERVER_URI3" ,
15
- " PMIX_SERVER_URI4" ,
16
- " PMIX_SERVER_URI41" ,
17
- " PMIX_SERVER_URI21" ,
18
- )
19
- const _OMPI_PROCESS_COUNT = " OMPI_COMM_WORLD_SIZE"
20
- const _OMPI_PROCESS_ID = " OMPI_COMM_WORLD_RANK"
21
- const _OMPI_LOCAL_PROCESS_ID = " OMPI_COMM_WORLD_LOCAL_RANK"
22
-
23
- Distributed. is_env_present (:: Distributed.OpenMPIORTEEnvDetector ) = haskey (ENV , _ORTE_URI)
24
-
25
- function Distributed. is_env_present (:: Distributed.OpenMPIPMIXEnvDetector )
26
- return any (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)
27
- end
28
-
29
- function Distributed. get_coordinator_address (
30
- :: Distributed.OpenMPIORTEEnvDetector , timeout_in_seconds:: Integer
31
- )
32
- orte_uri = ENV [_ORTE_URI]
33
-
34
- job_id = parse (Int, split (orte_uri, ' .' ; limit= 2 )[1 ])
35
- port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
36
-
37
- launcher_ip_match = match (r" tcp://(.+?)[,:]|tcp6://\[ (.+?)[,\] ]" , orte_uri)
38
-
39
- @assert launcher_ip_match != = nothing " Could not parse coordinator IP address from \
40
- Open MPI environment."
41
-
42
- launcher_ip = launcher_ip_match. captures[findfirst (
43
- ! isnothing, launcher_ip_match. captures
44
- )]
45
- return " $(launcher_ip) :$(port) "
46
- end
47
-
48
- function Distributed. get_coordinator_address (
49
- :: Distributed.OpenMPIPMIXEnvDetector , timeout_in_seconds:: Integer
50
- )
51
- varname = findfirst (Base. Fix1 (haskey, ENV ), _PMIX_SERVER_URI)
52
- pmix_uri = ENV [_PMIX_SERVER_URI[varname]]
53
-
54
- job_id = parse (Int, split (split (pmix_uri, ' -' ; limit= 3 )[3 ], " @" ; limit= 2 )[1 ])
55
- port = job_id % 2 ^ 12 + (65535 - 2 ^ 12 + 1 )
56
-
57
- launcher_ip_match = match (r" tcp4://(.+?):|tcp6://\[ (.+?)\] " , pmix_uri)
58
-
59
- @assert launcher_ip_match != = nothing " Could not parse coordinator IP address from \
60
- Open MPI environment."
61
-
62
- launcher_ip = launcher_ip_match. captures[findfirst (
63
- ! isnothing, launcher_ip_match. captures
64
- )]
65
-
66
- return " $(launcher_ip) :$(port) "
67
- end
68
-
69
- function Distributed. get_process_count (:: Distributed.AbstractOMPIClusterEnvDetector )
70
- return parse (Int, ENV [_OMPI_PROCESS_COUNT])
71
- end
72
-
73
- function Distributed. get_process_id (:: Distributed.AbstractOMPIClusterEnvDetector )
74
- return parse (Int, ENV [_OMPI_PROCESS_ID])
75
- end
76
-
77
- function Distributed. get_local_process_id (:: Distributed.AbstractOMPIClusterEnvDetector )
78
- return parse (Int, ENV [_OMPI_LOCAL_PROCESS_ID])
79
- end
80
-
81
- # Based on mpi4py
6
+ # https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
82
7
Distributed. is_env_present (:: Distributed.MPIEnvDetector ) = MPI. Initialized ()
83
8
84
9
function Distributed. get_coordinator_address (
0 commit comments