Skip to content

Commit f468e60

Browse files
committed
feat: IFRT Device API
1 parent 2f840fa commit f468e60

File tree

5 files changed

+189
-39
lines changed

5 files changed

+189
-39
lines changed

deps/ReactantExtra/API.cpp

+79-30
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,10 @@ extern "C" const char *ClientGetPlatformName(PjRtClient *client) {
421421
return cstr_from_string(client->platform_name());
422422
}
423423

424+
extern "C" const char *DeviceGetKind(PjRtDevice *device) {
425+
return cstr_from_string(device->device_kind());
426+
}
427+
424428
// To keep in sync with JLAllocatorStats in src/XLA.jl
425429
struct JLAllocatorStats {
426430
int64_t num_allocs;
@@ -1258,36 +1262,6 @@ reactant_release_pjrtbuffer(HeldValue<std::shared_ptr<PjRtBuffer>> *buffer) {
12581262
delete buffer;
12591263
}
12601264

1261-
extern "C" ifrt::Client *
1262-
ifrt_pjrt_MakeClient(HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1263-
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1264-
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1265-
}
1266-
1267-
extern "C" ifrt::Client *MakeCPUIfrtClient(uint8_t asynchronous, int node_id,
1268-
int num_nodes) {
1269-
return ifrt_pjrt_MakeClient(reactant_hold_pjrtclient(
1270-
MakeCPUClient(asynchronous, node_id, num_nodes)));
1271-
}
1272-
1273-
extern "C" ifrt::Client *
1274-
MakeGPUIfrtClient(int node_id, int num_nodes, int *allowed_devices,
1275-
int num_allowed_devices, double memory_fraction,
1276-
bool preallocate, const char *platform_name,
1277-
const char **error) {
1278-
return ifrt_pjrt_MakeClient(reactant_hold_pjrtclient(
1279-
MakeGPUClient(node_id, num_nodes, allowed_devices, num_allowed_devices,
1280-
memory_fraction, preallocate, platform_name, error)));
1281-
}
1282-
1283-
extern "C" ifrt::Client *MakeTPUIfrtClient(const char *tpu_path,
1284-
const char **error) {
1285-
return ifrt_pjrt_MakeClient(
1286-
reactant_hold_pjrtclient(MakeTPUClient(tpu_path, error)));
1287-
}
1288-
1289-
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
1290-
12911265
extern "C" xla::ifrt::LoadedExecutable *
12921266
ifrt_ClientCompile(ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
12931267
bool is_sharded, const int64_t *mesh_ids,
@@ -1399,6 +1373,8 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
13991373
delete hlo_module;
14001374
}
14011375

1376+
#pragma region IfRtClient
1377+
14021378
// right now only making it available for TPU
14031379
// in the future, we would like this for CPU and GPU PjRt backends too
14041380
extern "C" ifrt::proxy::GrpcServer *
@@ -1469,6 +1445,79 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
14691445
.release();
14701446
}
14711447

