Skip to content

Commit f9233b7

Browse files
committed
fix: use global device ids
1 parent ad69949 commit f9233b7

File tree

5 files changed

+16
-22
lines changed

5 files changed

+16
-22
lines changed

src/Compiler.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -1101,15 +1101,15 @@ function codegen_flatten!(
11011101
condensed_op_sharding, size(carg), mesh
11021102
)
11031103
for j in 1:length(mesh)
1104-
local_device_id = device_ids[j]
1105-
buf = Symbol(:buf_, i, :_, local_device_id)
1104+
device_id = device_ids[j]
1105+
buf = Symbol(:buf_, i, :_, device_id)
11061106
slice = device_to_array_slices[j]
11071107
push!(
11081108
flatten_code,
11091109
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
11101110
)
1111-
sbuf = Symbol(:sbuf_, i, :_, local_device_id)
1112-
device = XLA.get_addressable_device(client, local_device_id)
1111+
sbuf = Symbol(:s, buf)
1112+
device = XLA.get_device(client, device_id)
11131113
push!(flatten_names, sbuf)
11141114
push!(flatten_code, :($sbuf = XLA.copy_buffer_to_device($buf, $device)))
11151115
end
@@ -1457,7 +1457,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
14571457
)
14581458

14591459
# compile MLIR module to XLA executable
1460-
local_device_ids = if mlir_fn_res.is_sharded
1460+
global_device_ids = if mlir_fn_res.is_sharded
14611461
collect(Int64, mlir_fn_res.sharding_mesh.sorted_device_ids)
14621462
else
14631463
Int64[]
@@ -1471,7 +1471,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
14711471
num_outputs=length(mlir_fn_res.linear_results),
14721472
num_parameters=length(mlir_fn_res.linear_args),
14731473
mlir_fn_res.is_sharded,
1474-
local_device_ids,
1474+
global_device_ids,
14751475
)
14761476

14771477
return mod, exec, mlir_fn_res, device, client

src/Sharding.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ struct Mesh{D,ND}
88
shape::Dims{D}
99
axis_names::NTuple{D,Symbol}
1010

11-
function Mesh(devices::AbstractArray{XLA.AbstractDevice}, axis_names)
12-
return Mesh(XLA.get_local_device_id.(devices), axis_names)
11+
function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names)
12+
return Mesh(XLA.device_ordinal.(devices), axis_names)
1313
end
1414

1515
function Mesh(
16-
devices::NTuple{D,XLA.AbstractDevice}, shape::Dims{D}, axis_names
16+
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
1717
) where {D}
18-
return Mesh(XLA.get_local_device_id.(devices), shape, axis_names)
18+
return Mesh(XLA.device_ordinal.(devices), shape, axis_names)
1919
end
2020

2121
function Mesh(
@@ -114,7 +114,7 @@ function (sharding::NamedSharding)(
114114
XLA.PJRT.AsyncBuffer(
115115
client,
116116
x[device_to_array_slices[i]...],
117-
XLA.get_addressable_device(client, mesh.sorted_device_ids[i]),
117+
XLA.get_device(client, mesh.sorted_device_ids[i]),
118118
)
119119
end
120120

@@ -199,7 +199,7 @@ function (sharding::LazySharding)(
199199
client::XLA.PJRT.Client, ::Nothing, x::Union{AbstractArray,Number}
200200
)
201201
data = XLA.PJRT.AsyncBuffer(
202-
client, x, XLA.get_addressable_device(client, vec(sharding.sharding.mesh)[1])
202+
client, x, XLA.get_device(client, vec(sharding.sharding.mesh)[1])
203203
)
204204

205205
return (data,), ShardInfo(sharding, (ntuple(i -> 1:size(x, i), ndims(x)),))

src/Types.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ function ConcreteRArray(
123123
if idx === nothing
124124
device = XLA.default_device(client)
125125
else
126-
device = XLA.get_addressable_device(client, idx)
126+
device = XLA.get_device(client, idx)
127127
end
128128
else
129129
if idx !== nothing
130-
device_from_idx = XLA.get_addressable_device(client, idx)
130+
device_from_idx = XLA.get_device(client, idx)
131131
@assert device_from_idx == device "If both `idx` and `device` are \
132132
specified, `idx` must match `device`"
133133
end

src/xla/Device.jl

+1-6
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,11 @@ function memories end
1313

1414
"""
1515
device_ordinal(device::Device)
16-
device_ordinal(client::XLA.AbstractClient, local_device_id::Int)
1716
18-
Given the device or local device id, return the corresponding global device ordinal in the client.
17+
Given the device, return the corresponding global device ordinal in the client.
1918
"""
2019
function device_ordinal end
2120

22-
function device_ordinal(client::AbstractClient, local_device_id::Integer)
23-
return device_ordinal(get_addressable_device(client, local_device_id))
24-
end
25-
2621
function Base.string(device::AbstractDevice)
2722
client = XLA.client(device)
2823
pname = XLA.platform_name(client)

src/xla/PJRT/LoadedExecutable.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,11 @@ function XLA.compile(
8181
device::Union{Device,Nothing},
8282
mod::MLIR.IR.Module;
8383
is_sharded::Bool=false,
84-
local_device_ids::Vector{Int64}=Int64[],
84+
global_device_ids::Vector{Int64}=Int64[],
8585
num_outputs::Int64,
8686
num_parameters::Int64,
8787
)
8888
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
89-
global_device_ids = Int64.(XLA.device_ordinal.((client,), local_device_ids))
9089
GC.@preserve client mod begin
9190
exec = @ccall MLIR.API.mlir_c.ClientCompile(
9291
client.client::Ptr{Cvoid},

0 commit comments

Comments
 (0)