Skip to content

Commit ea30e50

Browse files
committed
feat: Low-Level XLA.IFRT integration
refactor: rework how OpSharding works feat: generate_device_list feat: add placeholder code to simplify future sharding logic fixup fix: store results as HloSharding docs: fix duplicate docs feat: compile with logical device ids fix: use correct global device ids feat: use a global state to setup pjrt distributed runtime fix: devices are not necessarily from 0 to N-1 fix: initialize clients on first use rather than on init fix: make device selection consistent with clients feat: add OMPI cluster detection fix: correctly set kv_store refactor: Distributed setup is not PJRT specific refactor: OMPI detection doesn't need to be in an extension feat: initial low-level IFRT API fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable feat: construct IFRT clients with distributed options refactor: remove BasicDevicesList fix: use global device ids feat: sharding annotations across nodes now working fix: Array construction from SingleShards feat: support to_host for distributed cases feat: add Gloo/MPI collectives for distributed CPU client feat: low level compile API feat: low-level IFRT compile + execute working
1 parent 4aaae3c commit ea30e50

35 files changed

+2159
-594
lines changed

Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3131
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3232
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3333
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
34+
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3435
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
3536
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
3637
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
@@ -47,6 +48,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs"
4748
ReactantArrayInterfaceExt = "ArrayInterface"
4849
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
4950
ReactantKernelAbstractionsExt = "KernelAbstractions"
51+
ReactantMPIExt = "MPI"
5052
ReactantNNlibExt = "NNlib"
5153
ReactantOffsetArraysExt = "OffsetArrays"
5254
ReactantPythonCallExt = "PythonCall"
@@ -72,6 +74,7 @@ KernelAbstractions = "0.9.30"
7274
LLVM = "9.1"
7375
LLVMOpenMP_jll = "18.1.7"
7476
LinearAlgebra = "1.10"
77+
MPI = "0.20"
7578
NNlib = "0.9.26"
7679
OffsetArrays = "1"
7780
OrderedCollections = "1"
@@ -81,7 +84,7 @@ PythonCall = "0.9"
8184
Random = "1.10"
8285
Random123 = "1.7"
8386
ReactantCore = "0.1.5"
84-
Reactant_jll = "0.0.71"
87+
Reactant_jll = "0.0.72"
8588
Scratch = "1.2"
8689
Sockets = "1.10"
8790
SpecialFunctions = "2.4"

ext/ReactantMPIExt.jl

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module ReactantMPIExt
2+
3+
using Reactant: Reactant, Distributed
4+
using MPI: MPI
5+
6+
# https://github.com/jax-ml/jax/blob/b0117366686ab084d38ad2657d9a2ae3a581ca7e/jax/_src/clusters/mpi4py_cluster.py
7+
Distributed.is_env_present(::Distributed.MPIEnvDetector) = MPI.Initialized()
8+
9+
function Distributed.get_coordinator_address(
10+
::Distributed.MPIEnvDetector, timeout_in_seconds::Integer
11+
)
12+
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
13+
hostname = gethostname()
14+
port_id = hash(hostname) % 2^12 + (65535 - 2^12 + 1)
15+
hostname = "$(hostname):$(port_id)"
16+
else
17+
hostname = nothing
18+
end
19+
20+
return MPI.bcast(hostname, MPI.COMM_WORLD; root=0)
21+
end
22+
23+
function Distributed.get_process_count(::Distributed.MPIEnvDetector)
24+
return Int(MPI.Comm_size(MPI.COMM_WORLD))
25+
end
26+
27+
function Distributed.get_process_id(::Distributed.MPIEnvDetector)
28+
return Int(MPI.Comm_rank(MPI.COMM_WORLD))
29+
end
30+
31+
function Distributed.get_local_process_id(::Distributed.MPIEnvDetector)
32+
new_comm = MPI.Comm_split_type(MPI.COMM_WORLD, MPI.COMM_TYPE_SHARED, 0)
33+
return Int(MPI.Comm_rank(new_comm))
34+
end
35+
36+
end

src/Compiler.jl

+38-40
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ function create_result(
7474
return Expr(:new, T, elems...)
7575
end
7676

77-
function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
78-
device_to_array_slices, partition_spec = path_to_shard_info[path]
77+
function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh, N::Integer)
78+
device_to_array_slices, hlo_sharding = path_to_shard_info[path]
7979
delete!(path_to_shard_info, path)
80-
sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec)
80+
sharding = Reactant.Sharding.HloSharding(
81+
hlo_sharding, sharding_mesh, ntuple(Returns(true), N), ntuple(Returns(-1), N)
82+
)
8183
return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices)
8284
end
8385

