@@ -786,7 +786,7 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
786
786
for (int64_t i = 0 ; i < num_mesh_ids; ++i) {
787
787
int64_t mesh_id = mesh_ids[i];
788
788
assert (mesh_id >= 0 );
789
- device_assignment (0 , mesh_id ) = i ;
789
+ device_assignment (0 , i ) = mesh_id ;
790
790
}
791
791
options.executable_build_options .set_device_assignment (device_assignment);
792
792
@@ -945,31 +945,20 @@ void PrintPjRtBuffer(PjRtBuffer *buffer) {
945
945
}
946
946
947
947
extern " C" void XLAExecute (xla::PjRtLoadedExecutable *exec, int op_args_len,
948
- PjRtBuffer **op_args, const int64_t *mesh_ids ,
949
- int64_t num_mesh_ids, uint8_t *is_arg_donatable,
950
- int num_results, PjRtBuffer **op_results,
951
- uint8_t *futures, FutureType **future_results) {
952
- // Ensure argument_handles is structured as num_mesh_ids x num_args
953
- std::vector<std::vector<PjRtBuffer *>> argument_handles (num_mesh_ids );
954
- int num_args = op_args_len / num_mesh_ids ;
948
+ PjRtBuffer **op_args, int64_t num_devices ,
949
+ uint8_t *is_arg_donatable, int num_results ,
950
+ PjRtBuffer **op_results, uint8_t *futures ,
951
+ FutureType **future_results) {
952
+ // Ensure argument_handles is structured as num_devices x num_args
953
+ std::vector<std::vector<PjRtBuffer *>> argument_handles (num_devices );
954
+ int num_args = op_args_len / num_devices ;
955
955
956
956
// Distribute arguments across devices
957
- for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
958
- int64_t mesh_id = mesh_ids[device_idx];
959
-
960
- // Validate mesh_id
961
- if (mesh_id < 0 || mesh_id >= num_mesh_ids) {
962
- ReactantThrowError ((" Invalid mesh_id " + std::to_string (mesh_id) +
963
- " at device_idx " + std::to_string (device_idx))
964
- .c_str ());
965
- }
966
-
957
+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
967
958
argument_handles[device_idx].reserve (num_args);
968
959
for (int arg_idx = 0 ; arg_idx < num_args; ++arg_idx) {
969
- // Assuming op_args is a flat array of size num_devices * num_args
970
- // where arguments for each device are contiguous
971
960
argument_handles[device_idx].push_back (
972
- op_args[mesh_id * num_args + arg_idx]);
961
+ op_args[device_idx * num_args + arg_idx]);
973
962
}
974
963
}
975
964
@@ -989,40 +978,40 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len,
989
978
argument_handles),
990
979
options, returned_futures));
991
980
992
- if (results.size () != num_mesh_ids ) {
981
+ if (results.size () != num_devices ) {
993
982
ReactantThrowError ((" results.size()=" + std::to_string (results.size ()) +
994
- " num_mesh_ids =" + std::to_string (num_mesh_ids ) + " \n " )
983
+ " num_devices =" + std::to_string (num_devices ) + " \n " )
995
984
.c_str ());
996
985
}
997
986
998
- for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
999
- int64_t mesh_id = mesh_ids[device_idx];
1000
- if (results[mesh_id].size () != num_results) {
1001
- ReactantThrowError ((" results[" + std::to_string (mesh_id) + " ].size()=" +
1002
- std::to_string (results[mesh_id].size ()) +
1003
- " num_results=" + std::to_string (num_results) + " \n " )
1004
- .c_str ());
987
+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
988
+ // Remove mesh_id lookup since we're using device_idx ordering
989
+ if (results[device_idx].size () != num_results) {
990
+ ReactantThrowError (
991
+ (" results[" + std::to_string (device_idx) +
992
+ " ].size()=" + std::to_string (results[device_idx].size ()) +
993
+ " num_results=" + std::to_string (num_results) + " \n " )
994
+ .c_str ());
1005
995
}
1006
996
}
1007
997
1008
998
// Handle returned futures
1009
999
auto future_val = returned_futures.has_value ();
1010
1000
*futures = future_val;
1011
1001
if (future_val) {
1012
- if (returned_futures->size () != num_mesh_ids ) {
1002
+ if (returned_futures->size () != num_devices ) {
1013
1003
ReactantThrowError ((" returned_futures->size()=" +
1014
1004
std::to_string (returned_futures->size ()) +
1015
- " num_mesh_ids =" + std::to_string (num_mesh_ids ) +
1005
+ " num_devices =" + std::to_string (num_devices ) +
1016
1006
" \n " )
1017
1007
.c_str ());
1018
1008
}
1019
1009
}
1020
1010
1021
1011
// Copy results into the output buffers
1022
- for (int device_idx = 0 ; device_idx < num_mesh_ids; ++device_idx) {
1023
- int64_t mesh_id = mesh_ids[device_idx];
1012
+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
1024
1013
for (int result_idx = 0 ; result_idx < num_results; ++result_idx) {
1025
- int flat_index = mesh_id * num_results + result_idx;
1014
+ int flat_index = device_idx * num_results + result_idx;
1026
1015
op_results[flat_index] = results[device_idx][result_idx].release ();
1027
1016
if (future_val) {
1028
1017
future_results[flat_index] =
0 commit comments