@@ -74,10 +74,12 @@ function create_result(
74
74
return Expr (:new , T, elems... )
75
75
end
76
76
77
- function __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
78
- device_to_array_slices, partition_spec = path_to_shard_info[path]
77
+ function __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh, N :: Integer )
78
+ device_to_array_slices, hlo_sharding = path_to_shard_info[path]
79
79
delete! (path_to_shard_info, path)
80
- sharding = Reactant. Sharding. NamedSharding (sharding_mesh, partition_spec)
80
+ sharding = Reactant. Sharding. HloSharding (
81
+ hlo_sharding, sharding_mesh, ntuple (Returns (true ), N), ntuple (Returns (- 1 ), N)
82
+ )
81
83
return Reactant. Sharding. ShardInfo (sharding, device_to_array_slices)
82
84
end
83
85
@@ -88,7 +90,9 @@ function create_result(
88
90
restore = result_stores[path]
89
91
delete! (result_stores, path)
90
92
if path_to_shard_info != = nothing # restore sharding
91
- sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
93
+ sharding = __reconstruct_shardinfo (
94
+ path, path_to_shard_info, sharding_mesh, ndims (tocopy)
95
+ )
92
96
return :(ConcreteRNumber {$T,length($(restore)),$(typeof(sharding))} (
93
97
($ (restore). .. ,), $ sharding
94
98
))
@@ -98,7 +102,9 @@ function create_result(
98
102
end
99
103
100
104
if path_to_shard_info != = nothing # restore sharding
101
- sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
105
+ sharding = __reconstruct_shardinfo (
106
+ path, path_to_shard_info, sharding_mesh, ndims (tocopy)
107
+ )
102
108
return :(ConcreteRNumber {$T,length($(tocopy.data)),$(typeof(sharding))} (
103
109
($ (tocopy. data... ,)), $ sharding
104
110
))
@@ -114,7 +120,9 @@ function create_result(
114
120
restore = result_stores[path]
115
121
delete! (result_stores, path)
116
122
if path_to_shard_info != = nothing # restore sharding
117
- sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
123
+ sharding = __reconstruct_shardinfo (
124
+ path, path_to_shard_info, sharding_mesh, ndims (tocopy)
125
+ )
118
126
return :(ConcreteRArray {$T,$N,length($(restore)),$(typeof(sharding))} (
119
127
($ (restore). .. ,), $ (tocopy. shape), $ sharding
120
128
))
@@ -124,7 +132,9 @@ function create_result(
124
132
end
125
133
126
134
if path_to_shard_info != = nothing # restore sharding
127
- sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
135
+ sharding = __reconstruct_shardinfo (
136
+ path, path_to_shard_info, sharding_mesh, ndims (tocopy)
137
+ )
128
138
return :(ConcreteRArray {$T,$N,length($(tocopy.data)),$(typeof(sharding))} (
129
139
($ (tocopy. data). .. ,), $ (tocopy. shape), $ sharding
130
140
))
@@ -477,11 +487,8 @@ function compile_mlir(f, args; client=nothing, kwargs...)
477
487
context_gc_vector[ctx] = Vector {TracedRArray} (undef, 0 )
478
488
@ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
479
489
480
- if client != = nothing
481
- backend = XLA. platform_name (client)
482
- else
483
- backend = XLA. platform_name (XLA. default_backend[])
484
- end
490
+ backend = XLA. platform_name (client != = nothing ? client : XLA. default_backend ())
491
+
485
492
if backend == " CUDA"
486
493
backend = " GPU"
487
494
elseif backend == " CPU"
@@ -493,13 +500,6 @@ function compile_mlir(f, args; client=nothing, kwargs...)
493
500
494
501
mlir_fn_res = compile_mlir! (mod, f, args; backend, kwargs... )
495
502
496
- client, _ = __resolve_device_and_client (
497
- client,
498
- mlir_fn_res. seen_args,
499
- mlir_fn_res. linear_args,
500
- mlir_fn_res. is_sharded,
501
- )
502
-
503
503
# Attach a name, and partitioning attributes to the module
504
504
__add_mhlo_attributes_and_name! (
505
505
mod, f; mlir_fn_res. num_partitions, mlir_fn_res. num_replicas
@@ -1079,7 +1079,6 @@ function codegen_flatten!(
1079
1079
1080
1080
if is_sharded
1081
1081
carg = inv_seen_args[arg]
1082
- device_ids = mesh. sorted_device_ids
1083
1082
if Reactant. Sharding. is_sharded (carg)
1084
1083
# Currently disabling the error since we roundtrip from MHLO to generate
1085
1084
# the shardings
@@ -1091,7 +1090,7 @@ function codegen_flatten!(
1091
1090
1092
1091
push! (flatten_code, :($ usbuf = $ flatcode. data))
1093
1092
for j in 1 : length (mesh)
1094
- sbuf = Symbol (:sbuf_ , i, " _" , device_ids[j])
1093
+ sbuf = Symbol (:sbuf_ , i, " _" , mesh . device_ids[j])
1095
1094
push! (flatten_names, sbuf)
1096
1095
push! (flatten_code, :($ sbuf = XLA. synced_buffer (getindex ($ usbuf, $ j))))
1097
1096
end
@@ -1101,18 +1100,18 @@ function codegen_flatten!(
1101
1100
)
1102
1101
push! (flatten_code, :($ usbuf = $ flatcode))
1103
1102
device_to_array_slices = XLA. sharding_to_concrete_array_indices (
1104
- condensed_op_sharding, size (carg), mesh
1103
+ condensed_op_sharding, size (carg), mesh. device_ids
1105
1104
)
1106
1105
for j in 1 : length (mesh)
1107
- local_device_id = device_ids[j]
1108
- buf = Symbol (:buf_ , i, :_ , local_device_id )
1106
+ device_id = mesh . device_ids[j]
1107
+ buf = Symbol (:buf_ , i, :_ , device_id )
1109
1108
slice = device_to_array_slices[j]
1110
1109
push! (
1111
1110
flatten_code,
1112
1111
:($ buf = XLA. synced_buffer (only ($ usbuf[$ (slice). .. ]. data))),
1113
1112
)
1114
- sbuf = Symbol (:sbuf_ , i, :_ , local_device_id )
1115
- device = XLA. get_addressable_device (client, local_device_id )
1113
+ sbuf = Symbol (:s , buf )
1114
+ device = XLA. get_device (client, device_id )
1116
1115
push! (flatten_names, sbuf)
1117
1116
push! (flatten_code, :($ sbuf = XLA. copy_buffer_to_device ($ buf, $ device)))
1118
1117
end
@@ -1386,7 +1385,7 @@ end
1386
1385
1387
1386
function __resolve_device_and_client (client, seen_args, linear_args, is_sharded)
1388
1387
if is_sharded
1389
- client === nothing && (client = XLA. default_backend[] )
1388
+ client === nothing && (client = XLA. default_backend () )
1390
1389
return client, nothing
1391
1390
end
1392
1391
@@ -1412,14 +1411,14 @@ function __resolve_device_and_client(client, seen_args, linear_args, is_sharded)
1412
1411
if device != = nothing
1413
1412
client = XLA. client (device)
1414
1413
else
1415
- client = XLA. default_backend[]
1416
- device = XLA. get_addressable_device (client, XLA . default_device_idx[] )
1414
+ client = XLA. default_backend ()
1415
+ device = XLA. default_device (client)
1417
1416
end
1418
1417
else
1419
1418
if device != = nothing
1420
1419
@assert client == XLA. client (device) " client ($(client) ) and XLA.client(device) ($(XLA. client (device)) ) must be the same"
1421
1420
else
1422
- device = XLA. get_addressable_device (client, XLA . default_device_idx[] )
1421
+ device = XLA. default_device (client)
1423
1422
end
1424
1423
end
1425
1424
@@ -1432,11 +1431,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
1432
1431
context_gc_vector[ctx] = Vector {TracedRArray} (undef, 0 )
1433
1432
@ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
1434
1433
1435
- if client != = nothing
1436
- backend = XLA. platform_name (client)
1437
- else
1438
- backend = XLA. platform_name (XLA. default_backend[])
1439
- end
1434
+ backend = XLA. platform_name (client != = nothing ? client : XLA. default_backend ())
1435
+
1440
1436
if backend == " CUDA"
1441
1437
backend = " GPU"
1442
1438
elseif backend == " CPU"
@@ -1463,8 +1459,8 @@ function compile_xla(f, args; client=nothing, kwargs...)
1463
1459
)
1464
1460
1465
1461
# compile MLIR module to XLA executable
1466
- local_device_ids = if mlir_fn_res. is_sharded
1467
- collect (Int64, mlir_fn_res. sharding_mesh. sorted_device_ids )
1462
+ global_device_ids = if mlir_fn_res. is_sharded
1463
+ collect (Int64, mlir_fn_res. sharding_mesh. device_ids )
1468
1464
else
1469
1465
Int64[]
1470
1466
end
@@ -1477,7 +1473,9 @@ function compile_xla(f, args; client=nothing, kwargs...)
1477
1473
num_outputs= length (mlir_fn_res. linear_results),
1478
1474
num_parameters= length (mlir_fn_res. linear_args),
1479
1475
mlir_fn_res. is_sharded,
1480
- local_device_ids,
1476
+ global_device_ids,
1477
+ mlir_fn_res. num_replicas,
1478
+ mlir_fn_res. num_partitions,
1481
1479
)
1482
1480
1483
1481
return mod, exec, mlir_fn_res, device, client
@@ -1525,10 +1523,10 @@ function compile(f, args; sync=false, kwargs...)
1525
1523
1526
1524
linear_result_shard_info = if mlir_fn_res. is_sharded
1527
1525
output_shardings = XLA. get_output_shardings (exec)
1528
- XLA. compute_array_indices_and_partition_spec .(
1526
+ XLA. compute_array_indices_and_hlo_sharding .(
1529
1527
output_shardings,
1530
1528
size .(mlir_fn_res. linear_results),
1531
- (mlir_fn_res. sharding_mesh,),
1529
+ (mlir_fn_res. sharding_mesh. logical_device_ids ,),
1532
1530
)
1533
1531
else
1534
1532
ntuple (Returns (nothing ), length (linear_results))
0 commit comments