Skip to content

Commit ca8f3da

Browse files
committed
feat: construct IFRT clients with distributed options
1 parent 408e120 commit ca8f3da

File tree

5 files changed

+351
-184
lines changed

5 files changed

+351
-184
lines changed

deps/ReactantExtra/API.cpp

+131-95
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
342342
delete server;
343343
}
344344

345-
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
346-
int num_nodes) {
345+
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
347346
CpuClientOptions options;
348-
// options.kv_store = "etcd";
347+
349348
options.process_id = node_id;
350-
// options.num_nodes = num_nodes;
351-
// options.collectives = num_nodes;
352349
options.asynchronous = asynchronous != 0;
350+
353351
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
354352
return client.release();
355353
}
@@ -1271,28 +1269,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
12711269
return wrap(entryFn);
12721270
}
12731271

1274-
extern "C" HeldPjRtClient *
1275-
pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) {
1276-
PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes);
1277-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1278-
}
1279-
1280-
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1281-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1282-
double memory_fraction, bool preallocate, const char *platform_name,
1283-
const char **error, void *distributed_runtime_client) {
1284-
PjRtClient *client = MakeGPUClient(
1285-
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1286-
preallocate, platform_name, error, distributed_runtime_client);
1287-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1288-
}
1289-
1290-
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1291-
const char **error) {
1292-
PjRtClient *client = MakeTPUClient(tpu_path, error);
1293-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1294-
}
1295-
12961272
extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; }
12971273

12981274
extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) {
@@ -1369,11 +1345,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
13691345
std::shared_ptr<PjRtClient>(buffer->ptr()->client()));
13701346
}
13711347

1372-
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) {
1373-
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1374-
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1375-
}
1376-
13771348
extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }
13781349

13791350
// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1575,61 +1546,62 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15751546

15761547
#pragma region IfRtClient
15771548

1578-
extern "C" ifrt::proxy::GrpcServer *
1579-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1580-
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
1581-
std::string address = c_address;
1582-
1583-
return MyValueOrThrow(
1584-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1585-
address,
1586-
[asynchronous, node_id, num_nodes]()
1587-
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1588-
auto pjrt_client = std::shared_ptr<PjRtClient>(
1589-
MakeCPUClient(asynchronous, node_id, num_nodes));
1590-
return std::shared_ptr<ifrt::Client>(
1591-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1592-
}))
1593-
.release();
1594-
}
1595-
1596-
extern "C" ifrt::proxy::GrpcServer *
1597-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1598-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1599-
double memory_fraction, bool preallocate, const char *platform_name,
1600-
const char **error) {
1601-
return MyValueOrThrow(
1602-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1603-
std::string(),
1604-
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1605-
memory_fraction, preallocate, platform_name,
1606-
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1607-
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1608-
node_id, num_nodes, allowed_devices, num_allowed_devices,
1609-
memory_fraction, preallocate, platform_name, error));
1610-
return std::shared_ptr<ifrt::Client>(
1611-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1612-
}))
1613-
.release();
1614-
}
1549+
// XXX: Bring back with the correct API
1550+
// extern "C" ifrt::proxy::GrpcServer *
1551+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1552+
// const char *c_address, uint8_t asynchronous, int node_id) {
1553+
// std::string address = c_address;
1554+
1555+
// return MyValueOrThrow(
1556+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1557+
// address,
1558+
// [asynchronous,
1559+
// node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1560+
// auto pjrt_client = std::shared_ptr<PjRtClient>(
1561+
// MakeCPUClient(asynchronous, node_id));
1562+
// return std::shared_ptr<ifrt::Client>(
1563+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1564+
// }))
1565+
// .release();
1566+
// }
16151567

1616-
extern "C" ifrt::proxy::GrpcServer *
1617-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1618-
const char *c_address, const char *tpu_path, const char **error) {
1619-
std::string address = c_address;
1568+
// extern "C" ifrt::proxy::GrpcServer *
1569+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1570+
// int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1571+
// double memory_fraction, bool preallocate, const char *platform_name,
1572+
// const char **error) {
1573+
// return MyValueOrThrow(
1574+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1575+
// std::string(),
1576+
// [node_id, num_nodes, allowed_devices, num_allowed_devices,
1577+
// memory_fraction, preallocate, platform_name,
1578+
// error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1579+
// auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1580+
// node_id, num_nodes, allowed_devices, num_allowed_devices,
1581+
// memory_fraction, preallocate, platform_name, error));
1582+
// return std::shared_ptr<ifrt::Client>(
1583+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1584+
// }))
1585+
// .release();
1586+
// }
16201587

1621-
return MyValueOrThrow(
1622-
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1623-
address,
1624-
[tpu_path, error]()
1625-
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626-
auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1627-
MakeTPUClient(tpu_path, error));
1628-
return std::shared_ptr<xla::ifrt::Client>(
1629-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1630-
}))
1631-
.release();
1632-
}
1588+
// extern "C" ifrt::proxy::GrpcServer *
1589+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1590+
// const char *c_address, const char *tpu_path, const char **error) {
1591+
// std::string address = c_address;
1592+
1593+
// return MyValueOrThrow(
1594+
// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1595+
// address,
1596+
// [tpu_path, error]()
1597+
// -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1598+
// auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1599+
// MakeTPUClient(tpu_path, error));
1600+
// return std::shared_ptr<xla::ifrt::Client>(
1601+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1602+
// }))
1603+
// .release();
1604+
// }
16331605

