Skip to content

Commit 831ec8d

Browse files
committed
fix: non-iota mesh execution
1 parent 953b869 commit 831ec8d

File tree

4 files changed

+37
-48
lines changed

4 files changed

+37
-48
lines changed

deps/ReactantExtra/API.cpp

+24-35
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
786786
for (int64_t i = 0; i < num_mesh_ids; ++i) {
787787
int64_t mesh_id = mesh_ids[i];
788788
assert(mesh_id >= 0);
789-
device_assignment(0, mesh_id) = i;
789+
device_assignment(0, i) = mesh_id;
790790
}
791791
options.executable_build_options.set_device_assignment(device_assignment);
792792

@@ -945,31 +945,20 @@ void PrintPjRtBuffer(PjRtBuffer *buffer) {
945945
}
946946

947947
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
948-
PjRtBuffer **op_args, const int64_t *mesh_ids,
949-
int64_t num_mesh_ids, uint8_t *is_arg_donatable,
950-
int num_results, PjRtBuffer **op_results,
951-
uint8_t *futures, FutureType **future_results) {
952-
// Ensure argument_handles is structured as num_mesh_ids x num_args
953-
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_mesh_ids);
954-
int num_args = op_args_len / num_mesh_ids;
948+
PjRtBuffer **op_args, int64_t num_devices,
949+
uint8_t *is_arg_donatable, int num_results,
950+
PjRtBuffer **op_results, uint8_t *futures,
951+
FutureType **future_results) {
952+
// Ensure argument_handles is structured as num_devices x num_args
953+
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);
954+
int num_args = op_args_len / num_devices;
955955

956956
// Distribute arguments across devices
957-
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
958-
int64_t mesh_id = mesh_ids[device_idx];
959-
960-
// Validate mesh_id
961-
if (mesh_id < 0 || mesh_id >= num_mesh_ids) {
962-
ReactantThrowError(("Invalid mesh_id " + std::to_string(mesh_id) +
963-
" at device_idx " + std::to_string(device_idx))
964-
.c_str());
965-
}
966-
957+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
967958
argument_handles[device_idx].reserve(num_args);
968959
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
969-
// Assuming op_args is a flat array of size num_devices * num_args
970-
// where arguments for each device are contiguous
971960
argument_handles[device_idx].push_back(
972-
op_args[mesh_id * num_args + arg_idx]);
961+
op_args[device_idx * num_args + arg_idx]);
973962
}
974963
}
975964

@@ -989,40 +978,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
989978
argument_handles),
990979
options, returned_futures));
991980

992-
if (results.size() != num_mesh_ids) {
981+
if (results.size() != num_devices) {
993982
ReactantThrowError((" results.size()=" + std::to_string(results.size()) +
994-
" num_mesh_ids=" + std::to_string(num_mesh_ids) + "\n")
983+
" num_devices=" + std::to_string(num_devices) + "\n")
995984
.c_str());
996985
}
997986

998-
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
999-
int64_t mesh_id = mesh_ids[device_idx];
1000-
if (results[mesh_id].size() != num_results) {
1001-
ReactantThrowError((" results[" + std::to_string(mesh_id) + "].size()=" +
1002-
std::to_string(results[mesh_id].size()) +
1003-
" num_results=" + std::to_string(num_results) + "\n")
1004-
.c_str());
987+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
988+
// Remove mesh_id lookup since we're using device_idx ordering
989+
if (results[device_idx].size() != num_results) {
990+
ReactantThrowError(
991+
(" results[" + std::to_string(device_idx) +
992+
"].size()=" + std::to_string(results[device_idx].size()) +
993+
" num_results=" + std::to_string(num_results) + "\n")
994+
.c_str());
1005995
}
1006996
}
1007997

