@@ -342,14 +342,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
342
342
delete server;
343
343
}
344
344
345
- extern " C" PjRtClient *MakeCPUClient (uint8_t asynchronous, int node_id,
346
- int num_nodes) {
345
+ extern " C" PjRtClient *MakeCPUClient (uint8_t asynchronous, int node_id) {
347
346
CpuClientOptions options;
348
- // options.kv_store = "etcd";
347
+
349
348
options.process_id = node_id;
350
- // options.num_nodes = num_nodes;
351
- // options.collectives = num_nodes;
352
349
options.asynchronous = asynchronous != 0 ;
350
+
353
351
auto client = MyValueOrThrow (GetTfrtCpuClient (options));
354
352
return client.release ();
355
353
}
@@ -1271,28 +1269,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
1271
1269
return wrap (entryFn);
1272
1270
}
1273
1271
1274
- extern " C" HeldPjRtClient *
1275
- pjrt_make_cpu_client_shared (uint8_t asynchronous, int node_id, int num_nodes) {
1276
- PjRtClient *client = MakeCPUClient (asynchronous, node_id, num_nodes);
1277
- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1278
- }
1279
-
1280
- extern " C" HeldPjRtClient *pjrt_make_gpu_client_shared (
1281
- int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1282
- double memory_fraction, bool preallocate, const char *platform_name,
1283
- const char **error, void *distributed_runtime_client) {
1284
- PjRtClient *client = MakeGPUClient (
1285
- node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1286
- preallocate, platform_name, error, distributed_runtime_client);
1287
- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1288
- }
1289
-
1290
- extern " C" HeldPjRtClient *pjrt_make_tpu_client_shared (const char *tpu_path,
1291
- const char **error) {
1292
- PjRtClient *client = MakeTPUClient (tpu_path, error);
1293
- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1294
- }
1295
-
1296
1272
extern " C" void pjrt_client_dtor (HeldPjRtClient *client) { delete client; }
1297
1273
1298
1274
extern " C" int pjrt_client_num_devices (HeldPjRtClient *client) {
@@ -1369,11 +1345,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
1369
1345
std::shared_ptr<PjRtClient>(buffer->ptr ()->client ()));
1370
1346
}
1371
1347
1372
- extern " C" ifrt::Client *ifrt_pjrt_make_client (HeldPjRtClient *pjrt_client) {
1373
- xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1374
- return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1375
- }
1376
-
1377
1348
extern " C" void ifrt_client_dtor (ifrt::Client *client) { delete client; }
1378
1349
1379
1350
// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1575,61 +1546,62 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
1575
1546
1576
1547
#pragma region IfRtClient
1577
1548
1578
- extern " C" ifrt::proxy::GrpcServer *
1579
- ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu (
1580
- const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
1581
- std::string address = c_address;
1582
-
1583
- return MyValueOrThrow (
1584
- ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1585
- address,
1586
- [asynchronous, node_id, num_nodes]()
1587
- -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1588
- auto pjrt_client = std::shared_ptr<PjRtClient>(
1589
- MakeCPUClient (asynchronous, node_id, num_nodes));
1590
- return std::shared_ptr<ifrt::Client>(
1591
- xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1592
- }))
1593
- .release ();
1594
- }
1595
-
1596
- extern " C" ifrt::proxy::GrpcServer *
1597
- ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu (
1598
- int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1599
- double memory_fraction, bool preallocate, const char *platform_name,
1600
- const char **error) {
1601
- return MyValueOrThrow (
1602
- ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1603
- std::string (),
1604
- [node_id, num_nodes, allowed_devices, num_allowed_devices,
1605
- memory_fraction, preallocate, platform_name,
1606
- error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1607
- auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient (
1608
- node_id, num_nodes, allowed_devices, num_allowed_devices,
1609
- memory_fraction, preallocate, platform_name, error));
1610
- return std::shared_ptr<ifrt::Client>(
1611
- xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1612
- }))
1613
- .release ();
1614
- }
1549
+ // XXX: Bring back with the correct API
1550
+ // extern "C" ifrt::proxy::GrpcServer *
1551
+ // ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1552
+ // const char *c_address, uint8_t asynchronous, int node_id) {
1553
+ // std::string address = c_address;
1554
+
1555
+ // return MyValueOrThrow(
1556
+ // ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1557
+ // address,
1558
+ // [asynchronous,
1559
+ // node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1560
+ // auto pjrt_client = std::shared_ptr<PjRtClient>(
1561
+ // MakeCPUClient(asynchronous, node_id));
1562
+ // return std::shared_ptr<ifrt::Client>(
1563
+ // xla::ifrt::PjRtClient::Create(pjrt_client).release());
1564
+ // }))
1565
+ // .release();
1566
+ // }
1615
1567
1616
- extern " C" ifrt::proxy::GrpcServer *
1617
- ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu (
1618
- const char *c_address, const char *tpu_path, const char **error) {
1619
- std::string address = c_address;
1568
+ // extern "C" ifrt::proxy::GrpcServer *
1569
+ // ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1570
+ // int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1571
+ // double memory_fraction, bool preallocate, const char *platform_name,
1572
+ // const char **error) {
1573
+ // return MyValueOrThrow(
1574
+ // ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1575
+ // std::string(),
1576
+ // [node_id, num_nodes, allowed_devices, num_allowed_devices,
1577
+ // memory_fraction, preallocate, platform_name,
1578
+ // error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1579
+ // auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1580
+ // node_id, num_nodes, allowed_devices, num_allowed_devices,
1581
+ // memory_fraction, preallocate, platform_name, error));
1582
+ // return std::shared_ptr<ifrt::Client>(
1583
+ // xla::ifrt::PjRtClient::Create(pjrt_client).release());
1584
+ // }))
1585
+ // .release();
1586
+ // }
1620
1587
1621
- return MyValueOrThrow (
1622
- xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1623
- address,
1624
- [tpu_path, error]()
1625
- -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626
- auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1627
- MakeTPUClient (tpu_path, error));
1628
- return std::shared_ptr<xla::ifrt::Client>(
1629
- xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1630
- }))
1631
- .release ();
1632
- }
1588
+ // extern "C" ifrt::proxy::GrpcServer *
1589
+ // ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1590
+ // const char *c_address, const char *tpu_path, const char **error) {
1591
+ // std::string address = c_address;
1592
+
1593
+ // return MyValueOrThrow(
1594
+ // xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1595
+ // address,
1596
+ // [tpu_path, error]()
1597
+ // -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1598
+ // auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1599
+ // MakeTPUClient(tpu_path, error));
1600
+ // return std::shared_ptr<xla::ifrt::Client>(
1601
+ // xla::ifrt::PjRtClient::Create(pjrt_client).release());
1602
+ // }))
1603
+ // .release();
1604
+ // }
1633
1605
1634
1606
extern " C" void ifrt_proxy_grpc_server_dtor (ifrt::proxy::GrpcServer *server) {
1635
1607
delete server;
@@ -1662,24 +1634,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
1662
1634
.release ();
1663
1635
}
1664
1636
1665
- extern " C" ifrt::Client *ifrt_make_pjrt_cpu_client (uint8_t asynchronous,
1666
- int node_id, int num_nodes) {
1667
- return ifrt_pjrt_make_client (
1668
- pjrt_make_cpu_client_shared (asynchronous, node_id, num_nodes));
1637
+ extern " C" ifrt::Client *ifrt_pjrt_make_client (HeldPjRtClient *pjrt_client,
1638
+ int node_id, int num_nodes,
1639
+ void *distributed_runtime_client,
1640
+ const char **error,
1641
+ std::string key_prefix) {
1642
+ ifrt::PjRtClient::CreateOptions options;
1643
+ options.pjrt_client = pjrt_client->obj ();
1644
+
1645
+ if (num_nodes > 1 ) {
1646
+ if (distributed_runtime_client == nullptr ) {
1647
+ *error =
1648
+ " `distributed_runtime_client` must be non-null if `num_nodes` > 1" ;
1649
+ return nullptr ;
1650
+ }
1651
+ auto typed_distributed_runtime_client = static_cast <
1652
+ HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1653
+ distributed_runtime_client);
1654
+ options.kv_store = GetDistributedKeyValueStore (
1655
+ typed_distributed_runtime_client->obj (), key_prefix);
1656
+ }
1657
+
1658
+ options.process_id = node_id;
1659
+ options.num_processes = num_nodes;
1660
+
1661
+ return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1662
+ }
1663
+
1664
+ extern " C" HeldPjRtClient *pjrt_make_cpu_client_shared (uint8_t asynchronous,
1665
+ int node_id) {
1666
+ PjRtClient *client = MakeCPUClient (asynchronous, node_id);
1667
+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
1668
+ }
1669
+
1670
+ extern " C" ifrt::Client *
1671
+ ifrt_make_pjrt_cpu_client (uint8_t asynchronous, int node_id, int num_nodes,
1672
+ void *distributed_runtime_client,
1673
+ const char **error) {
1674
+ HeldPjRtClient *pjrt_client =
1675
+ pjrt_make_cpu_client_shared (asynchronous, node_id);
1676
+ if (pjrt_client == nullptr )
1677
+ return nullptr ;
1678
+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1679
+ distributed_runtime_client, error, " cpu" );
1680
+ }
1681
+
1682
+ extern " C" HeldPjRtClient *pjrt_make_gpu_client_shared (
1683
+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1684
+ double memory_fraction, bool preallocate, const char *platform_name,
1685
+ const char **error, void *distributed_runtime_client) {
1686
+ PjRtClient *client = MakeGPUClient (
1687
+ node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1688
+ preallocate, platform_name, error, distributed_runtime_client);
1689
+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
1669
1690
}
1670
1691
1671
1692
extern " C" ifrt::Client *ifrt_make_pjrt_gpu_client (
1672
1693
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1673
1694
double memory_fraction, bool preallocate, const char *platform_name,
1674
1695
const char **error, void *distributed_runtime_client) {
1675
- return ifrt_pjrt_make_client ( pjrt_make_gpu_client_shared (
1696
+ HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared (
1676
1697
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1677
- preallocate, platform_name, error, distributed_runtime_client));
1698
+ preallocate, platform_name, error, distributed_runtime_client);
1699
+ if (pjrt_client == nullptr )
1700
+ return nullptr ;
1701
+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1702
+ distributed_runtime_client, error, " gpu" );
1703
+ }
1704
+
1705
+ extern " C" HeldPjRtClient *pjrt_make_tpu_client_shared (const char *tpu_path,
1706
+ const char **error) {
1707
+ PjRtClient *client = MakeTPUClient (tpu_path, error);
1708
+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
1678
1709
}
1679
1710
1680
- extern " C" ifrt::Client *ifrt_make_pjrt_tpu_client (const char *tpu_path,
1681
- const char **error) {
1682
- return ifrt_pjrt_make_client (pjrt_make_tpu_client_shared (tpu_path, error));
1711
+ extern " C" ifrt::Client *
1712
+ ifrt_make_pjrt_tpu_client (const char *tpu_path, const char **error, int node_id,
1713
+ int num_nodes, void *distributed_runtime_client) {
1714
+ HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared (tpu_path, error);
1715
+ if (pjrt_client == nullptr )
1716
+ return nullptr ;
1717
+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1718
+ distributed_runtime_client, error, " tpu" );
1683
1719
}
1684
1720
1685
1721
extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
@@ -1943,7 +1979,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
1943
1979
1944
1980
#pragma endregion
1945
1981
1946
- #pragma region PjRtDistributed
1982
+ #pragma region xla::Distributed
1947
1983
1948
1984
extern " C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
1949
1985
GetDistributedRuntimeClient (char *c_address, int32_t node_id,
0 commit comments