Skip to content

Commit c89dca8

Browse files
committed
refactor: Distributed setup is not PJRT specific
1 parent 9bb572b commit c89dca8

File tree

5 files changed

+28
-36
lines changed

5 files changed

+28
-36
lines changed

src/Distributed.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ function initialize(;
2424

2525
@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_device_ids
2626

27-
Reactant.XLA.PJRT.update_global_state!(;
27+
Reactant.XLA.update_global_state!(;
2828
coordinator_address, num_processes, process_id, local_device_ids, kwargs...
2929
)
3030

31-
@debug "New Global State" Reactant.XLA.PJRT.global_state
31+
@debug "New Global State" Reactant.XLA.global_state
3232

3333
initialized[] = true
3434
return nothing

src/xla/Client.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ function GPUClient(
2828
node_id=0,
2929
num_nodes=1,
3030
platform="gpu";
31-
allowed_devices=nothing,
32-
distributed_runtime_client=nothing,
31+
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32+
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
3333
)
3434
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
3535
refstr = Ref{Cstring}()

src/xla/PJRT/Distributed.jl src/xla/Distributed.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ function update!(
175175
if length(proxy_vars) > 0
176176
vars = join(proxy_vars, ", ")
177177
@warn "Reactant detected proxy variable(s) in the environment as distributed \
178-
setup: $(vars). On some systems, this may cause a hang of \
179-
`XLA.PJRT.update!` and you may need to unset the proxy variables."
178+
setup: $(vars). On some systems, this may cause a hang of `XLA.update!` and \
179+
you may need to unset the proxy variables."
180180
end
181181

182182
@assert state.client === nothing "`Reactant.Distributed.initialize` should only be \

src/xla/PJRT/PJRT.jl

-18
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,11 @@ using Reactant_jll: Reactant_jll
66

77
using Libdl: Libdl
88

9-
include("Distributed.jl")
109
include("Client.jl")
1110
include("Device.jl")
1211
include("Future.jl")
1312
include("Buffer.jl")
1413
include("AsyncBuffer.jl")
1514
include("LoadedExecutable.jl")
1615

17-
const global_state = State()
18-
19-
function __init__()
20-
if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
21-
global_state.local_device_ids =
22-
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
23-
end
24-
return nothing
25-
end
26-
27-
function update_global_state!(args...; kwargs...)
28-
update!(global_state, args...; kwargs...)
29-
# We need to update the clients based on the new state
30-
XLA.initialize_default_clients()
31-
return nothing
32-
end
33-
3416
end

src/xla/XLA.jl

+22-12
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ function LLVMclopts(opts...)
2121
)::Cvoid
2222
end
2323

24+
include("Distributed.jl")
2425
include("Client.jl")
2526
include("Device.jl")
2627
include("Sharding.jl")
@@ -53,6 +54,7 @@ function Base.setproperty!(bs::BackendState, sym::Symbol, val)
5354
end
5455

5556
const global_backend_state = BackendState()
57+
const global_state = State()
5658

5759
client(backend::String) = global_backend_state.clients[backend]
5860
default_backend() = global_backend_state.default_client
@@ -68,6 +70,13 @@ function set_default_backend(backend::String)
6870
return nothing
6971
end
7072

73+
function update_global_state!(args...; kwargs...)
74+
update!(global_state, args...; kwargs...)
75+
# We need to update the clients based on the new state
76+
initialize_default_clients!(global_backend_state)
77+
return nothing
78+
end
79+
7180
function __init__()
7281
# This must be the very first thing initialized (otherwise we can't throw errors)
7382
errptr = cglobal((:ReactantThrowError, MLIR.API.mlir_c), Ptr{Ptr{Cvoid}})
@@ -90,16 +99,17 @@ function __init__()
9099
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
91100
end
92101

102+
if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES")
103+
global_state.local_device_ids =
104+
parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ","))
105+
@debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_device_ids
106+
end
107+
93108
@ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid
94109
@ccall MLIR.API.mlir_c.RegisterEnzymeXLAGPUHandler()::Cvoid
95110
return nothing
96111
end
97112

98-
function initialize_default_clients()
99-
initialize_default_clients!(global_backend_state)
100-
return nothing
101-
end
102-
103113
function initialize_default_clients!(state::BackendState)
104114
was_initialized = state.initialized
105115
state.initialized = true
@@ -109,7 +119,7 @@ function initialize_default_clients!(state::BackendState)
109119
XLA.free_client(state.clients["cpu"])
110120
XLA.PJRT.cpu_client_count[] -= 1
111121
end
112-
cpu = PJRT.CPUClient(PJRT.global_state.process_id, PJRT.global_state.num_processes)
122+
cpu = PJRT.CPUClient(global_state.process_id, global_state.num_processes)
113123
state.clients["cpu"] = cpu
114124
state.default_client = cpu
115125

@@ -142,9 +152,9 @@ function initialize_default_clients!(state::BackendState)
142152
else
143153
if !Reactant.precompiling()
144154
try
145-
distributed_runtime_client = if PJRT.global_state.num_processes > 1
146-
@assert PJRT.global_state.client !== nothing
147-
PJRT.global_state.client
155+
distributed_runtime_client = if global_state.num_processes > 1
156+
@assert global_state.client !== nothing
157+
global_state.client
148158
else
149159
nothing
150160
end
@@ -154,9 +164,9 @@ function initialize_default_clients!(state::BackendState)
154164
XLA.PJRT.gpu_client_count[] -= 1
155165
end
156166
gpu = PJRT.GPUClient(
157-
PJRT.global_state.process_id,
158-
PJRT.global_state.num_processes;
159-
allowed_devices=PJRT.global_state.local_device_ids,
167+
global_state.process_id,
168+
global_state.num_processes;
169+
allowed_devices=global_state.local_device_ids,
160170
distributed_runtime_client,
161171
)
162172
state.clients["gpu"] = gpu

0 commit comments

Comments
 (0)