Skip to content

Commit 1285870

Browse files
committed
feat: construct IFRT clients with distributed options
1 parent 17675c2 commit 1285870

File tree

5 files changed

+355
-183
lines changed

5 files changed

+355
-183
lines changed

deps/ReactantExtra/API.cpp

+135-94
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181

8282
// IFRT
8383
#include "xla/python/ifrt/array.h"
84+
#include "xla/python/ifrt/attribute_map.h"
8485
#include "xla/python/ifrt/basic_device_list.h"
8586
#include "xla/python/ifrt/client.h"
8687
#include "xla/python/ifrt/compiler.h"
@@ -99,7 +100,6 @@
99100
#include "xla/python/ifrt/topology.h"
100101
#include "xla/python/ifrt/tuple.h"
101102
#include "xla/python/ifrt/value.h"
102-
#include "xla/python/ifrt/attribute_map.h"
103103

104104
// IFRT - PJRT
105105
#include "xla/python/pjrt_ifrt/pjrt_array.h"
@@ -343,14 +343,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
343343
delete server;
344344
}
345345

346-
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
347-
int num_nodes) {
346+
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
348347
CpuClientOptions options;
349-
// options.kv_store = "etcd";
348+
350349
options.process_id = node_id;
351-
// options.num_nodes = num_nodes;
352-
// options.collectives = num_nodes;
353350
options.asynchronous = asynchronous != 0;
351+
354352
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
355353
return client.release();
356354
}
@@ -1272,28 +1270,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
12721270
return wrap(entryFn);
12731271
}
12741272

1275-
extern "C" HeldPjRtClient *
1276-
pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) {
1277-
PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes);
1278-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1279-
}
1280-
1281-
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1282-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1283-
double memory_fraction, bool preallocate, const char *platform_name,
1284-
const char **error, void *distributed_runtime_client) {
1285-
PjRtClient *client = MakeGPUClient(
1286-
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1287-
preallocate, platform_name, error, distributed_runtime_client);
1288-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1289-
}
1290-
1291-
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1292-
const char **error) {
1293-
PjRtClient *client = MakeTPUClient(tpu_path, error);
1294-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1295-
}
1296-
12971273
extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; }
12981274

12991275
extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) {
@@ -1370,11 +1346,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
13701346
std::shared_ptr<PjRtClient>(buffer->ptr()->client()));
13711347
}
13721348

1373-
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) {
1374-
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1375-
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1376-
}
1377-
13781349
extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }
13791350

13801351
// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1576,59 +1547,65 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15761547

15771548
#pragma region IfRtClient
15781549

1579-
extern "C" ifrt::proxy::GrpcServer *
1580-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1581-
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
1582-
std::string address = c_address;
1583-
1584-
return MyValueOrThrow(
1585-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1586-
address,
1587-
[asynchronous, node_id, num_nodes]()
1588-
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1589-
auto pjrt_client = std::shared_ptr<PjRtClient>(
1590-
MakeCPUClient(asynchronous, node_id, num_nodes));
1591-
return std::shared_ptr<ifrt::Client>(
1592-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1593-
}))
1594-
.release();
1595-
}
1596-
1597-
extern "C" ifrt::proxy::GrpcServer *
1598-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1599-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1600-
double memory_fraction, bool preallocate, const char *platform_name,
1601-
const char **error) {
1602-
return MyValueOrThrow(
1603-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1604-
std::string(),
1605-
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1606-
memory_fraction, preallocate, platform_name,
1607-
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1608-
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1609-
node_id, num_nodes, allowed_devices, num_allowed_devices,
1610-
memory_fraction, preallocate, platform_name, error));
1611-
return std::shared_ptr<ifrt::Client>(
1612-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1613-
}))
1614-
.release();
1615-
}
1550+
// XXX: Bring back with the correct API
1551+
// extern "C" ifrt::proxy::GrpcServer *
1552+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1553+
// const char *c_address, uint8_t asynchronous, int node_id) {
1554+
// std::string address = c_address;
1555+
1556+
// return MyValueOrThrow(
1557+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1558+
// address,
1559+
// [asynchronous,
1560+
// node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>>
1561+
// {
1562+
// auto pjrt_client = std::shared_ptr<PjRtClient>(
1563+
// MakeCPUClient(asynchronous, node_id));
1564+
// return std::shared_ptr<ifrt::Client>(
1565+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1566+
// }))
1567+
// .release();
1568+
// }
16161569

1617-
extern "C" ifrt::proxy::GrpcServer *
1618-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1619-
const char *c_address, const char *tpu_path, const char **error) {
1620-
std::string address = c_address;
1570+
// extern "C" ifrt::proxy::GrpcServer *
1571+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1572+
// int node_id, int num_nodes, int *allowed_devices, int
1573+
// num_allowed_devices, double memory_fraction, bool preallocate, const char
1574+
// *platform_name, const char **error) {
1575+
// return MyValueOrThrow(
1576+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1577+
// std::string(),
1578+
// [node_id, num_nodes, allowed_devices, num_allowed_devices,
1579+
// memory_fraction, preallocate, platform_name,
1580+
// error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1581+
// auto pjrt_client =
1582+
// std::shared_ptr<PjRtClient>(MakeGPUClient(
1583+
// node_id, num_nodes, allowed_devices,
1584+
// num_allowed_devices, memory_fraction, preallocate,
1585+
// platform_name, error));
1586+
// return std::shared_ptr<ifrt::Client>(
1587+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1588+
// }))
1589+
// .release();
1590+
// }
16211591

1622-
return MyValueOrThrow(
1623-
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1624-
address,
1625-
[](xla::ifrt::AttributeMap initialization_data) -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626-
auto pjrt_client =
1627-
std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
1628-
return xla::ifrt::PjRtClient::Create(std::move(pjrt_client));
1629-
}))
1630-
.release();
1631-
}
1592+
// extern "C" ifrt::proxy::GrpcServer *
1593+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1594+
// const char *c_address, const char *tpu_path, const char **error) {
1595+
// std::string address = c_address;
1596+
//
1597+
// return MyValueOrThrow(
1598+
// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1599+
// address,
1600+
// [](xla::ifrt::AttributeMap initialization_data) ->
1601+
// absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1602+
// auto pjrt_client =
1603+
// std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
1604+
// return
1605+
// xla::ifrt::PjRtClient::Create(std::move(pjrt_client));
1606+
// }))
1607+
// .release();
1608+
// }
16321609

