Skip to content

Commit 2a5711e

Browse files
authored
feat: initial IFRT integration (#764)
* feat: Low-Level XLA.IFRT integration refactor: rework how OpSharding works feat: generate_device_list feat: add placeholder code to simplify future sharding logic fixup fix: store results as HloSharding docs: fix duplicate docs feat: compile with logical device ids fix: use correct global device ids feat: use a global state to setup pjrt distributed runtime fix: devices are not necessarily from 0 to N-1 fix: initialize clients on first use rather than on init fix: make device selection consistent with clients feat: add OMPI cluster detection fix: correctly set kv_store refactor: Distributed setup is not PJRT specific refactor: OMPI detection doesn't need to be in an extension feat: initial low-level IFRT API fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable feat: construct IFRT clients with distributed options refactor: remove BasicDevicesList fix: use global device ids feat: sharding annotations across nodes now working fix: Array construction from SingleShards feat: support to_host for distributed cases feat: add Gloo/MPI collectives for distributed CPU client feat: low level compile API feat: low-level IFRT compile + execute working * test: low level IFRT tests * docs: add warning on distributed cases
1 parent bcb0282 commit 2a5711e

25 files changed

+1187
-198
lines changed

src/Compiler.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -508,13 +508,6 @@ function compile_mlir(f, args; client=nothing, kwargs...)
508508

509509
mlir_fn_res = compile_mlir!(mod, f, args; backend, kwargs...)
510510

511-
client, _ = __resolve_device_and_client(
512-
client,
513-
mlir_fn_res.seen_args,
514-
mlir_fn_res.linear_args,
515-
mlir_fn_res.is_sharded,
516-
)
517-
518511
# Attach a name, and partitioning attributes to the module
519512
__add_mhlo_attributes_and_name!(
520513
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
@@ -1509,6 +1502,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
15091502
num_parameters=length(mlir_fn_res.linear_args),
15101503
mlir_fn_res.is_sharded,
15111504
global_device_ids,
1505+
mlir_fn_res.num_replicas,
1506+
mlir_fn_res.num_partitions,
15121507
)
15131508

15141509
return mod, exec, mlir_fn_res, device, client

src/Distributed.jl

+11-11
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,24 @@ function initialize(;
88
coordinator_address::Union{Nothing,String}=nothing,
99
num_processes::Union{Nothing,Integer}=nothing,
1010
process_id::Union{Nothing,Integer}=nothing,
11-
local_device_ids::Union{Nothing,Vector{Int}}=nothing,
11+
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
1212
initialization_timeout_in_seconds::Integer=300,
1313
kwargs...,
1414
)
1515
@assert !initialized[] "`Distributed.initialize` has already been called"
1616

17-
(coordinator_address, num_processes, process_id, local_device_ids) = auto_detect_unset_distributed_params(;
17+
(coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(;
1818
coordinator_address,
1919
num_processes,
2020
process_id,
21-
local_device_ids,
21+
local_gpu_device_ids,
2222
initialization_timeout_in_seconds,
2323
)
2424

25-
@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_device_ids
25+
@debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids
2626

2727
Reactant.XLA.update_global_state!(;
28-
coordinator_address, num_processes, process_id, local_device_ids, kwargs...
28+
coordinator_address, num_processes, process_id, local_gpu_device_ids, kwargs...
2929
)
3030

3131
@debug "New Global State" Reactant.XLA.global_state
@@ -57,14 +57,14 @@ function auto_detect_unset_distributed_params(;
5757
coordinator_address::Union{Nothing,String}=nothing,
5858
num_processes::Union{Nothing,Integer}=nothing,
5959
process_id::Union{Nothing,Integer}=nothing,
60-
local_device_ids::Union{Nothing,Vector{Int}}=nothing,
60+
local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing,
6161
initialization_timeout_in_seconds::Integer=300,
6262
)
6363
if all(
6464
Base.Fix2(!==, nothing),
65-
(coordinator_address, num_processes, process_id, local_device_ids),
65+
(coordinator_address, num_processes, process_id, local_gpu_device_ids),
6666
)
67-
return coordinator_address, num_processes, process_id, local_device_ids
67+
return coordinator_address, num_processes, process_id, local_gpu_device_ids
6868
end
6969

7070
idx = findfirst(is_env_present, detector_list)
@@ -91,11 +91,11 @@ function auto_detect_unset_distributed_params(;
9191
process_id = get_process_id(detector)
9292
end
9393

94-
if local_device_ids === nothing
95-
local_device_ids = [get_local_process_id(detector)]
94+
if local_gpu_device_ids === nothing
95+
local_gpu_device_ids = [get_local_process_id(detector)]
9696
end
9797

98-
return coordinator_address, num_processes, process_id, local_device_ids
98+
return coordinator_address, num_processes, process_id, local_gpu_device_ids
9999
end
100100

101101
# OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector

src/Types.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ function ConcretePJRTArray(
135135
if idx === nothing
136136
device = XLA.default_device(client)
137137
else
138-
device = XLA.get_addressable_device(client, idx)
138+
device = XLA.get_device(client, idx)
139139
end
140140
else
141141
if idx !== nothing
142-
device_from_idx = XLA.get_addressable_device(client, idx)
142+
device_from_idx = XLA.get_device(client, idx)
143143
@assert device_from_idx == device "If both `idx` and `device` are \
144144
specified, `idx` must match `device`"
145145
end

src/xla/Buffer.jl

+54
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ function buffer_on_cpu end
55
function to_host end
66
function unsafe_buffer_pointer end
77
function copy_buffer_to_device end
8+
function sharding end
9+
10+
Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer)
11+
12+
function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T}
13+
arr = zeros(T, reverse(size(buffer))...)
14+
XLA.to_host(buffer, arr)
15+
return arr
16+
end
817

918
@inline function client(
1019
buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}}
@@ -19,3 +28,48 @@ end
1928
)
2029
return map(synced_buffer, buffers)
2130
end
31+
32+
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
33+
print(io, "$(B) storing ")
34+
show(io, mime, convert(Array, buffer))
35+
return nothing
36+
end
37+
38+
# Async Buffers
39+
abstract type AbstractAsyncBuffer <: AbstractBuffer end
40+
41+
Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL
42+
43+
function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer)
44+
XLA.await(buffer)
45+
return convert(T, buffer.buffer)
46+
end
47+
48+
function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1}
49+
XLA.await(buffer)
50+
return convert(T, buffer.buffer)
51+
end
52+
53+
for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding)
54+
@eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer)
55+
end
56+
57+
function XLA.synced_buffer(buffer::AbstractAsyncBuffer)
58+
XLA.await(buffer)
59+
return buffer.buffer
60+
end
61+
62+
function XLA.await(buffer::AbstractAsyncBuffer)
63+
buffer.future === nothing && return nothing
64+
future = buffer.future
65+
buffer.future = nothing
66+
XLA.await(future)
67+
return nothing
68+
end
69+
70+
function XLA.is_ready(buffer::AbstractAsyncBuffer)
71+
buffer.future === nothing && return true
72+
return XLA.is_ready(buffer.future)
73+
end
74+
75+
XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)