1008998
// Handle returned futures
1009999
auto future_val = returned_futures.has_value();
10101000
*futures = future_val;
10111001
if (future_val) {
1012-
if (returned_futures->size() != num_mesh_ids) {
1002+
if (returned_futures->size() != num_devices) {
10131003
ReactantThrowError((" returned_futures->size()=" +
10141004
std::to_string(returned_futures->size()) +
1015-
" num_mesh_ids=" + std::to_string(num_mesh_ids) +
1005+
" num_devices=" + std::to_string(num_devices) +
10161006
"\n")
10171007
.c_str());
10181008
}
10191009
}
10201010

10211011
// Copy results into the output buffers
1022-
for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) {
1023-
int64_t mesh_id = mesh_ids[device_idx];
1012+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
10241013
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
1025-
int flat_index = mesh_id * num_results + result_idx;
1014+
int flat_index = device_idx * num_results + result_idx;
10261015
op_results[flat_index] = results[device_idx][result_idx].release();
10271016
if (future_val) {
10281017
future_results[flat_index] =

src/Compiler.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -1307,12 +1307,17 @@ Generate Julia code to call the XLA executable.
13071307
- `nresults`: The number of results to expect.
13081308
"""
13091309
function codegen_xla_call(
1310-
exec, device, flatten_names, donated_args_mask, nresults, is_sharded::Bool, mesh_ids
1310+
exec,
1311+
device,
1312+
flatten_names,
1313+
donated_args_mask,
1314+
nresults,
1315+
is_sharded::Bool,
1316+
ndevices::Int,
13111317
)
13121318
flatten_buffer_refs = map(n -> :($n.buffer), flatten_names)
13131319

1314-
base_symbol_name =
1315-
is_sharded ? Symbol(:result_buffer_m, length(mesh_ids), :_) : :result_buffer_
1320+
base_symbol_name = is_sharded ? Symbol(:result_buffer_m, ndevices, :_) : :result_buffer_
13161321
concretized_res_names = Symbol[Symbol(base_symbol_name, i) for i in 1:nresults]
13171322
concretized_res_code = map(enumerate(concretized_res_names)) do (i, varname)
13181323
:($varname = linearized_results[$i])
@@ -1326,11 +1331,10 @@ function codegen_xla_call(
13261331
GC.@preserve $(flatten_names...) begin
13271332
linearized_results = XLA.execute(
13281333
$exec,
1329-
$(mesh_ids),
13301334
($(flatten_buffer_refs...),),
13311335
$(Tuple(donated_args_mask)),
13321336
Val($nresults),
1333-
Val($(length(mesh_ids))),
1337+
Val($ndevices),
13341338
)
13351339
end
13361340
$(concretized_res_code...)
@@ -1509,7 +1513,7 @@ function compile(f, args; sync=false, kwargs...)
15091513
donated_args_mask,
15101514
length(linear_results),
15111515
mlir_fn_res.is_sharded,
1512-
mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[],
1516+
mlir_fn_res.is_sharded ? length(mlir_fn_res.sharding_mesh) : 1,
15131517
)
15141518

15151519
linear_result_shard_info = if mlir_fn_res.is_sharded

src/xla/PJRT/LoadedExecutable.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -241,25 +241,22 @@ end
241241

242242
@inline function XLA.execute(
243243
exec::LoadedExecutable,
244-
mesh_ids::Vector{Int64},
245244
inputs::NTuple{N,Ptr{Cvoid}},
246245
donated_args::NTuple{M,UInt8},
247246
::Val{n_outs},
248247
::Val{K},
249248
) where {N,M,n_outs,K}
250-
@assert length(mesh_ids) == K
251249
outputs = Ref{NTuple{n_outs * K,Ptr{Cvoid}}}()
252250
future_res = Ref{NTuple{n_outs * K,Ptr{Cvoid}}}()
253251
futures = Ref{UInt8}(0)
254252

255253
inputs = Base.RefValue(inputs)
256254
donated_args = Base.RefValue(donated_args)
257-
GC.@preserve inputs donated_args mesh_ids outputs futures future_res begin
255+
GC.@preserve inputs donated_args outputs futures future_res begin
258256
@ccall MLIR.API.mlir_c.XLAExecute(
259257
exec.exec::Ptr{Cvoid},
260258
N::Cint,
261259
inputs::Ptr{Cvoid},
262-
mesh_ids::Ptr{Clong},
263260
K::Clong,
264261
donated_args::Ptr{UInt8},
265262
n_outs::Cint,

test/sharding.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,9 @@ end
104104
mesh = Sharding.Mesh(reshape([4, 6, 0, 2, 7, 3, 1, 5], 4, 2), ("data", "model"))
105105
x = reshape(collect(Float32, 1:16), 4, 4)
106106
x_ra = Reactant.to_rarray(
107-
x; sharding=Sharding.NamedSharding(mesh, ("data", nothing))
107+
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
108108
)
109-
# XXX: This needs to be fixed
110-
@test_broken Array(@jit sum(x_ra)) sum(x)
109+
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
111110
else
112111
@warn "Not enough addressable devices to run sharding tests"
113112
end

0 commit comments

Comments
 (0)