@@ -162,6 +162,7 @@ function XLA.to_host(buffer::Array, data, reactant_sharding)
162
162
client = XLA. client (buffer)
163
163
all_devices = XLA. get_device .((client,), reactant_sharding. mesh. device_ids)
164
164
165
+ # TODO : Test if the below logic for replication works for distributed cases as well
165
166
if any (! XLA. is_addressable, all_devices)
166
167
@warn " Not all devices are addressable. Currently we only fill in the data for \
167
168
addressable devices. Remaining slices of data in `data` are left \
@@ -202,23 +203,23 @@ function disassemble_into_single_device_arrays(array::Array, only_addressable_de
202
203
return [Array (unsafe_load (arrays, i)) for i in 1 : narrays[]]
203
204
end
204
205
205
- function replicate_array_to_all_devices (array:: Array , sharding, mesh)
206
+ function replicate_array_to_all_devices (array:: Array , sharding, mesh, size_arr )
206
207
is_fully_replicated (XLA. sharding (array)) && return array
207
208
208
209
hlo_sharding = Reactant. Sharding. HloSharding (
209
210
convert (XLA. HloSharding, sharding),
210
211
mesh,
211
- ntuple (Returns (1 ), ndims (array )),
212
- ntuple (Returns (- 1 ), ndims (array )),
212
+ ntuple (Returns (1 ), length (size_arr )),
213
+ ntuple (Returns (- 1 ), length (size_arr )),
213
214
)
214
215
shard_info = Reactant. Sharding. ShardInfo (
215
- hlo_sharding, Reactant. Sharding. sharding_to_array_slices (hlo_sharding, size (array) )
216
+ hlo_sharding, Reactant. Sharding. sharding_to_array_slices (hlo_sharding, size_arr )
216
217
)
217
218
sharding_constraint = Reactant. Sharding. NamedSharding (
218
- mesh, ntuple (Returns (nothing ), ndims (array ))
219
+ mesh, ntuple (Returns (nothing ), length (size_arr ))
219
220
)
220
- data = Reactant. ConcreteIFRTArray {eltype(array),ndims(array), typeof(shard_info)} (
221
- AsyncArray (array, nothing ), size (array) , shard_info
221
+ data = Reactant. ConcreteIFRTArray {eltype(array),length(size_arr), typeof(shard_info)} (
222
+ AsyncArray (array, nothing ), size_arr , shard_info
222
223
)
223
224
224
225
fn (x) = Reactant. Ops. sharding_constraint (x, sharding_constraint)
0 commit comments