@@ -478,9 +478,9 @@ function compile_mlir(f, args; client=nothing, kwargs...)
478
478
@ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
479
479
480
480
if client != = nothing
481
- backend = XLA. ClientGetPlatformName (client)
481
+ backend = XLA. platform_name (client)
482
482
else
483
- backend = XLA. ClientGetPlatformName (XLA. default_backend[])
483
+ backend = XLA. platform_name (XLA. default_backend[])
484
484
end
485
485
if backend == " CUDA"
486
486
backend = " GPU"
@@ -1076,9 +1076,7 @@ function codegen_flatten!(
1076
1076
1077
1077
if is_sharded
1078
1078
carg = inv_seen_args[arg]
1079
- condensed_op_sharding = Reactant. Sharding. XLA. CondensedOpSharding (
1080
- linear_parameter_shardings[i]
1081
- )
1079
+ device_ids = mesh. sorted_device_ids
1082
1080
if Reactant. Sharding. is_sharded (carg)
1083
1081
# Currently disabling the error since we roundtrip from MHLO to generate
1084
1082
# the shardings
@@ -1090,29 +1088,30 @@ function codegen_flatten!(
1090
1088
1091
1089
push! (flatten_code, :($ usbuf = $ flatcode. data))
1092
1090
for j in 1 : length (mesh)
1093
- sbuf = Symbol (:sbuf_ , i, " _" , j )
1091
+ sbuf = Symbol (:sbuf_ , i, " _" , device_ids[j] )
1094
1092
push! (flatten_names, sbuf)
1095
1093
push! (flatten_code, :($ sbuf = XLA. synced_buffer (getindex ($ usbuf, $ j))))
1096
1094
end
1097
1095
else
1096
+ condensed_op_sharding = convert (
1097
+ Reactant. Sharding. XLA. CondensedOpSharding, linear_parameter_shardings[i]
1098
+ )
1098
1099
push! (flatten_code, :($ usbuf = $ flatcode))
1099
1100
device_to_array_slices = XLA. sharding_to_concrete_array_indices (
1100
1101
condensed_op_sharding, size (carg), mesh
1101
1102
)
1102
- device_ids = vec (mesh)
1103
1103
for j in 1 : length (mesh)
1104
- buf = Symbol ( :buf_ , i, :_ , j)
1105
- device_id = device_ids[j]
1104
+ local_device_id = device_ids[j]
1105
+ buf = Symbol ( :buf_ , i, :_ , local_device_id)
1106
1106
slice = device_to_array_slices[j]
1107
1107
push! (
1108
1108
flatten_code,
1109
1109
:($ buf = XLA. synced_buffer (only ($ usbuf[$ (slice). .. ]. data))),
1110
1110
)
1111
- device_ordinal = XLA. device_ordinal (client, device_id)
1112
- sbuf = Symbol (:sbuf_ , i, :_ , j)
1113
- device = XLA. ClientGetAddressableDevice (client, device_ordinal)
1111
+ sbuf = Symbol (:sbuf_ , i, :_ , local_device_id)
1112
+ device = XLA. get_addressable_device (client, local_device_id)
1114
1113
push! (flatten_names, sbuf)
1115
- push! (flatten_code, :($ sbuf = XLA. CopyBufferToDevice ($ buf, $ device)))
1114
+ push! (flatten_code, :($ sbuf = XLA. copy_buffer_to_device ($ buf, $ device)))
1116
1115
end
1117
1116
end
1118
1117
else
@@ -1308,12 +1307,17 @@ Generate Julia code to call the XLA executable.
1308
1307
- `nresults`: The number of results to expect.
1309
1308
"""
1310
1309
function codegen_xla_call (
1311
- exec, device, flatten_names, donated_args_mask, nresults, is_sharded:: Bool , mesh_ids
1310
+ exec,
1311
+ device,
1312
+ flatten_names,
1313
+ donated_args_mask,
1314
+ nresults,
1315
+ is_sharded:: Bool ,
1316
+ ndevices:: Int ,
1312
1317
)
1313
1318
flatten_buffer_refs = map (n -> :($ n. buffer), flatten_names)
1314
1319
1315
- base_symbol_name =
1316
- is_sharded ? Symbol (:result_buffer_m , length (mesh_ids), :_ ) : :result_buffer_
1320
+ base_symbol_name = is_sharded ? Symbol (:result_buffer_m , ndevices, :_ ) : :result_buffer_
1317
1321
concretized_res_names = Symbol[Symbol (base_symbol_name, i) for i in 1 : nresults]
1318
1322
concretized_res_code = map (enumerate (concretized_res_names)) do (i, varname)
1319
1323
:($ varname = linearized_results[$ i])
@@ -1325,21 +1329,20 @@ function codegen_xla_call(
1325
1329
if is_sharded
1326
1330
quote
1327
1331
GC. @preserve $ (flatten_names... ) begin
1328
- linearized_results = XLA. ExecutableCall (
1332
+ linearized_results = XLA. execute (
1329
1333
$ exec,
1330
- $ (mesh_ids),
1331
1334
($ (flatten_buffer_refs... ),),
1332
1335
$ (Tuple (donated_args_mask)),
1333
1336
Val ($ nresults),
1334
- Val ($ ( length (mesh_ids)) ),
1337
+ Val ($ ndevices ),
1335
1338
)
1336
1339
end
1337
1340
$ (concretized_res_code... )
1338
1341
end
1339
1342
else
1340
1343
quote
1341
1344
GC. @preserve $ (flatten_names... ) begin
1342
- linearized_results = XLA. ExecutableCallSharded (
1345
+ linearized_results = XLA. execute_sharded (
1343
1346
$ exec,
1344
1347
$ (device),
1345
1348
($ (flatten_buffer_refs... ),),
@@ -1393,7 +1396,7 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
1393
1396
if ! allequal (devices_list)
1394
1397
msg = " Expected all arguments to be on the same device, got:\n "
1395
1398
for (i, device) in enumerate (devices_list)
1396
- msg *= " Device $(i) : $(XLA . DeviceToString (device)) \n "
1399
+ msg *= " Device $(i) : $(string (device)) \n "
1397
1400
end
1398
1401
throw (ArgumentError (msg))
1399
1402
end
@@ -1407,17 +1410,13 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
1407
1410
client = XLA. client (device)
1408
1411
else
1409
1412
client = XLA. default_backend[]
1410
- device = XLA. ClientGetAddressableDevice (
1411
- client, XLA. device_ordinal (client, XLA. default_device_idx[])
1412
- )
1413
+ device = XLA. get_addressable_device (client, XLA. default_device_idx[])
1413
1414
end
1414
1415
else
1415
1416
if device != = nothing
1416
1417
@assert client == XLA. client (device) " client ($(client) ) and XLA.client(device) ($(XLA. client (device)) ) must be the same"
1417
1418
else
1418
- device = XLA. ClientGetAddressableDevice (
1419
- client, XLA. device_ordinal (client, XLA. default_device_idx[])
1420
- )
1419
+ device = XLA. get_addressable_device (client, XLA. default_device_idx[])
1421
1420
end
1422
1421
end
1423
1422
@@ -1431,9 +1430,9 @@ function compile_xla(f, args; client=nothing, kwargs...)
1431
1430
@ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
1432
1431
1433
1432
if client != = nothing
1434
- backend = XLA. ClientGetPlatformName (client)
1433
+ backend = XLA. platform_name (client)
1435
1434
else
1436
- backend = XLA. ClientGetPlatformName (XLA. default_backend[])
1435
+ backend = XLA. platform_name (XLA. default_backend[])
1437
1436
end
1438
1437
if backend == " CUDA"
1439
1438
backend = " GPU"
@@ -1461,17 +1460,21 @@ function compile_xla(f, args; client=nothing, kwargs...)
1461
1460
)
1462
1461
1463
1462
# compile MLIR module to XLA executable
1464
- device_ids = mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh) : Int64[]
1463
+ local_device_ids = if mlir_fn_res. is_sharded
1464
+ collect (Int64, mlir_fn_res. sharding_mesh. sorted_device_ids)
1465
+ else
1466
+ Int64[]
1467
+ end
1465
1468
mlir_fn_res. is_sharded && (device = nothing )
1466
1469
1467
- exec = XLA. Compile (
1470
+ exec = XLA. compile (
1468
1471
client,
1469
1472
device,
1470
1473
mod;
1471
1474
num_outputs= length (mlir_fn_res. linear_results),
1472
1475
num_parameters= length (mlir_fn_res. linear_args),
1473
1476
mlir_fn_res. is_sharded,
1474
- device_ids ,
1477
+ local_device_ids ,
1475
1478
)
1476
1479
1477
1480
return mod, exec, mlir_fn_res, device, client
@@ -1514,7 +1517,7 @@ function compile(f, args; sync=false, kwargs...)
1514
1517
donated_args_mask,
1515
1518
length (linear_results),
1516
1519
mlir_fn_res. is_sharded,
1517
- mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh) : Int64[] ,
1520
+ mlir_fn_res. is_sharded ? length (mlir_fn_res. sharding_mesh) : 1 ,
1518
1521
)
1519
1522
1520
1523
linear_result_shard_info = if mlir_fn_res. is_sharded
0 commit comments