16331610
extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) {
16341611
delete server;
@@ -1661,24 +1638,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16611638
.release();
16621639
}
16631640

1664-
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1665-
int node_id, int num_nodes) {
1666-
return ifrt_pjrt_make_client(
1667-
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
1641+
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client,
1642+
int node_id, int num_nodes,
1643+
void *distributed_runtime_client,
1644+
const char **error,
1645+
std::string key_prefix) {
1646+
ifrt::PjRtClient::CreateOptions options;
1647+
options.pjrt_client = pjrt_client->obj();
1648+
1649+
if (num_nodes > 1) {
1650+
if (distributed_runtime_client == nullptr) {
1651+
*error =
1652+
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
1653+
return nullptr;
1654+
}
1655+
auto typed_distributed_runtime_client = static_cast<
1656+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1657+
distributed_runtime_client);
1658+
options.kv_store = GetDistributedKeyValueStore(
1659+
typed_distributed_runtime_client->obj(), key_prefix);
1660+
}
1661+
1662+
options.process_id = node_id;
1663+
options.num_processes = num_nodes;
1664+
1665+
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1666+
}
1667+
1668+
extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared(uint8_t asynchronous,
1669+
int node_id) {
1670+
PjRtClient *client = MakeCPUClient(asynchronous, node_id);
1671+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1672+
}
1673+
1674+
extern "C" ifrt::Client *
1675+
ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
1676+
void *distributed_runtime_client,
1677+
const char **error) {
1678+
HeldPjRtClient *pjrt_client =
1679+
pjrt_make_cpu_client_shared(asynchronous, node_id);
1680+
if (pjrt_client == nullptr)
1681+
return nullptr;
1682+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1683+
distributed_runtime_client, error, "cpu");
1684+
}
1685+
1686+
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1687+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1688+
double memory_fraction, bool preallocate, const char *platform_name,
1689+
const char **error, void *distributed_runtime_client) {
1690+
PjRtClient *client = MakeGPUClient(
1691+
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1692+
preallocate, platform_name, error, distributed_runtime_client);
1693+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16681694
}
16691695

16701696
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
16711697
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
16721698
double memory_fraction, bool preallocate, const char *platform_name,
16731699
const char **error, void *distributed_runtime_client) {
1674-
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
1700+
HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared(
16751701
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1676-
preallocate, platform_name, error, distributed_runtime_client));
1702+
preallocate, platform_name, error, distributed_runtime_client);
1703+
if (pjrt_client == nullptr)
1704+
return nullptr;
1705+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1706+
distributed_runtime_client, error, "gpu");
16771707
}
16781708

1679-
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1680-
const char **error) {
1681-
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
1709+
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1710+
const char **error) {
1711+
PjRtClient *client = MakeTPUClient(tpu_path, error);
1712+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1713+
}
1714+
1715+
extern "C" ifrt::Client *
1716+
ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id,
1717+
int num_nodes, void *distributed_runtime_client) {
1718+
HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error);
1719+
if (pjrt_client == nullptr)
1720+
return nullptr;
1721+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1722+
distributed_runtime_client, error, "tpu");
16821723
}
16831724

16841725
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
@@ -1942,7 +1983,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
19421983

19431984
#pragma endregion
19441985

1945-
#pragma region PjRtDistributed
1986+
#pragma region xla::Distributed
19461987

19471988
extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
19481989
GetDistributedRuntimeClient(char *c_address, int32_t node_id,

src/xla/Client.jl

-52
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,3 @@ function get_addressable_device end
1414
function platform_name end
1515

1616
default_device(client::AbstractClient) = first(addressable_devices(client))
17-
18-
# Clients for Different Backends
19-
function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
20-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
21-
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
22-
LLVMclopts("-nvptx-fma-level=1")
23-
return client
24-
end
25-
26-
function GPUClient(
27-
cfunc,
28-
node_id=0,
29-
num_nodes=1,
30-
platform="gpu";
31-
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32-
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
33-
)
34-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
35-
refstr = Ref{Cstring}()
36-
37-
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
38-
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39-
distributed_runtime_client =
40-
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
41-
42-
client = ccall(
43-
f,
44-
Ptr{Cvoid},
45-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
46-
node_id,
47-
num_nodes,
48-
allowed_devices,
49-
num_allowed_devices,
50-
XLA_REACTANT_GPU_MEM_FRACTION[],
51-
false,
52-
platform,
53-
refstr,
54-
distributed_runtime_client,
55-
)
56-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
57-
LLVMclopts("-nvptx-fma-level=1")
58-
return client
59-
end
60-
61-
function TPUClient(cfunc, tpu_path::String)
62-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
63-
refstr = Ref{Cstring}()
64-
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr)
65-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
66-
LLVMclopts("-nvptx-fma-level=1")
67-
return client
68-
end

0 commit comments

Comments
 (0)