Skip to content

Commit 5b449d1

Browse files
committed
fix: Array construction from SingleShards
1 parent 7e45011 commit 5b449d1

File tree

8 files changed

+201
-35
lines changed

8 files changed

+201
-35
lines changed

deps/ReactantExtra/API.cpp

+39-3
Original file line numberDiff line numberDiff line change
@@ -1273,9 +1273,9 @@ extern "C" HeldIfrtArray *ifrt_client_make_single_shard_array_from_host_buffer(
12731273

12741274
// all arrays are assumed to have same DType
12751275
extern "C" HeldIfrtArray *ifrt_client_assemble_array_from_single_shards(
1276-
ifrt::Client *client, int ndims, const int64_t *c_shape,
1277-
HeldValue<std::shared_ptr<const ifrt::Sharding>> *sharding, int narrays,
1278-
HeldIfrtArray **c_arrays, int c_semantics) {
1276+
ifrt::Client *client, int32_t ndims, const int64_t *c_shape,
1277+
HeldValue<std::shared_ptr<const ifrt::Sharding>> *sharding, int32_t narrays,
1278+
HeldIfrtArray **c_arrays, int32_t c_semantics) {
12791279
auto shape = ifrt::Shape(absl::Span<const int64_t>(c_shape, ndims));
12801280
std::vector<tsl::RCReference<ifrt::Array>> arrays;
12811281
for (int i = 0; i < narrays; i++) {
@@ -1692,6 +1692,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
16921692
return device->client();
16931693
}
16941694

1695+
extern "C" bool ifrt_DeviceIsAddressable(ifrt::Device *device) {
1696+
return device->IsAddressable();
1697+
}
1698+
16951699
tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices(
16961700
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
16971701
absl::Span<ifrt::Device *const> devices(device_list, num_devices);
@@ -1916,6 +1920,14 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
19161920
return cstr_from_string(hlo_sharding->DebugString());
19171921
}
19181922

1923+
extern "C" ifrt::HloSharding *ifrt_sharding_to_ifrt_hlo_sharding(
1924+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1925+
const ifrt::Sharding *val = sharding->obj().get();
1926+
if (!llvm::isa<ifrt::HloSharding>(val))
1927+
ReactantThrowError("Expected a HloSharding");
1928+
return new ifrt::HloSharding(*llvm::dyn_cast<const ifrt::HloSharding>(val));
1929+
}
1930+
19191931
extern "C" void
19201932
free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
19211933
delete sharding;
@@ -1936,11 +1948,35 @@ ifrt_sharding_from_hlo_sharding(ifrt::Client *client,
19361948
memory_kind, xla_hlo_sharding));
19371949
}
19381950

1951+
extern "C" bool ifrt_sharding_is_single_device_sharding(
1952+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1953+
return llvm::isa<const ifrt::SingleDeviceSharding>(sharding->obj().get());
1954+
}
1955+
1956+
extern "C" bool ifrt_sharding_is_fully_replicated(
1957+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1958+
return sharding->obj()->IsFullyReplicated();
1959+
}
1960+
19391961
extern "C" const char *
19401962
ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
19411963
return cstr_from_string(sharding->obj()->DebugString());
19421964
}
19431965

1966+
extern "C" int32_t ifrt_sharding_devices_size(
1967+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1968+
return sharding->obj()->devices()->size();
1969+
}
1970+
1971+
extern "C" void ifrt_sharding_to_device_list(
1972+
HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding,
1973+
ifrt::Device **devices) {
1974+
auto device_list = sharding->obj()->devices()->devices();
1975+
for (int i = 0; i < device_list.size(); i++) {
1976+
devices[i] = device_list[i];
1977+
}
1978+
}
1979+
19441980
#pragma endregion
19451981

19461982
typedef ifrt::Future<> IfRtFutureType;

src/Compiler.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ function codegen_flatten!(
11071107
)
11081108
push!(flatten_code, :($usbuf = $flatcode))
11091109
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
1110-
condensed_op_sharding, size(carg), mesh
1110+
condensed_op_sharding, size(carg), mesh.device_ids
11111111
)
11121112
for j in 1:length(mesh)
11131113
device_id = mesh.device_ids[j]

src/Sharding.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ function (sharding::HloSharding)(
280280
condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding)
281281

282282
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
283-
condensed_op_sharding, size(x), sharding.mesh
283+
condensed_op_sharding, size(x), sharding.mesh.logical_device_ids
284284
)
285285

286286
data = ntuple(length(sharding.mesh)) do i

