Skip to content

Commit 0a37253

Browse files
authored
refactor: move PJRT into a specific module (#771)
* feat: JLL changes for IFRT Shardings * chore: run clang-format * chore: run formatter * fix: remove old APIs * refactor: move PJRT into a specific module * fix: store buffers in global sorted id ordering * fix: finalizer * fix: prevent double free * chore: bump jll * fix: avoid finalizer for now
1 parent afa90a4 commit 0a37253

32 files changed

+879
-733
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ PythonCall = "0.9"
8181
Random = "1.10"
8282
Random123 = "1.7"
8383
ReactantCore = "0.1.5"
84-
Reactant_jll = "0.0.70"
84+
Reactant_jll = "0.0.71"
8585
Scratch = "1.2"
8686
Sockets = "1.10"
8787
SpecialFunctions = "2.4"

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@ end
11551155
@static if !Sys.isapple()
11561156
Reactant.PrecompileTools.@setup_workload begin
11571157
Reactant.initialize_dialect()
1158-
client = Reactant.XLA.CPUClient(; checkcount=false)
1158+
client = Reactant.XLA.PJRT.CPUClient(; checkcount=false)
11591159
Reactant.PrecompileTools.@compile_workload begin
11601160
@static if Reactant.precompilation_supported() && VERSION != v"1.11.3"
11611161
function square_kernel!(x)

src/Compiler.jl

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,9 @@ function compile_mlir(f, args; client=nothing, kwargs...)
478478
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
479479

480480
if client !== nothing
481-
backend = XLA.ClientGetPlatformName(client)
481+
backend = XLA.platform_name(client)
482482
else
483-
backend = XLA.ClientGetPlatformName(XLA.default_backend[])
483+
backend = XLA.platform_name(XLA.default_backend[])
484484
end
485485
if backend == "CUDA"
486486
backend = "GPU"
@@ -1076,9 +1076,7 @@ function codegen_flatten!(
10761076

10771077
if is_sharded
10781078
carg = inv_seen_args[arg]
1079-
condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
1080-
linear_parameter_shardings[i]
1081-
)
1079+
device_ids = mesh.sorted_device_ids
10821080
if Reactant.Sharding.is_sharded(carg)
10831081
# Currently disabling the error since we roundtrip from MHLO to generate
10841082
# the shardings
@@ -1090,29 +1088,30 @@ function codegen_flatten!(
10901088

10911089
push!(flatten_code, :($usbuf = $flatcode.data))
10921090
for j in 1:length(mesh)
1093-
sbuf = Symbol(:sbuf_, i, "_", j)
1091+
sbuf = Symbol(:sbuf_, i, "_", device_ids[j])
10941092
push!(flatten_names, sbuf)
10951093
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
10961094
end
10971095
else
1096+
condensed_op_sharding = convert(
1097+
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
1098+
)
10981099
push!(flatten_code, :($usbuf = $flatcode))
10991100
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
11001101
condensed_op_sharding, size(carg), mesh
11011102
)
1102-
device_ids = vec(mesh)
11031103
for j in 1:length(mesh)
1104-
buf = Symbol(:buf_, i, :_, j)
1105-
device_id = device_ids[j]
1104+
local_device_id = device_ids[j]
1105+
buf = Symbol(:buf_, i, :_, local_device_id)
11061106
slice = device_to_array_slices[j]
11071107
push!(
11081108
flatten_code,
11091109
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
11101110
)
1111-
device_ordinal = XLA.device_ordinal(client, device_id)
1112-
sbuf = Symbol(:sbuf_, i, :_, j)
1113-
device = XLA.ClientGetAddressableDevice(client, device_ordinal)
1111+
sbuf = Symbol(:sbuf_, i, :_, local_device_id)
1112+
device = XLA.get_addressable_device(client, local_device_id)
11141113
push!(flatten_names, sbuf)
1115-
push!(flatten_code, :($sbuf = XLA.CopyBufferToDevice($buf, $device)))
1114+
push!(flatten_code, :($sbuf = XLA.copy_buffer_to_device($buf, $device)))
11161115
end
11171116
end
11181117
else
@@ -1308,12 +1307,17 @@ Generate Julia code to call the XLA executable.
13081307
- `nresults`: The number of results to expect.
13091308
"""
13101309
function codegen_xla_call(
1311-
exec, device, flatten_names, donated_args_mask, nresults, is_sharded::Bool, mesh_ids
1310+
exec,
1311+
device,
1312+
flatten_names,
1313+
donated_args_mask,
1314+
nresults,
1315+
is_sharded::Bool,
1316+
ndevices::Int,
13121317
)
13131318
flatten_buffer_refs = map(n -> :($n.buffer), flatten_names)
13141319

1315-
base_symbol_name =
1316-
is_sharded ? Symbol(:result_buffer_m, length(mesh_ids), :_) : :result_buffer_
1320+
base_symbol_name = is_sharded ? Symbol(:result_buffer_m, ndevices, :_) : :result_buffer_
13171321
concretized_res_names = Symbol[Symbol(base_symbol_name, i) for i in 1:nresults]
13181322
concretized_res_code = map(enumerate(concretized_res_names)) do (i, varname)
13191323
:($varname = linearized_results[$i])
@@ -1325,21 +1329,20 @@ function codegen_xla_call(
13251329
if is_sharded
13261330
quote
13271331
GC.@preserve $(flatten_names...) begin
1328-
linearized_results = XLA.ExecutableCall(
1332+
linearized_results = XLA.execute(
13291333
$exec,
1330-
$(mesh_ids),
13311334
($(flatten_buffer_refs...),),
13321335
$(Tuple(donated_args_mask)),
13331336
Val($nresults),
1334-
Val($(length(mesh_ids))),
1337+
Val($ndevices),
13351338
)
13361339
end
13371340
$(concretized_res_code...)
13381341
end
13391342
else
13401343
quote
13411344
GC.@preserve $(flatten_names...) begin
1342-
linearized_results = XLA.ExecutableCallSharded(
1345+
linearized_results = XLA.execute_sharded(
13431346
$exec,
13441347
$(device),
13451348
($(flatten_buffer_refs...),),
@@ -1393,7 +1396,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
13931396
if !allequal(devices_list)
13941397
msg = "Expected all arguments to be on the same device, got:\n"
13951398
for (i, device) in enumerate(devices_list)
1396-
msg *= " Device $(i): $(XLA.DeviceToString(device))\n"
1399+
msg *= " Device $(i): $(string(device))\n"
13971400
end
13981401
throw(ArgumentError(msg))
13991402
end
@@ -1407,17 +1410,13 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
14071410
client = XLA.client(device)
14081411
else
14091412
client = XLA.default_backend[]
1410-
device = XLA.ClientGetAddressableDevice(
1411-
client, XLA.device_ordinal(client, XLA.default_device_idx[])
1412-
)
1413+
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
14131414
end
14141415
else
14151416
if device !== nothing
14161417
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
14171418
else
1418-
device = XLA.ClientGetAddressableDevice(
1419-
client, XLA.device_ordinal(client, XLA.default_device_idx[])
1420-
)
1419+
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
14211420
end
14221421
end
14231422

@@ -1431,9 +1430,9 @@ function compile_xla(f, args; client=nothing, kwargs...)
14311430
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
14321431

14331432
if client !== nothing
1434-
backend = XLA.ClientGetPlatformName(client)
1433+
backend = XLA.platform_name(client)
14351434
else
1436-
backend = XLA.ClientGetPlatformName(XLA.default_backend[])
1435+
backend = XLA.platform_name(XLA.default_backend[])
14371436
end
14381437
if backend == "CUDA"
14391438
backend = "GPU"
@@ -1461,17 +1460,21 @@ function compile_xla(f, args; client=nothing, kwargs...)
14611460
)
14621461

14631462
# compile MLIR module to XLA executable
1464-
device_ids = mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[]
1463+
local_device_ids = if mlir_fn_res.is_sharded
1464+
collect(Int64, mlir_fn_res.sharding_mesh.sorted_device_ids)
1465+
else
1466+
Int64[]
1467+
end
14651468
mlir_fn_res.is_sharded && (device = nothing)
14661469

1467-
exec = XLA.Compile(
1470+
exec = XLA.compile(
14681471
client,
14691472
device,
14701473
mod;
14711474
num_outputs=length(mlir_fn_res.linear_results),
14721475
num_parameters=length(mlir_fn_res.linear_args),
14731476
mlir_fn_res.is_sharded,
1474-
device_ids,
1477+
local_device_ids,
14751478
)
14761479

14771480
return mod, exec, mlir_fn_res, device, client
@@ -1514,7 +1517,7 @@ function compile(f, args; sync=false, kwargs...)
15141517
donated_args_mask,
15151518
length(linear_results),
15161519
mlir_fn_res.is_sharded,
1517-
mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[],
1520+
mlir_fn_res.is_sharded ? length(mlir_fn_res.sharding_mesh) : 1,
15181521
)
15191522

15201523
linear_result_shard_info = if mlir_fn_res.is_sharded

src/ConcreteRArray.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,13 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x
3737

3838
Base.size(x::ConcreteRArray) = x.shape
3939

40-
function Base.isempty(x::Union{ConcreteRArray,ConcreteRNumber})
41-
return any(==(XLA.AsyncEmptyBuffer), x.data)
42-
end
40+
Base.isempty(x::Union{ConcreteRArray,ConcreteRNumber}) = any(isempty, x.data)
4341
Base.isempty(x::WrappedConcreteRArray) = isempty(ancestor(x))
4442

4543
function Base.convert(::Type{<:Array}, X::ConcreteRArray{T,N}) where {T,N}
46-
data = Array{T,N}(undef, size(X)...)
47-
4844
if Sharding.is_sharded(X)
45+
data = Array{T,N}(undef, size(X)...)
46+
4947
completed = Set{eltype(X.sharding.device_to_array_slices)}()
5048
for idx in 1:length(X.data)
5149
slice = X.sharding.device_to_array_slices[idx]
@@ -56,14 +54,14 @@ function Base.convert(::Type{<:Array}, X::ConcreteRArray{T,N}) where {T,N}
5654
end
5755
data[slice...] = convert(Array{T}, X.data[idx])
5856
end
57+
58+
return data
5959
else
6060
buf = XLA.synced_buffer(only(X.data))
61-
GC.@preserve data buf begin
62-
XLA.BufferToHost(buf, pointer(data))
61+
GC.@preserve buf begin
62+
return convert(Array{T}, buf)
6363
end
6464
end
65-
66-
return data
6765
end
6866
function Base.convert(::Type{<:Array}, X::WrappedConcreteRArray)
6967
fn = compile(TracedUtils.materialize_traced_array, (X,))
@@ -82,7 +80,7 @@ function to_number(X::ConcreteRScalar{T}) where {T}
8280
XLA.await(X)
8381
buf = get_buffer(X; no_error_for_scalar=true)
8482
GC.@preserve data buf begin
85-
XLA.BufferToHost(buf, data)
83+
XLA.to_host(buf, data)
8684
end
8785
return data[]
8886
end
@@ -184,7 +182,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
184182
if buffer_on_cpu(a) && !Sharding.is_sharded(a)
185183
buf = get_buffer(a)
186184
GC.@preserve buf begin
187-
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
185+
ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf))
188186
start = 0
189187
for i in 1:N
190188
start *= size(a, N - i + 1)
@@ -211,7 +209,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
211209
if buffer_on_cpu(a) && !Sharding.is_sharded(a)
212210
buf = get_buffer(a)
213211
GC.@preserve buf begin
214-
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
212+
ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf))
215213
start = 0
216214
for i in 1:N
217215
start *= size(a, N - i + 1)
@@ -303,9 +301,7 @@ end
303301
(f::CallMapReduce)(A) = Base.mapreduce(f.f, f.op, A; f.dims, f.init)
304302

305303
buffer_on_cpu(::Any) = true
306-
function buffer_on_cpu(x::ConcreteRArray)
307-
return all(XLA.BufferOnCPU Base.Fix2(getproperty, :buffer), x.data)
308-
end
304+
buffer_on_cpu(x::ConcreteRArray) = all(XLA.buffer_on_cpu, x.data)
309305

310306
function Ops.constant(x::ConcreteRArray; kwargs...)
311307
return Ops.constant(Base.convert(Array, x); kwargs...)
@@ -328,7 +324,7 @@ function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N}
328324
if buffer_on_cpu(a) && !Sharding.is_sharded(a)
329325
buf = get_buffer(a)
330326
GC.@preserve buf begin
331-
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
327+
ptr = Base.unsafe_convert(Ptr{T}, XLA.unsafe_buffer_pointer(buf))
332328
for start in 1:length(a)
333329
unsafe_store!(ptr, val, start)
334330
end

src/Devices.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
"""
22
devices(backend::String)
3-
devices(backend::XLA.Client = XLA.default_backend[])
3+
devices(backend::XLA.AbstractClient = XLA.default_backend[])
44
55
Return a list of devices available on the backend.
66
"""
77
devices(backend::String) = devices(XLA.backends[backend])
88

9-
function devices(client::XLA.Client=XLA.default_backend[])
10-
ndevices = XLA.ClientNumDevices(client)
11-
return [XLA.ClientGetDevice(client, i - 1) for i in 1:ndevices]
9+
function devices(client::XLA.AbstractClient=XLA.default_backend[])
10+
ndevices = XLA.num_devices(client)
11+
return [XLA.get_device(client, i - 1) for i in 1:ndevices]
1212
end
1313

1414
"""
1515
addressable_devices(backend::String)
16-
addressable_devices(backend::XLA.Client = XLA.default_backend[])
16+
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend[])
1717
1818
Return a list of addressable devices available on the backend.
1919
"""
2020
addressable_devices(backend::String) = addressable_devices(XLA.backends[backend])
2121

22-
function addressable_devices(client::XLA.Client=XLA.default_backend[])
23-
ndevices = XLA.ClientNumAddressableDevices(client)
24-
return [XLA.ClientGetAddressableDevice(client, i - 1) for i in 1:ndevices]
22+
function addressable_devices(client::XLA.AbstractClient=XLA.default_backend[])
23+
ndevices = XLA.num_addressable_devices(client)
24+
return [XLA.get_addressable_device(client, i - 1) for i in 1:ndevices]
2525
end
2626

2727
# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55

src/Precompile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959

6060
@setup_workload begin
6161
initialize_dialect()
62-
client = XLA.CPUClient(; checkcount=false)
62+
client = XLA.PJRT.CPUClient(; checkcount=false)
6363
@compile_workload begin
6464
@static if precompilation_supported()
6565
x = ConcreteRNumber(2.0; client)

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ function __init__()
221221
return initialize_dialect()
222222
end
223223

224-
function set_default_backend(backend::XLA.Client)
224+
function set_default_backend(backend::XLA.AbstractClient)
225225
return XLA.default_backend[] = backend
226226
end
227227

0 commit comments

Comments
 (0)