@@ -21,6 +21,7 @@ function LLVMclopts(opts...)
21
21
):: Cvoid
22
22
end
23
23
24
+ include (" Distributed.jl" )
24
25
include (" Client.jl" )
25
26
include (" Device.jl" )
26
27
include (" Sharding.jl" )
@@ -53,6 +54,7 @@ function Base.setproperty!(bs::BackendState, sym::Symbol, val)
53
54
end
54
55
55
56
const global_backend_state = BackendState ()
57
+ const global_state = State ()
56
58
57
59
client (backend:: String ) = global_backend_state. clients[backend]
58
60
default_backend () = global_backend_state. default_client
@@ -68,6 +70,13 @@ function set_default_backend(backend::String)
68
70
return nothing
69
71
end
70
72
73
+ function update_global_state! (args... ; kwargs... )
74
+ update! (global_state, args... ; kwargs... )
75
+ # We need to update the clients based on the new state
76
+ initialize_default_clients! (global_backend_state)
77
+ return nothing
78
+ end
79
+
71
80
function __init__ ()
72
81
# This must be the very first thing initialized (otherwise we can't throw errors)
73
82
errptr = cglobal ((:ReactantThrowError , MLIR. API. mlir_c), Ptr{Ptr{Cvoid}})
@@ -90,16 +99,17 @@ function __init__()
90
99
@debug " XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
91
100
end
92
101
102
+ if haskey (ENV , " REACTANT_VISIBLE_GPU_DEVICES" )
103
+ global_state. local_device_ids =
104
+ parse .(Int, split (ENV [" REACTANT_VISIBLE_GPU_DEVICES" ], " ," ))
105
+ @debug " REACTANT_VISIBLE_GPU_DEVICES: " global_state. local_device_ids
106
+ end
107
+
93
108
@ccall MLIR. API. mlir_c. RegisterEnzymeXLACPUHandler ():: Cvoid
94
109
@ccall MLIR. API. mlir_c. RegisterEnzymeXLAGPUHandler ():: Cvoid
95
110
return nothing
96
111
end
97
112
98
- function initialize_default_clients ()
99
- initialize_default_clients! (global_backend_state)
100
- return nothing
101
- end
102
-
103
113
function initialize_default_clients! (state:: BackendState )
104
114
was_initialized = state. initialized
105
115
state. initialized = true
@@ -109,7 +119,7 @@ function initialize_default_clients!(state::BackendState)
109
119
XLA. free_client (state. clients[" cpu" ])
110
120
XLA. PJRT. cpu_client_count[] -= 1
111
121
end
112
- cpu = PJRT. CPUClient (PJRT . global_state. process_id, PJRT . global_state. num_processes)
122
+ cpu = PJRT. CPUClient (global_state. process_id, global_state. num_processes)
113
123
state. clients[" cpu" ] = cpu
114
124
state. default_client = cpu
115
125
@@ -142,9 +152,9 @@ function initialize_default_clients!(state::BackendState)
142
152
else
143
153
if ! Reactant. precompiling ()
144
154
try
145
- distributed_runtime_client = if PJRT . global_state. num_processes > 1
146
- @assert PJRT . global_state. client != = nothing
147
- PJRT . global_state. client
155
+ distributed_runtime_client = if global_state. num_processes > 1
156
+ @assert global_state. client != = nothing
157
+ global_state. client
148
158
else
149
159
nothing
150
160
end
@@ -154,9 +164,9 @@ function initialize_default_clients!(state::BackendState)
154
164
XLA. PJRT. gpu_client_count[] -= 1
155
165
end
156
166
gpu = PJRT. GPUClient (
157
- PJRT . global_state. process_id,
158
- PJRT . global_state. num_processes;
159
- allowed_devices= PJRT . global_state. local_device_ids,
167
+ global_state. process_id,
168
+ global_state. num_processes;
169
+ allowed_devices= global_state. local_device_ids,
160
170
distributed_runtime_client,
161
171
)
162
172
state. clients[" gpu" ] = gpu
0 commit comments