Skip to content

Commit 044e878

Browse files
committed
fix: use global device ids
1 parent b44c2a3 commit 044e878

File tree

5 files changed

+15
-21
lines changed

5 files changed

+15
-21
lines changed

Diff for: src/Compiler.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -1110,15 +1110,15 @@ function codegen_flatten!(
11101110
condensed_op_sharding, size(carg), mesh
11111111
)
11121112
for j in 1:length(mesh)
1113-
local_device_id = mesh.device_ids[j]
1114-
buf = Symbol(:buf_, i, :_, local_device_id)
1113+
device_id = mesh.device_ids[j]
1114+
buf = Symbol(:buf_, i, :_, device_id)
11151115
slice = device_to_array_slices[j]
11161116
push!(
11171117
flatten_code,
11181118
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
11191119
)
1120-
sbuf = Symbol(:sbuf_, i, :_, local_device_id)
1121-
device = XLA.get_addressable_device(client, local_device_id)
1120+
sbuf = Symbol(:s, buf)
1121+
device = XLA.get_device(client, device_id)
11221122
push!(flatten_names, sbuf)
11231123
push!(flatten_code, :($sbuf = XLA.copy_buffer_to_device($buf, $device)))
11241124
end
@@ -1466,7 +1466,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
14661466
)
14671467

14681468
# compile MLIR module to XLA executable
1469-
local_device_ids = if mlir_fn_res.is_sharded
1469+
global_device_ids = if mlir_fn_res.is_sharded
14701470
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
14711471
else
14721472
Int64[]
@@ -1480,7 +1480,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
14801480
num_outputs=length(mlir_fn_res.linear_results),
14811481
num_parameters=length(mlir_fn_res.linear_args),
14821482
mlir_fn_res.is_sharded,
1483-
local_device_ids,
1483+
global_device_ids,
14841484
)
14851485

14861486
return mod, exec, mlir_fn_res, device, client

Diff for: src/Sharding.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ struct Mesh{D,ND}
2828
shape::Dims{D}
2929
axis_names::NTuple{D,Symbol}
3030

31-
function Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names)
32-
return Mesh(XLA.get_local_device_id.(devices), axis_names)
31+
function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names)
32+
return Mesh(XLA.device_ordinal.(devices), axis_names)
3333
end
3434

3535
function Mesh(
36-
devices::NTuple{D,XLA.AbstractDevice}, shape::Dims{D}, axis_names
36+
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
3737
) where {D}
38-
return Mesh(XLA.get_local_device_id.(devices), shape, axis_names)
38+
return Mesh(XLA.device_ordinal.(devices), shape, axis_names)
3939
end
4040

4141
function Mesh(
@@ -287,7 +287,7 @@ function (sharding::HloSharding)(
287287
XLA.PJRT.AsyncBuffer(
288288
client,
289289
x[device_to_array_slices[i]...],
290-
XLA.get_addressable_device(client, sharding.mesh.device_ids[i]),
290+
XLA.get_device(client, sharding.mesh.device_ids[i]),
291291
)
292292
end
293293

Diff for: src/Types.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ function ConcreteRArray(
123123
if idx === nothing
124124
device = XLA.default_device(client)
125125
else
126-
device = XLA.get_addressable_device(client, idx)
126+
device = XLA.get_device(client, idx)
127127
end
128128
else
129129
if idx !== nothing
130-
device_from_idx = XLA.get_addressable_device(client, idx)
130+
device_from_idx = XLA.get_device(client, idx)
131131
@assert device_from_idx == device "If both `idx` and `device` are \
132132
specified, `idx` must match `device`"
133133
end

Diff for: src/xla/Device.jl

+1-6
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,11 @@ function memories end
1313

1414
"""
1515
device_ordinal(device::Device)
16-
device_ordinal(client::XLA.AbstractClient, local_device_id::Int)
1716
18-
Given the device or local device id, return the corresponding global device ordinal in the client.
17+
Given the device, return the corresponding global device ordinal in the client.
1918
"""
2019
function device_ordinal end
2120

22-
function device_ordinal(client::AbstractClient, local_device_id::Integer)
23-
return device_ordinal(get_addressable_device(client, local_device_id))
24-
end
25-
2621
function Base.string(device::AbstractDevice)
2722
client = XLA.client(device)
2823
pname = XLA.platform_name(client)

Diff for: src/xla/PJRT/LoadedExecutable.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,11 @@ function XLA.compile(
7575
device::Union{Device,Nothing},
7676
mod::MLIR.IR.Module;
7777
is_sharded::Bool=false,
78-
local_device_ids::Vector{Int64}=Int64[],
78+
global_device_ids::Vector{Int64}=Int64[],
7979
num_outputs::Int64,
8080
num_parameters::Int64,
8181
)
8282
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
83-
global_device_ids = Int64.(XLA.device_ordinal.((client,), local_device_ids))
8483
GC.@preserve client mod begin
8584
exec = @ccall MLIR.API.mlir_c.ClientCompile(
8685
client.client::Ptr{Cvoid},

0 commit comments

Comments
 (0)