1448+
extern "C" ifrt::Client *
1449+
ifrt_pjrt_MakeClient(HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1450+
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1451+
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1452+
}
1453+
1454+
extern "C" ifrt::Client *MakeCPUIfrtClient(uint8_t asynchronous, int node_id,
1455+
int num_nodes) {
1456+
return ifrt_pjrt_MakeClient(reactant_hold_pjrtclient(
1457+
MakeCPUClient(asynchronous, node_id, num_nodes)));
1458+
}
1459+
1460+
extern "C" ifrt::Client *
1461+
MakeGPUIfrtClient(int node_id, int num_nodes, int *allowed_devices,
1462+
int num_allowed_devices, double memory_fraction,
1463+
bool preallocate, const char *platform_name,
1464+
const char **error) {
1465+
return ifrt_pjrt_MakeClient(reactant_hold_pjrtclient(
1466+
MakeGPUClient(node_id, num_nodes, allowed_devices, num_allowed_devices,
1467+
memory_fraction, preallocate, platform_name, error)));
1468+
}
1469+
1470+
extern "C" ifrt::Client *MakeTPUIfrtClient(const char *tpu_path,
1471+
const char **error) {
1472+
return ifrt_pjrt_MakeClient(
1473+
reactant_hold_pjrtclient(MakeTPUClient(tpu_path, error)));
1474+
}
1475+
1476+
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
1477+
1478+
extern "C" int ifrt_ClientNumDevices(ifrt::Client *client) {
1479+
return client->device_count();
1480+
}
1481+
1482+
extern "C" int ifrt_ClientNumAddressableDevices(ifrt::Client *client) {
1483+
return client->addressable_device_count();
1484+
}
1485+
1486+
extern "C" int ifrt_ClientProcessIndex(ifrt::Client *client) {
1487+
return client->process_index();
1488+
}
1489+
1490+
extern "C" const char *ifrt_ClientGetPlatformName(ifrt::Client *client) {
1491+
return cstr_from_string(client->platform_name());
1492+
}
1493+
1494+
extern "C" ifrt::Device *ifrt_ClientGetDevice(ifrt::Client *client, int idx) {
1495+
return MyValueOrThrow(client->LookupDevice(ifrt::DeviceId(idx)));
1496+
}
1497+
1498+
extern "C" ifrt::Device *ifrt_ClientGetAddressableDevice(ifrt::Client *client,
1499+
int idx) {
1500+
return MyValueOrThrow(client->LookupAddressableDevice(idx));
1501+
}
1502+
1503+
#pragma endregion
1504+
1505+
#pragma region IfRtDevice
1506+
1507+
extern "C" int64_t ifrt_DeviceGetGlobalDeviceId(ifrt::Device *device) {
1508+
return device->Id().value();
1509+
}
1510+
1511+
extern "C" const char *ifrt_DeviceGetKind(ifrt::Device *device) {
1512+
return cstr_from_string(device->Kind());
1513+
}
1514+
1515+
extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
1516+
return device->client();
1517+
}
1518+
1519+
#pragma endregion
1520+
14721521
#pragma region HloSharding
14731522

