Skip to content

Commit 141b468

Browse files
committed
refactor: remove BasicDevicesList
1 parent 253aff1 commit 141b468

File tree

3 files changed

+28
-78
lines changed

3 files changed

+28
-78
lines changed

deps/ReactantExtra/API.cpp

+13-28
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181

8282
// IFRT
8383
#include "xla/python/ifrt/array.h"
84-
#include "xla/python/ifrt/basic_device_list.h"
8584
#include "xla/python/ifrt/client.h"
8685
#include "xla/python/ifrt/compiler.h"
8786
#include "xla/python/ifrt/device.h"
@@ -1799,26 +1798,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
17991798
return device->client();
18001799
}
18011800

1802-
extern "C" HeldValue<tsl::RCReference<ifrt::DeviceList>> *
1803-
ifrt_CreateBasicDeviceListFromDevices(ifrt::Device **device_list,
1804-
int32_t num_devices) {
1801+
tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices(
1802+
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
18051803
absl::Span<ifrt::Device *const> devices(device_list, num_devices);
1806-
return reactant::capture(ifrt::BasicDeviceList::Create(devices));
1807-
}
1808-
1809-
extern "C" const char *ifrt_BasicDeviceListToString(
1810-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1811-
return cstr_from_string(device_list->obj()->DebugString());
1812-
}
1813-
1814-
extern "C" int ifrt_BasicDeviceListSize(
1815-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1816-
return device_list->obj()->size();
1817-
}
1818-
1819-
extern "C" ifrt::Device *const ifrt_BasicDeviceListGetDevice(
1820-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, int32_t index) {
1821-
return device_list->obj()->devices()[index];
1804+
return client->MakeDeviceList(devices);
18221805
}
18231806

18241807
extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) {
@@ -1888,10 +1871,11 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
18881871
}
18891872

18901873
extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding(
1891-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1874+
ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices,
18921875
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1893-
return ifrt::HloSharding::Create(device_list->obj(), *memory_kind,
1894-
*xla_hlo_sharding)
1876+
return ifrt::HloSharding::Create(
1877+
ifrt_CreateDeviceListFromDevices(client, device_list, num_devices),
1878+
*memory_kind, *xla_hlo_sharding)
18951879
.release();
18961880
}
18971881

@@ -1918,12 +1902,13 @@ ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
19181902
}
19191903

19201904
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1921-
ifrt_sharding_from_hlo_sharding(
1922-
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1923-
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1905+
ifrt_sharding_from_hlo_sharding(ifrt::Client *client,
1906+
ifrt::Device **device_list, int32_t num_devices,
1907+
ifrt::MemoryKind *memory_kind,
1908+
xla::HloSharding *xla_hlo_sharding) {
19241909
return ifrt_sharding_from_ifrt_hlo_sharding(
1925-
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1926-
xla_hlo_sharding));
1910+
ifrt_hlo_sharding_from_xla_hlo_sharding(client, device_list, num_devices,
1911+
memory_kind, xla_hlo_sharding));
19271912
}
19281913

19291914
extern "C" const char *

src/xla/IFRT/Device.jl

+6-46
Original file line numberDiff line numberDiff line change
@@ -55,52 +55,6 @@ function XLA.memories(device::Device)
5555
return memories
5656
end
5757

58-
# Device List
59-
## TODO: This is semi-deprecated in openxla. At some point we want to just replace this with
60-
## a simple vector of devices
61-
struct BasicDeviceList <: AbstractVector{Device}
62-
ptr::Ptr{Cvoid}
63-
64-
function BasicDeviceList(devices::AbstractVector{Device})
65-
GC.@preserve devices begin
66-
ptr = @ccall MLIR.API.mlir_c.ifrt_CreateBasicDeviceListFromDevices(
67-
[d.device for d in devices]::Ptr{Ptr{Cvoid}}, length(devices)::Int32
68-
)::Ptr{Cvoid}
69-
end
70-
return new(ptr)
71-
end
72-
end
73-
74-
function Base.getindex(device_list::BasicDeviceList, index::Integer)
75-
if !(1 index length(device_list))
76-
throw(BoundsError(device_list, index))
77-
end
78-
GC.@preserve device_list begin
79-
device_ptr = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListGetDevice(
80-
device_list.ptr::Ptr{Cvoid}, (index - 1)::Int32
81-
)::Ptr{Cvoid}
82-
end
83-
return Device(device_ptr)
84-
end
85-
86-
function Base.size(device_list::BasicDeviceList)
87-
GC.@preserve device_list begin
88-
len = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListSize(
89-
device_list.ptr::Ptr{Cvoid}
90-
)::Int32
91-
end
92-
return (len,)
93-
end
94-
95-
function Base.string(device_list::BasicDeviceList)
96-
GC.@preserve device_list begin
97-
str = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListToString(
98-
device_list.ptr::Ptr{Cvoid}
99-
)::Cstring
100-
end
101-
return XLA.unsafe_string_and_free(str)
102-
end
103-
10458
function XLA.default_memory(device_list::AbstractVector{Device})
10559
default_memories = XLA.default_memory.(device_list)
10660
default_memory_kinds = convert.(MemoryKind, default_memories)
@@ -109,3 +63,9 @@ function XLA.default_memory(device_list::AbstractVector{Device})
10963
end
11064
return first(default_memories)
11165
end
66+
67+
function XLA.client(device_list::AbstractVector{Device})
68+
clients = XLA.client.(device_list)
69+
@assert allequal(clients) "All devices must have the same client"
70+
return first(clients)
71+
end

src/xla/IFRT/Sharding.jl

+9-4
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,17 @@ function free_hlo_sharding(hlo_sharding::HloSharding)
1212
@ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid
1313
end
1414

15-
function HloSharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
15+
function HloSharding(
16+
device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding
17+
)
1618
default_memory_kind = convert(MemoryKind, XLA.default_memory(device_list))
17-
GC.@preserve device_list default_memory_kind xla_hlo_sharding begin
19+
client = XLA.client(device_list)
20+
GC.@preserve device_list default_memory_kind xla_hlo_sharding client begin
1821
return HloSharding(
1922
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding(
20-
device_list.ptr::Ptr{Cvoid},
23+
client.client::Ptr{Cvoid},
24+
[d.device for d in device_list]::Ptr{Ptr{Cvoid}},
25+
length(device_list)::Int32,
2126
default_memory_kind.ptr::Ptr{Cvoid},
2227
xla_hlo_sharding.ptr::Ptr{Cvoid},
2328
)::Ptr{Cvoid}
@@ -50,7 +55,7 @@ mutable struct Sharding
5055
end
5156
end
5257

53-
function Sharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
58+
function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding)
5459
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding))
5560
end
5661

0 commit comments

Comments
 (0)