Skip to content

Commit 375120e

Browse files
committed
fix: replication
1 parent 9d0ea12 commit 375120e

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

src/Compiler.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,18 @@ function generate_unresharded_ifrt_array(
282282
size_arr,
283283
) where {T}
284284
single_device_arrays = Reactant.XLA.IFRT.disassemble_into_single_device_arrays(
285-
Reactant.XLA.IFRT.replicate_array_to_all_devices(arr, output_sharding, mesh), true
285+
Reactant.XLA.IFRT.replicate_array_to_all_devices(
286+
arr, output_sharding, mesh, size_arr
287+
),
288+
true,
286289
)
287290
devs = Reactant.XLA.device.(single_device_arrays)
288291
idx = findfirst(isequal(target_device), devs)
289-
return ConcreteIFRTArray{T,N}(
290-
Reactant.XLA.IFRT.AsyncArray(single_device_arrays[idx], nothing), size(arr)
291-
)
292+
res_arr = Reactant.XLA.IFRT.AsyncArray(single_device_arrays[idx], nothing)
293+
res_arr_size = reverse(size(res_arr))
294+
@assert size_arr == res_arr_size "Expected size of array to be $(size_arr), but got \
295+
$(res_arr_size)"
296+
return ConcreteIFRTArray{T,N}(res_arr, size_arr)
292297
end
293298

294299
function create_result(tocopy::Array{T,N}, path, args...) where {T,N}

src/xla/IFRT/Array.jl

+8-7
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
162162
client = XLA.client(buffer)
163163
all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids)
164164

165+
# TODO: Test if the below logic for replication works for distributed cases as well
165166
if any(!XLA.is_addressable, all_devices)
166167
@warn "Not all devices are addressable. Currently we only fill in the data for \
167168
addressable devices. Remaining slices of data in `data` are left \
@@ -202,23 +203,23 @@ function disassemble_into_single_device_arrays(array::Array, only_addressable_de
202203
return [Array(unsafe_load(arrays, i)) for i in 1:narrays[]]
203204
end
204205

205-
function replicate_array_to_all_devices(array::Array, sharding, mesh)
206+
function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
206207
is_fully_replicated(XLA.sharding(array)) && return array
207208

208209
hlo_sharding = Reactant.Sharding.HloSharding(
209210
convert(XLA.HloSharding, sharding),
210211
mesh,
211-
ntuple(Returns(1), ndims(array)),
212-
ntuple(Returns(-1), ndims(array)),
212+
ntuple(Returns(1), length(size_arr)),
213+
ntuple(Returns(-1), length(size_arr)),
213214
)
214215
shard_info = Reactant.Sharding.ShardInfo(
215-
hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size(array))
216+
hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size_arr)
216217
)
217218
sharding_constraint = Reactant.Sharding.NamedSharding(
218-
mesh, ntuple(Returns(nothing), ndims(array))
219+
mesh, ntuple(Returns(nothing), length(size_arr))
219220
)
220-
data = Reactant.ConcreteIFRTArray{eltype(array),ndims(array), typeof(shard_info)}(
221-
AsyncArray(array, nothing), size(array), shard_info
221+
data = Reactant.ConcreteIFRTArray{eltype(array),length(size_arr),typeof(shard_info)}(
222+
AsyncArray(array, nothing), size_arr, shard_info
222223
)
223224

224225
fn(x) = Reactant.Ops.sharding_constraint(x, sharding_constraint)

0 commit comments

Comments
 (0)