src/xla/Device.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function get_local_device_id end
1010
function device_kind end
1111
function default_memory end
1212
function memories end
13+
function is_addressable end
1314

1415
"""
1516
device_ordinal(device::Device)

src/xla/IFRT/Array.jl

+61-17
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,91 @@ mutable struct Array <: XLA.AbstractBuffer
22
buffer::Ptr{Cvoid}
33

44
function Array(buffer::Ptr{Cvoid})
5-
return finalizer(free_ifrt_array, new(buffer))
5+
# return finalizer(free_ifrt_array, new(buffer))
6+
return new(buffer)
67
end
78
end
89

9-
function Array(client::Client, array::Base.Array{T,N}, device::Device) where {T,N}
10+
function Array(
11+
client::Client,
12+
array::Base.Array{T,N},
13+
device::Device,
14+
memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))),
15+
) where {T,N}
1016
sizear = collect(Int64, reverse(size(array)))
1117
buffer = GC.@preserve array sizear begin
1218
@ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer(
1319
client.client::Ptr{Cvoid},
14-
pointer(array)::Ptr{T},
20+
array::Ptr{T},
1521
XLA.primitive_type(T)::UInt64,
1622
N::Csize_t,
17-
pointer(sizear)::Ptr{Int64},
23+
sizear::Ptr{Int64},
1824
0::Cint, # kAlwaysCopy
1925
device.device::Ptr{Cvoid},
20-
string(convert(MemoryKind, XLA.default_memory(device)))::Cstring,
26+
string(memory_kind)::Cstring,
2127
)::Ptr{Cvoid}
2228
end
2329
return Array(buffer)
2430
end
2531

26-
function Array(client::Client, array::Base.Array{T,N}, sharding::HloSharding) where {T,N}
27-
return Array(client, array, convert(Sharding, sharding))
28-
end
29-
30-
function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N}
32+
function Array(
33+
client::Client, array::Base.Array{T,N}, sharding::Sharding, logical_device_ids
34+
) where {T,N}
3135
sizear = collect(Int64, reverse(size(array)))
32-
buffer = GC.@preserve array sizear begin
33-
@ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer(
36+
37+
if is_single_device_sharding(sharding) || is_fully_replicated(sharding)
38+
buffer = GC.@preserve array sizear begin
39+
@ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer(
40+
client.client::Ptr{Cvoid},
41+
array::Ptr{T},
42+
XLA.primitive_type(T)::Cint,
43+
N::Csize_t,
44+
sizear::Ptr{Int64},
45+
sharding.ptr::Ptr{Cvoid},
46+
0::Cint, # kAlwaysCopy
47+
)::Ptr{Cvoid}
48+
end
49+
return Array(buffer)
50+
end
51+
52+
all_devices = XLA.devices(sharding)
53+
array_slices = XLA.sharding_to_concrete_array_indices(
54+
convert(XLA.HloSharding, sharding), size(array), logical_device_ids
55+
)
56+
array_shape = collect(Int64, reverse(size(array)))
57+
arrays_list = [
58+
Array(client, array[slice...], device).buffer for
59+
(device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device)
60+
]
61+
62+
buffer = GC.@preserve client arrays_list array_shape sharding begin
63+
@ccall MLIR.API.mlir_c.ifrt_client_assemble_array_from_single_shards(
3464
client.client::Ptr{Cvoid},
35-
pointer(array)::Ptr{T},
36-
XLA.primitive_type(T)::Cint,
37-
N::Csize_t,
38-
pointer(sizear)::Ptr{Int64},
65+
Int32(length(array_shape))::Int32,
66+
array_shape::Ptr{Int64},
3967
sharding.ptr::Ptr{Cvoid},
40-
0::Cint, # kAlwaysCopy
68+
Int32(length(arrays_list))::Int32,
69+
arrays_list::Ptr{Ptr{Cvoid}},
70+
2::Cint, # kDonateInput
4171
)::Ptr{Cvoid}
4272
end
73+
4374
return Array(buffer)
4475
end
4576

77+
function Array(client::Client, array::Base.Array{T,N}, sharding) where {T,N}
78+
@assert sharding isa Reactant.Sharding.AbstractSharding
79+
if !(sharding isa Reactant.Sharding.HloSharding)
80+
sharding = convert(Reactant.Sharding.HloSharding, sharding)
81+
end
82+
83+
(; hlo_sharding, mesh) = sharding
84+
devices = XLA.get_device.((client,), mesh.device_ids)
85+
ifrt_sharding = Sharding([devices...], hlo_sharding)
86+
87+
return Array(client, array, ifrt_sharding, mesh.logical_device_ids)
88+
end
89+
4690
@inline function free_ifrt_array(buffer::Array)
4791
sbuffer = buffer.buffer
4892
if sbuffer != C_NULL

src/xla/IFRT/Device.jl

+8
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,11 @@ function XLA.client(device_list::AbstractVector{Device})
6363
@assert allequal(clients) "All devices must have the same client"
6464
return first(clients)
6565
end
66+
67+
function XLA.is_addressable(device::Device)
68+
GC.@preserve device begin
69+
return @ccall MLIR.API.mlir_c.ifrt_DeviceIsAddressable(
70+
device.device::Ptr{Cvoid}
71+
)::Bool
72+
end
73+
end

src/xla/IFRT/Sharding.jl

+57-1
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@ function free_hlo_sharding(hlo_sharding::HloSharding)
1313
@ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid
1414
end
1515

16+
function Base.convert(::Type{XLA.HloSharding}, sharding::HloSharding)
17+
GC.@preserve sharding begin
18+
return XLA.HloSharding(
19+
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_xla_hlo_sharding(
20+
sharding.ptr::Ptr{Cvoid}
21+
)::Ptr{Cvoid}
22+
)
23+
end
24+
end
25+
1626
function HloSharding(
1727
device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding
1828
)
19-
default_memory_kind = convert(MemoryKind, XLA.default_memory(device_list))
29+
addressable_devices = filter(XLA.is_addressable, device_list)
30+
default_memory_kind = convert(MemoryKind, XLA.default_memory(addressable_devices))
2031
return HloSharding(device_list, xla_hlo_sharding, default_memory_kind)
2132
end
2233

@@ -89,6 +100,21 @@ function free_sharding(sharding::Sharding)
89100
@ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid
90101
end
91102

103+
function XLA.devices(sharding::Sharding)
104+
GC.@preserve sharding begin
105+
ndevices = @ccall MLIR.API.mlir_c.ifrt_sharding_devices_size(
106+
sharding.ptr::Ptr{Cvoid}
107+
)::Int32
108+
end
109+
devices = Ref{NTuple{Int64(ndevices),Ptr{Cvoid}}}()
110+
GC.@preserve sharding devices begin
111+
@ccall MLIR.API.mlir_c.ifrt_sharding_to_device_list(
112+
sharding.ptr::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}}
113+
)::Cvoid
114+
end
115+
return [Device(device) for device in devices[]]
116+
end
117+
92118
function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding)
93119
GC.@preserve hlo_sharding begin
94120
return Sharding(
@@ -99,6 +125,20 @@ function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding)
99125
end
100126
end
101127

128+
function Base.convert(::Type{HloSharding}, sharding::Sharding)
129+
GC.@preserve sharding begin
130+
return HloSharding(
131+
@ccall MLIR.API.mlir_c.ifrt_sharding_to_ifrt_hlo_sharding(
132+
sharding.ptr::Ptr{Cvoid}
133+
)::Ptr{Cvoid}
134+
)
135+
end
136+
end
137+
138+
function Base.convert(::Type{XLA.HloSharding}, sharding::Sharding)
139+
return convert(XLA.HloSharding, convert(HloSharding, sharding))
140+
end
141+
102142
function Base.string(sharding::Sharding)
103143
GC.@preserve sharding begin
104144
str = @ccall MLIR.API.mlir_c.ifrt_sharding_to_string(
@@ -108,6 +148,22 @@ function Base.string(sharding::Sharding)
108148
return XLA.unsafe_string_and_free(str)
109149
end
110150

151+
function is_fully_replicated(sharding::Sharding)
152+
GC.@preserve sharding begin
153+
return @ccall MLIR.API.mlir_c.ifrt_sharding_is_fully_replicated(
154+
sharding.ptr::Ptr{Cvoid}
155+
)::Bool
156+
end
157+
end
158+
159+
function is_single_device_sharding(sharding::Sharding)
160+
GC.@preserve sharding begin
161+
return @ccall MLIR.API.mlir_c.ifrt_sharding_is_single_device_sharding(
162+
sharding.ptr::Ptr{Cvoid}
163+
)::Bool
164+
end
165+
end
166+
111167
function Base.show(io::IO, ::MIME"text/plain", sharding::Sharding)
112168
print(io, "XLA.IFRT.Sharding(\"", string(sharding), "\")")
113169
return nothing

0 commit comments

Comments
 (0)