16341606
extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) {
16351607
delete server;
@@ -1662,24 +1634,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16621634
.release();
16631635
}
16641636

1665-
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1666-
int node_id, int num_nodes) {
1667-
return ifrt_pjrt_make_client(
1668-
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
1637+
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client,
1638+
int node_id, int num_nodes,
1639+
void *distributed_runtime_client,
1640+
const char **error,
1641+
std::string key_prefix) {
1642+
ifrt::PjRtClient::CreateOptions options;
1643+
options.pjrt_client = pjrt_client->obj();
1644+
1645+
if (num_nodes > 1) {
1646+
if (distributed_runtime_client == nullptr) {
1647+
*error =
1648+
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
1649+
return nullptr;
1650+
}
1651+
auto typed_distributed_runtime_client = static_cast<
1652+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1653+
distributed_runtime_client);
1654+
options.kv_store = GetDistributedKeyValueStore(
1655+
typed_distributed_runtime_client->obj(), key_prefix);
1656+
}
1657+
1658+
options.process_id = node_id;
1659+
options.num_processes = num_nodes;
1660+
1661+
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1662+
}
1663+
1664+
extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared(uint8_t asynchronous,
1665+
int node_id) {
1666+
PjRtClient *client = MakeCPUClient(asynchronous, node_id);
1667+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1668+
}
1669+
1670+
extern "C" ifrt::Client *
1671+
ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
1672+
void *distributed_runtime_client,
1673+
const char **error) {
1674+
HeldPjRtClient *pjrt_client =
1675+
pjrt_make_cpu_client_shared(asynchronous, node_id);
1676+
if (pjrt_client == nullptr)
1677+
return nullptr;
1678+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1679+
distributed_runtime_client, error, "cpu");
1680+
}
1681+
1682+
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1683+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1684+
double memory_fraction, bool preallocate, const char *platform_name,
1685+
const char **error, void *distributed_runtime_client) {
1686+
PjRtClient *client = MakeGPUClient(
1687+
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1688+
preallocate, platform_name, error, distributed_runtime_client);
1689+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16691690
}
16701691

16711692
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
16721693
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
16731694
double memory_fraction, bool preallocate, const char *platform_name,
16741695
const char **error, void *distributed_runtime_client) {
1675-
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
1696+
HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared(
16761697
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1677-
preallocate, platform_name, error, distributed_runtime_client));
1698+
preallocate, platform_name, error, distributed_runtime_client);
1699+
if (pjrt_client == nullptr)
1700+
return nullptr;
1701+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1702+
distributed_runtime_client, error, "gpu");
1703+
}
1704+
1705+
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1706+
const char **error) {
1707+
PjRtClient *client = MakeTPUClient(tpu_path, error);
1708+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16781709
}
16791710

1680-
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1681-
const char **error) {
1682-
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
1711+
extern "C" ifrt::Client *
1712+
ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id,
1713+
int num_nodes, void *distributed_runtime_client) {
1714+
HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error);
1715+
if (pjrt_client == nullptr)
1716+
return nullptr;
1717+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1718+
distributed_runtime_client, error, "tpu");
16831719
}
16841720

16851721
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
@@ -1943,7 +1979,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
19431979

19441980
#pragma endregion
19451981

1946-
#pragma region PjRtDistributed
1982+
#pragma region xla::Distributed
19471983

19481984
extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
19491985
GetDistributedRuntimeClient(char *c_address, int32_t node_id,

src/xla/Client.jl

-52
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,3 @@ function get_addressable_device end
1414
function platform_name end
1515

1616
default_device(client::AbstractClient) = first(addressable_devices(client))
17-
18-
# Clients for Different Backends
19-
function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
20-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
21-
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
22-
LLVMclopts("-nvptx-fma-level=1")
23-
return client
24-
end
25-
26-
function GPUClient(
27-
cfunc,
28-
node_id=0,
29-
num_nodes=1,
30-
platform="gpu";
31-
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32-
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
33-
)
34-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
35-
refstr = Ref{Cstring}()
36-
37-
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
38-
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39-
distributed_runtime_client =
40-
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
41-
42-
client = ccall(
43-
f,
44-
Ptr{Cvoid},
45-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
46-
node_id,
47-
num_nodes,
48-
allowed_devices,
49-
num_allowed_devices,
50-
XLA_REACTANT_GPU_MEM_FRACTION[],
51-
false,
52-
platform,
53-
refstr,
54-
distributed_runtime_client,
55-
)
56-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
57-
LLVMclopts("-nvptx-fma-level=1")
58-
return client
59-
end
60-
61-
function TPUClient(cfunc, tpu_path::String)
62-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
63-
refstr = Ref{Cstring}()
64-
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr)
65-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
66-
LLVMclopts("-nvptx-fma-level=1")
67-
return client
68-
end

0 commit comments

Comments
 (0)