src/xla/Client.jl

-52
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,3 @@ function get_addressable_device end
1414
function platform_name end
1515

1616
default_device(client::AbstractClient) = first(addressable_devices(client))
17-
18-
# Clients for Different Backends
19-
function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
20-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
21-
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
22-
LLVMclopts("-nvptx-fma-level=1")
23-
return client
24-
end
25-
26-
function GPUClient(
27-
cfunc,
28-
node_id=0,
29-
num_nodes=1,
30-
platform="gpu";
31-
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32-
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
33-
)
34-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
35-
refstr = Ref{Cstring}()
36-
37-
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
38-
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39-
distributed_runtime_client =
40-
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
41-
42-
client = ccall(
43-
f,
44-
Ptr{Cvoid},
45-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
46-
node_id,
47-
num_nodes,
48-
allowed_devices,
49-
num_allowed_devices,
50-
XLA_REACTANT_GPU_MEM_FRACTION[],
51-
false,
52-
platform,
53-
refstr,
54-
distributed_runtime_client,
55-
)
56-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
57-
LLVMclopts("-nvptx-fma-level=1")
58-
return client
59-
end
60-
61-
function TPUClient(cfunc, tpu_path::String)
62-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
63-
refstr = Ref{Cstring}()
64-
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr)
65-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
66-
LLVMclopts("-nvptx-fma-level=1")
67-
return client
68-
end

src/xla/Device.jl

+2-6
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,15 @@ function get_local_device_id end
1010
function device_kind end
1111
function default_memory end
1212
function memories end
13+
function is_addressable end
1314

1415
"""
1516
device_ordinal(device::Device)
16-
device_ordinal(client::XLA.AbstractClient, local_device_id::Int)
1717
18-
Given the device or local device id, return the corresponding global device ordinal in the client.
18+
Given the device, return the corresponding global device ordinal in the client.
1919
"""
2020
function device_ordinal end
2121

22-
function device_ordinal(client::AbstractClient, local_device_id::Integer)
23-
return device_ordinal(get_addressable_device(client, local_device_id))
24-
end
25-
2622
function Base.string(device::AbstractDevice)
2723
client = XLA.client(device)
2824
pname = XLA.platform_name(client)

src/xla/Distributed.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
@kwdef mutable struct State
107107
process_id::Int = 0
108108
num_processes::Int = 1
109-
local_device_ids::Union{Nothing,Vector{Int}} = nothing
109+
local_gpu_device_ids::Union{Nothing,Vector{Int}} = nothing
110110
service::Union{Nothing,DistributedRuntimeService} = nothing
111111
client::Union{Nothing,DistributedRuntimeClient} = nothing
112112
coordinator_address::Union{Nothing,String} = nothing
@@ -129,7 +129,7 @@ function update!(
129129
coordinator_address::String,
130130
num_processes::Int,
131131
process_id::Int,
132-
local_device_ids::Vector{Int},
132+
local_gpu_device_ids::Vector{Int},
133133
coordinator_bind_address::Union{Nothing,String}=nothing,
134134
cluster_register_timeout_in_minutes::Integer=60,
135135
rpc_timeout_in_seconds::Integer=120,
@@ -141,7 +141,7 @@ function update!(
141141
@assert 0 process_id < num_processes
142142

143143
state.coordinator_address = coordinator_address
144-
state.local_device_ids = local_device_ids
144+
state.local_gpu_device_ids = local_gpu_device_ids
145145
state.process_id = process_id
146146
state.num_processes = num_processes
147147

0 commit comments

Comments
 (0)