14741523
extern "C" void free_op_sharding(xla::OpSharding *op_sharding) {

src/xla/Device.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
abstract type AbstractDevice end
22

3-
function Base.show(io::IO, ::MIME"text/plain", device::D) where {D <: AbstractDevice}
4-
print(io, "$(parentmodule(D)).Device($(device.device), name=\"$(string(device))\")")
3+
function Base.show(io::IO, ::MIME"text/plain", device::D) where {D<:AbstractDevice}
4+
print(io, "$(parentmodule(D)).Device($(device.device), \"$(string(device))\")")
55
return nothing
66
end
77

88
function device end
99
function get_local_device_id end
10+
function device_kind end
1011

1112
"""
1213
device_ordinal(client::XLA.AbstractClient, device::Device)
@@ -15,3 +16,9 @@ function get_local_device_id end
1516
Given the device or local device id, return the corresponding global device ordinal in the client.
1617
"""
1718
function device_ordinal end
19+
20+
function Base.string(device::AbstractDevice)
21+
client = XLA.client(device)
22+
pname = XLA.platform_name(client)
23+
return "$(uppercase(pname)):$(device_ordinal(client, device)) $(device_kind(device))"
24+
end

src/xla/IFRT/Client.jl

+58-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,67 @@ mutable struct Client <: XLA.AbstractClient
33

44
function Client(client::Ptr{Cvoid})
55
@assert client != C_NULL
6-
# TODO: Finalizer
7-
return new(client)
6+
return finalizer(XLA.free_client, new(client))
87
end
98
end
109

10+
function XLA.free_client(client::Client)
11+
@ccall MLIR.API.mlir_c.ifrt_FreeClient(client.client::Ptr{Cvoid})::Cvoid
12+
end
13+
14+
function XLA.num_devices(client::Client)
15+
GC.@preserve client begin
16+
return @ccall MLIR.API.mlir_c.ifrt_ClientNumDevices(client.client::Ptr{Cvoid})::Cint
17+
end
18+
end
19+
20+
function XLA.num_addressable_devices(client::Client)
21+
GC.@preserve client begin
22+
return @ccall MLIR.API.mlir_c.ifrt_ClientNumAddressableDevices(
23+
client.client::Ptr{Cvoid}
24+
)::Cint
25+
end
26+
end
27+
28+
function XLA.process_index(client::Client)
29+
GC.@preserve client begin
30+
return @ccall MLIR.API.mlir_c.ifrt_ClientProcessIndex(
31+
client.client::Ptr{Cvoid}
32+
)::Cint
33+
end
34+
end
35+
36+
function XLA.get_device(client::Client, idx)
37+
GC.@preserve client begin
38+
return Device(
39+
@ccall MLIR.API.mlir_c.ifrt_ClientGetDevice(
40+
client.client::Ptr{Cvoid}, idx::Cint
41+
)::Ptr{Cvoid}
42+
)
43+
end
44+
end
45+
46+
function XLA.get_addressable_device(client::Client, idx)
47+
GC.@preserve client begin
48+
return Device(
49+
@ccall MLIR.API.mlir_c.ifrt_ClientGetAddressableDevice(
50+
client.client::Ptr{Cvoid}, idx::Cint
51+
)::Ptr{Cvoid}
52+
)
53+
end
54+
end
55+
56+
function XLA.platform_name(client::Client)
57+
GC.@preserve client begin
58+
str = @ccall MLIR.API.mlir_c.ifrt_ClientGetPlatformName(
59+
client.client::Ptr{Cvoid}
60+
)::Cstring
61+
end
62+
str_jl = unsafe_string(str)
63+
@ccall free(str::Cstring)::Cvoid
64+
return str_jl
65+
end
66+
1167
# Different Backends
1268
const cpu_client_count = Ref(0)
1369
const gpu_client_count = Ref(0)

src/xla/IFRT/Device.jl

+36
Original file line numberDiff line numberDiff line change
@@ -1 +1,37 @@
1+
struct Device <: XLA.AbstractDevice
2+
device::Ptr{Cvoid}
3+
end
14

5+
function XLA.client(device::Device)
6+
GC.@preserve device begin
7+
return Client(
8+
@ccall MLIR.API.mlir_c.ifrt_DeviceToClient(
9+
device.device::Ptr{Cvoid}
10+
)::Ptr{Cvoid}
11+
)
12+
end
13+
end
14+
15+
function XLA.device_ordinal(::Client, device::Device)
16+
GC.@preserve device begin
17+
return @ccall MLIR.API.mlir_c.ifrt_DeviceGetGlobalDeviceId(
18+
device.device::Ptr{Cvoid}
19+
)::Int64
20+
end
21+
end
22+
function XLA.device_ordinal(client::Client, device_id::Integer)
23+
return XLA.device_ordinal(client, XLA.get_addressable_device(client, device_id))
24+
end
25+
26+
function XLA.device_kind(device::Device)
27+
GC.@preserve device begin
28+
str = @ccall MLIR.API.mlir_c.ifrt_DeviceGetKind(device.device::Ptr{Cvoid})::Cstring
29+
end
30+
str_jl = unsafe_string(str)
31+
@ccall free(str::Cstring)::Cvoid
32+
return str_jl
33+
end
34+
35+
function XLA.get_local_device_id(::Device)
36+
return error("Not implemented for ifrt devices")
37+
end

src/xla/PJRT/Device.jl

+7-5
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,20 @@ function XLA.client(device::Device)
1010
end
1111
end
1212

13-
# TODO: Can be defined on the AbstractDevice?
1413
function XLA.device_ordinal(client::Client, device::Device)
1514
return XLA.device_ordinal(client, XLA.get_local_device_id(device))
1615
end
1716
function XLA.device_ordinal(client::Client, local_device_id::Integer)
1817
return client.global_ordinals[local_device_id + 1]
1918
end
2019

21-
function Base.string(device::Device)
22-
client = XLA.client(device)
23-
platform_name = XLA.platform_name(client)
24-
return "$(uppercase(platform_name)):$(XLA.device_ordinal(client, device))"
20+
function XLA.device_kind(device::Device)
21+
GC.@preserve device begin
22+
str = @ccall MLIR.API.mlir_c.DeviceGetKind(device.device::Ptr{Cvoid})::Cstring
23+
end
24+
str_jl = unsafe_string(str)
25+
@ccall free(str::Cstring)::Cvoid
26+
return str_jl
2527
end
2628

2729
function XLA.get_local_device_id(device::Device)

0 commit comments

Comments
 (0)