Skip to content

Commit ad69949

Browse files
committed
refactor: remove BasicDevicesList
1 parent 1285870 commit ad69949

File tree

3 files changed

+28
-78
lines changed

3 files changed

+28
-78
lines changed

deps/ReactantExtra/API.cpp

+13-28
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
// IFRT
8383
#include "xla/python/ifrt/array.h"
8484
#include "xla/python/ifrt/attribute_map.h"
85-
#include "xla/python/ifrt/basic_device_list.h"
8685
#include "xla/python/ifrt/client.h"
8786
#include "xla/python/ifrt/compiler.h"
8887
#include "xla/python/ifrt/device.h"
@@ -1800,26 +1799,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
18001799
return device->client();
18011800
}
18021801

1803-
extern "C" HeldValue<tsl::RCReference<ifrt::DeviceList>> *
1804-
ifrt_CreateBasicDeviceListFromDevices(ifrt::Device **device_list,
1805-
int32_t num_devices) {
1802+
tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices(
1803+
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
18061804
absl::Span<ifrt::Device *const> devices(device_list, num_devices);
1807-
return reactant::capture(ifrt::BasicDeviceList::Create(devices));
1808-
}
1809-
1810-
extern "C" const char *ifrt_BasicDeviceListToString(
1811-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1812-
return cstr_from_string(device_list->obj()->DebugString());
1813-
}
1814-
1815-
extern "C" int ifrt_BasicDeviceListSize(
1816-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1817-
return device_list->obj()->size();
1818-
}
1819-
1820-
extern "C" ifrt::Device *const ifrt_BasicDeviceListGetDevice(
1821-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, int32_t index) {
1822-
return device_list->obj()->devices()[index];
1805+
return client->MakeDeviceList(devices);
18231806
}
18241807

18251808
extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) {
@@ -1889,10 +1872,11 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
18891872
}
18901873

18911874
extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding(
1892-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1875+
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices,
18931876
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1894-
return ifrt::HloSharding::Create(device_list->obj(), *memory_kind,
1895-
*xla_hlo_sharding)
1877+
return ifrt::HloSharding::Create(
1878+
ifrt_CreateDeviceListFromDevices(client, device_list, num_devices),
1879+
*memory_kind, *xla_hlo_sharding)
18961880
.release();
18971881
}
18981882

@@ -1919,12 +1903,13 @@ ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
19191903
}
19201904

19211905
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1922-
ifrt_sharding_from_hlo_sharding(
1923-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1924-
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1906+
ifrt_sharding_from_hlo_sharding(ifrt::Client *client,
1907+
ifrt::Device **device_list, int32_t num_devices,
1908+
ifrt::MemoryKind *memory_kind,
1909+
xla::HloSharding *xla_hlo_sharding) {
19251910
return ifrt_sharding_from_ifrt_hlo_sharding(
1926-
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1927-
xla_hlo_sharding));
1911+
ifrt_hlo_sharding_from_xla_hlo_sharding(client, device_list, num_devices,
1912+
memory_kind, xla_hlo_sharding));
19281913
}
19291914

19301915
extern "C" const char *

src/xla/IFRT/Device.jl

+6-46
Original file line numberDiff line numberDiff line change
@@ -55,52 +55,6 @@ function XLA.memories(device::Device)
5555
return memories
5656
end
5757

58-
# Device List
59-
## TODO: This is semi-deprecated in openxla. At some point we want to just replace this with
60-
## a simple vector of devices
61-
struct BasicDeviceList <: AbstractVector{Device}
62-
ptr::Ptr{Cvoid}
63-
64-
function BasicDeviceList(devices::AbstractVector{Device})
65-
GC.@preserve devices begin
66-
ptr = @ccall MLIR.API.mlir_c.ifrt_CreateBasicDeviceListFromDevices(
67-
[d.device for d in devices]::Ptr{Ptr{Cvoid}}, length(devices)::Int32
68-
)::Ptr{Cvoid}
69-
end
70-
return new(ptr)
71-
end
72-
end
73-
74-
function Base.getindex(device_list::BasicDeviceList, index::Integer)
75-
if !(1 index length(device_list))
76-
throw(BoundsError(device_list, index))
77-
end
78-
GC.@preserve device_list begin
79-
device_ptr = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListGetDevice(
80-
device_list.ptr::Ptr{Cvoid}, (index - 1)::Int32
81-
)::Ptr{Cvoid}
82-
end
83-
return Device(device_ptr)
84-
end
85-
86-
function Base.size(device_list::BasicDeviceList)
87-
GC.@preserve device_list begin
88-
len = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListSize(
89-
device_list.ptr::Ptr{Cvoid}
90-
)::Int32
91-
end
92-
return (len,)
93-
end
94-
95-
function Base.string(device_list::BasicDeviceList)
96-
GC.@preserve device_list begin
97-
str = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListToString(
98-
device_list.ptr::Ptr{Cvoid}
99-
)::Cstring
100-
end
101-
return XLA.unsafe_string_and_free(str)
102-
end
103-
10458
function XLA.default_memory(device_list::AbstractVector{Device})
10559
default_memories = XLA.default_memory.(device_list)
10660
default_memory_kinds = convert.(MemoryKind, default_memories)
@@ -109,3 +63,9 @@ function XLA.default_memory(device_list::AbstractVector{Device})
10963
end
11064
return first(default_memories)
11165
end
66+
67+
function XLA.client(device_list::AbstractVector{Device})
68+
clients = XLA.client.(device_list)
69+
@assert allequal(clients) "All devices must have the same client"
70+
return first(clients)
71+
end

src/xla/IFRT/Sharding.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@ function free_hlo_sharding(hlo_sharding::HloSharding)
1212
@ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid
1313
end
1414

15-
function HloSharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
15+
function HloSharding(
16+
device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding
17+
)
1618
default_memory_kind = convert(MemoryKind, XLA.default_memory(device_list))
17-
GC.@preserve device_list default_memory_kind xla_hlo_sharding begin
19+
client = XLA.client(device_list)
20+
GC.@preserve device_list default_memory_kind xla_hlo_sharding client begin
1821
return HloSharding(
1922
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding(
20-
device_list.ptr::Ptr{Cvoid},
23+
client.client::Ptr{Cvoid},
24+
[d.device for d in device_list]::Ptr{Ptr{Cvoid}},
25+
length(device_list)::Int32,
2126
default_memory_kind.ptr::Ptr{Cvoid},
2227
xla_hlo_sharding.ptr::Ptr{Cvoid},
2328
)::Ptr{Cvoid}
@@ -50,7 +55,7 @@ mutable struct Sharding
5055
end
5156
end
5257

53-
function Sharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
58+
function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding)
5459
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding))
5560
end
5661

0 commit comments

Comments
 (0)