@@ -88,7 +90,9 @@ function create_result(
8890
restore = result_stores[path]
8991
delete!(result_stores, path)
9092
if path_to_shard_info !== nothing # restore sharding
91-
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
93+
sharding = __reconstruct_shardinfo(
94+
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
95+
)
9296
return :(ConcreteRNumber{$T,length($(restore)),$(typeof(sharding))}(
9397
($(restore)...,), $sharding
9498
))
@@ -98,7 +102,9 @@ function create_result(
98102
end
99103

100104
if path_to_shard_info !== nothing # restore sharding
101-
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
105+
sharding = __reconstruct_shardinfo(
106+
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
107+
)
102108
return :(ConcreteRNumber{$T,length($(tocopy.data)),$(typeof(sharding))}(
103109
($(tocopy.data...,)), $sharding
104110
))
@@ -114,7 +120,9 @@ function create_result(
114120
restore = result_stores[path]
115121
delete!(result_stores, path)
116122
if path_to_shard_info !== nothing # restore sharding
117-
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
123+
sharding = __reconstruct_shardinfo(
124+
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
125+
)
118126
return :(ConcreteRArray{$T,$N,length($(restore)),$(typeof(sharding))}(
119127
($(restore)...,), $(tocopy.shape), $sharding
120128
))
@@ -124,7 +132,9 @@ function create_result(
124132
end
125133

126134
if path_to_shard_info !== nothing # restore sharding
127-
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
135+
sharding = __reconstruct_shardinfo(
136+
path, path_to_shard_info, sharding_mesh, ndims(tocopy)
137+
)
128138
return :(ConcreteRArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}(
129139
($(tocopy.data)...,), $(tocopy.shape), $sharding
130140
))
@@ -477,11 +487,8 @@ function compile_mlir(f, args; client=nothing, kwargs...)
477487
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
478488
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
479489

480-
if client !== nothing
481-
backend = XLA.platform_name(client)
482-
else
483-
backend = XLA.platform_name(XLA.default_backend[])
484-
end
490+
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
491+
485492
if backend == "CUDA"
486493
backend = "GPU"
487494
elseif backend == "CPU"
@@ -493,13 +500,6 @@ function compile_mlir(f, args; client=nothing, kwargs...)
493500

494501
mlir_fn_res = compile_mlir!(mod, f, args; backend, kwargs...)
495502

496-
client, _ = __resolve_device_and_client(
497-
client,
498-
mlir_fn_res.seen_args,
499-
mlir_fn_res.linear_args,
500-
mlir_fn_res.is_sharded,
501-
)
502-
503503
# Attach a name, and partitioning attributes to the module
504504
__add_mhlo_attributes_and_name!(
505505
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
@@ -1079,7 +1079,6 @@ function codegen_flatten!(
10791079

10801080
if is_sharded
10811081
carg = inv_seen_args[arg]
1082-
device_ids = mesh.sorted_device_ids
10831082
if Reactant.Sharding.is_sharded(carg)
10841083
# Currently disabling the error since we roundtrip from MHLO to generate
10851084
# the shardings
@@ -1091,7 +1090,7 @@ function codegen_flatten!(
10911090

10921091
push!(flatten_code, :($usbuf = $flatcode.data))
10931092
for j in 1:length(mesh)
1094-
sbuf = Symbol(:sbuf_, i, "_", device_ids[j])
1093+
sbuf = Symbol(:sbuf_, i, "_", mesh.device_ids[j])
10951094
push!(flatten_names, sbuf)
10961095
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
10971096
end
@@ -1101,18 +1100,18 @@ function codegen_flatten!(
11011100
)
11021101
push!(flatten_code, :($usbuf = $flatcode))
11031102
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
1104-
condensed_op_sharding, size(carg), mesh
1103+
condensed_op_sharding, size(carg), mesh.device_ids
11051104
)
11061105
for j in 1:length(mesh)
1107-
local_device_id = device_ids[j]
1108-
buf = Symbol(:buf_, i, :_, local_device_id)
1106+
device_id = mesh.device_ids[j]
1107+
buf = Symbol(:buf_, i, :_, device_id)
11091108
slice = device_to_array_slices[j]
11101109
push!(
11111110
flatten_code,
11121111
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
11131112
)
1114-
sbuf = Symbol(:sbuf_, i, :_, local_device_id)
1115-
device = XLA.get_addressable_device(client, local_device_id)
1113+
sbuf = Symbol(:s, buf)
1114+
device = XLA.get_device(client, device_id)
11161115
push!(flatten_names, sbuf)
11171116
push!(flatten_code, :($sbuf = XLA.copy_buffer_to_device($buf, $device)))
11181117
end
@@ -1386,7 +1385,7 @@ end
13861385

13871386
function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
13881387
if is_sharded
1389-
client === nothing && (client = XLA.default_backend[])
1388+
client === nothing && (client = XLA.default_backend())
13901389
return client, nothing
13911390
end
13921391

@@ -1412,14 +1411,14 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
14121411
if device !== nothing
14131412
client = XLA.client(device)
14141413
else
1415-
client = XLA.default_backend[]
1416-
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
1414+
client = XLA.default_backend()
1415+
device = XLA.default_device(client)
14171416
end
14181417
else
14191418
if device !== nothing
14201419
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
14211420
else
1422-
device = XLA.get_addressable_device(client, XLA.default_device_idx[])
1421+
device = XLA.default_device(client)
14231422
end
14241423
end
14251424

@@ -1432,11 +1431,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
14321431
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
14331432
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
14341433

1435-
if client !== nothing
1436-
backend = XLA.platform_name(client)
1437-
else
1438-
backend = XLA.platform_name(XLA.default_backend[])
1439-
end
1434+
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
1435+
14401436
if backend == "CUDA"
14411437
backend = "GPU"
14421438
elseif backend == "CPU"
@@ -1463,8 +1459,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
14631459
)
14641460

14651461
# compile MLIR module to XLA executable
1466-
local_device_ids = if mlir_fn_res.is_sharded
1467-
collect(Int64, mlir_fn_res.sharding_mesh.sorted_device_ids)
1462+
global_device_ids = if mlir_fn_res.is_sharded
1463+
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
14681464
else
14691465
Int64[]
14701466
end
@@ -1477,7 +1473,9 @@ function compile_xla(f, args; client=nothing, kwargs...)
14771473
num_outputs=length(mlir_fn_res.linear_results),
14781474
num_parameters=length(mlir_fn_res.linear_args),
14791475
mlir_fn_res.is_sharded,
1480-
local_device_ids,
1476+
global_device_ids,
1477+
mlir_fn_res.num_replicas,
1478+
mlir_fn_res.num_partitions,
14811479
)
14821480

14831481
return mod, exec, mlir_fn_res, device, client
@@ -1525,10 +1523,10 @@ function compile(f, args; sync=false, kwargs...)
15251523

15261524
linear_result_shard_info = if mlir_fn_res.is_sharded
15271525
output_shardings = XLA.get_output_shardings(exec)
1528-
XLA.compute_array_indices_and_partition_spec.(
1526+
XLA.compute_array_indices_and_hlo_sharding.(
15291527
output_shardings,
15301528
size.(mlir_fn_res.linear_results),
1531-
(mlir_fn_res.sharding_mesh,),
1529+
(mlir_fn_res.sharding_mesh.logical_device_ids,),
15321530
)
15331531
else
15341532
ntuple(Returns(nothing), length(linear_results))

src/Devices.jl

+9-13
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
"""
22
devices(backend::String)
3-
devices(backend::XLA.AbstractClient = XLA.default_backend[])
3+
devices(backend::XLA.AbstractClient = XLA.default_backend())
44
5-
Return a list of devices available on the backend.
5+
Return a list of devices available for the given client.
66
"""
7-
devices(backend::String) = devices(XLA.backends[backend])
7+
devices(backend::String) = devices(XLA.client(backend))
88

9-
function devices(client::XLA.AbstractClient=XLA.default_backend[])
10-
ndevices = XLA.num_devices(client)
11-
return [XLA.get_device(client, i - 1) for i in 1:ndevices]
12-
end
9+
devices(client::XLA.AbstractClient=XLA.default_backend()) = XLA.devices(client)
1310

1411
"""
1512
addressable_devices(backend::String)
16-
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend[])
13+
addressable_devices(backend::XLA.AbstractClient = XLA.default_backend())
1714
18-
Return a list of addressable devices available on the backend.
15+
Return a list of addressable devices available for the given client.
1916
"""
20-
addressable_devices(backend::String) = addressable_devices(XLA.backends[backend])
17+
addressable_devices(backend::String) = addressable_devices(XLA.client(backend))
2118

22-
function addressable_devices(client::XLA.AbstractClient=XLA.default_backend[])
23-
ndevices = XLA.num_addressable_devices(client)
24-
return [XLA.get_addressable_device(client, i - 1) for i in 1:ndevices]
19+
function addressable_devices(client::XLA.AbstractClient=XLA.default_backend())
20+
return XLA.addressable_devices(client)
2521
end
2622

2723
# https://github.com/jax-ml/jax/blob/152099ee0ef31119f16f4c2dac50d84fcb1575ef/jax/_src/hardware_utils.py#L19-L55

0 commit comments

Comments
 (0)