|
81 | 81 |
|
82 | 82 | // IFRT
|
83 | 83 | #include "xla/python/ifrt/array.h"
|
| 84 | +#include "xla/python/ifrt/attribute_map.h" |
84 | 85 | #include "xla/python/ifrt/basic_device_list.h"
|
85 | 86 | #include "xla/python/ifrt/client.h"
|
86 | 87 | #include "xla/python/ifrt/compiler.h"
|
|
99 | 100 | #include "xla/python/ifrt/topology.h"
|
100 | 101 | #include "xla/python/ifrt/tuple.h"
|
101 | 102 | #include "xla/python/ifrt/value.h"
|
102 |
| -#include "xla/python/ifrt/attribute_map.h" |
103 | 103 |
|
104 | 104 | // IFRT - PJRT
|
105 | 105 | #include "xla/python/pjrt_ifrt/pjrt_array.h"
|
@@ -343,14 +343,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
|
343 | 343 | delete server;
|
344 | 344 | }
|
345 | 345 |
|
346 |
| -extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id, |
347 |
| - int num_nodes) { |
| 346 | +extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) { |
348 | 347 | CpuClientOptions options;
|
349 |
| - // options.kv_store = "etcd"; |
| 348 | + |
350 | 349 | options.process_id = node_id;
|
351 |
| - // options.num_nodes = num_nodes; |
352 |
| - // options.collectives = num_nodes; |
353 | 350 | options.asynchronous = asynchronous != 0;
|
| 351 | + |
354 | 352 | auto client = MyValueOrThrow(GetTfrtCpuClient(options));
|
355 | 353 | return client.release();
|
356 | 354 | }
|
@@ -1272,28 +1270,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
|
1272 | 1270 | return wrap(entryFn);
|
1273 | 1271 | }
|
1274 | 1272 |
|
1275 |
| -extern "C" HeldPjRtClient * |
1276 |
| -pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) { |
1277 |
| - PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes); |
1278 |
| - return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1279 |
| -} |
1280 |
| - |
1281 |
| -extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared( |
1282 |
| - int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, |
1283 |
| - double memory_fraction, bool preallocate, const char *platform_name, |
1284 |
| - const char **error, void *distributed_runtime_client) { |
1285 |
| - PjRtClient *client = MakeGPUClient( |
1286 |
| - node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, |
1287 |
| - preallocate, platform_name, error, distributed_runtime_client); |
1288 |
| - return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1289 |
| -} |
1290 |
| - |
1291 |
| -extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path, |
1292 |
| - const char **error) { |
1293 |
| - PjRtClient *client = MakeTPUClient(tpu_path, error); |
1294 |
| - return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1295 |
| -} |
1296 |
| - |
1297 | 1273 | extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; }
|
1298 | 1274 |
|
1299 | 1275 | extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) {
|
@@ -1370,11 +1346,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
|
1370 | 1346 | std::shared_ptr<PjRtClient>(buffer->ptr()->client()));
|
1371 | 1347 | }
|
1372 | 1348 |
|
1373 |
| -extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) { |
1374 |
| - xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; |
1375 |
| - return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); |
1376 |
| -} |
1377 |
| - |
1378 | 1349 | extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }
|
1379 | 1350 |
|
1380 | 1351 | // generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
|
@@ -1576,59 +1547,65 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
|
1576 | 1547 |
|
1577 | 1548 | #pragma region IfRtClient
|
1578 | 1549 |
|
1579 |
| -extern "C" ifrt::proxy::GrpcServer * |
1580 |
| -ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu( |
1581 |
| - const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) { |
1582 |
| - std::string address = c_address; |
1583 |
| - |
1584 |
| - return MyValueOrThrow( |
1585 |
| - ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
1586 |
| - address, |
1587 |
| - [asynchronous, node_id, num_nodes]() |
1588 |
| - -> absl::StatusOr<std::shared_ptr<ifrt::Client>> { |
1589 |
| - auto pjrt_client = std::shared_ptr<PjRtClient>( |
1590 |
| - MakeCPUClient(asynchronous, node_id, num_nodes)); |
1591 |
| - return std::shared_ptr<ifrt::Client>( |
1592 |
| - xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
1593 |
| - })) |
1594 |
| - .release(); |
1595 |
| -} |
1596 |
| - |
1597 |
| -extern "C" ifrt::proxy::GrpcServer * |
1598 |
| -ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu( |
1599 |
| - int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, |
1600 |
| - double memory_fraction, bool preallocate, const char *platform_name, |
1601 |
| - const char **error) { |
1602 |
| - return MyValueOrThrow( |
1603 |
| - ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
1604 |
| - std::string(), |
1605 |
| - [node_id, num_nodes, allowed_devices, num_allowed_devices, |
1606 |
| - memory_fraction, preallocate, platform_name, |
1607 |
| - error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> { |
1608 |
| - auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient( |
1609 |
| - node_id, num_nodes, allowed_devices, num_allowed_devices, |
1610 |
| - memory_fraction, preallocate, platform_name, error)); |
1611 |
| - return std::shared_ptr<ifrt::Client>( |
1612 |
| - xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
1613 |
| - })) |
1614 |
| - .release(); |
1615 |
| -} |
| 1550 | +// XXX: Bring back with the correct API |
| 1551 | +// extern "C" ifrt::proxy::GrpcServer * |
| 1552 | +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu( |
| 1553 | +// const char *c_address, uint8_t asynchronous, int node_id) { |
| 1554 | +// std::string address = c_address; |
| 1555 | + |
| 1556 | +// return MyValueOrThrow( |
| 1557 | +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
| 1558 | +// address, |
| 1559 | +// [asynchronous, |
| 1560 | +// node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> |
| 1561 | +// { |
| 1562 | +// auto pjrt_client = std::shared_ptr<PjRtClient>( |
| 1563 | +// MakeCPUClient(asynchronous, node_id)); |
| 1564 | +// return std::shared_ptr<ifrt::Client>( |
| 1565 | +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
| 1566 | +// })) |
| 1567 | +// .release(); |
| 1568 | +// } |
1616 | 1569 |
|
1617 |
| -extern "C" ifrt::proxy::GrpcServer * |
1618 |
| -ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( |
1619 |
| - const char *c_address, const char *tpu_path, const char **error) { |
1620 |
| - std::string address = c_address; |
| 1570 | +// extern "C" ifrt::proxy::GrpcServer * |
| 1571 | +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu( |
| 1572 | +// int node_id, int num_nodes, int *allowed_devices, int |
| 1573 | +// num_allowed_devices, double memory_fraction, bool preallocate, const char |
| 1574 | +// *platform_name, const char **error) { |
| 1575 | +// return MyValueOrThrow( |
| 1576 | +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
| 1577 | +// std::string(), |
| 1578 | +// [node_id, num_nodes, allowed_devices, num_allowed_devices, |
| 1579 | +// memory_fraction, preallocate, platform_name, |
| 1580 | +// error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> { |
| 1581 | +// auto pjrt_client = |
| 1582 | +// std::shared_ptr<PjRtClient>(MakeGPUClient( |
| 1583 | +// node_id, num_nodes, allowed_devices, |
| 1584 | +// num_allowed_devices, memory_fraction, preallocate, |
| 1585 | +// platform_name, error)); |
| 1586 | +// return std::shared_ptr<ifrt::Client>( |
| 1587 | +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
| 1588 | +// })) |
| 1589 | +// .release(); |
| 1590 | +// } |
1621 | 1591 |
|
1622 |
| - return MyValueOrThrow( |
1623 |
| - xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
1624 |
| - address, |
1625 |
| - [](xla::ifrt::AttributeMap initialization_data) -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> { |
1626 |
| - auto pjrt_client = |
1627 |
| - std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU")); |
1628 |
| - return xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); |
1629 |
| - })) |
1630 |
| - .release(); |
1631 |
| -} |
| 1592 | +// extern "C" ifrt::proxy::GrpcServer * |
| 1593 | +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( |
| 1594 | +// const char *c_address, const char *tpu_path, const char **error) { |
| 1595 | +// std::string address = c_address; |
| 1596 | +// |
| 1597 | +// return MyValueOrThrow( |
| 1598 | +// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
| 1599 | +// address, |
| 1600 | +// [](xla::ifrt::AttributeMap initialization_data) -> |
| 1601 | +// absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> { |
| 1602 | +// auto pjrt_client = |
| 1603 | +// std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU")); |
| 1604 | +// return |
| 1605 | +// xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); |
| 1606 | +// })) |
| 1607 | +// .release(); |
| 1608 | +// } |
1632 | 1609 |
|
1633 | 1610 | extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) {
|
1634 | 1611 | delete server;
|
@@ -1661,24 +1638,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
|
1661 | 1638 | .release();
|
1662 | 1639 | }
|
1663 | 1640 |
|
1664 |
| -extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous, |
1665 |
| - int node_id, int num_nodes) { |
1666 |
| - return ifrt_pjrt_make_client( |
1667 |
| - pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes)); |
| 1641 | +extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client, |
| 1642 | + int node_id, int num_nodes, |
| 1643 | + void *distributed_runtime_client, |
| 1644 | + const char **error, |
| 1645 | + std::string key_prefix) { |
| 1646 | + ifrt::PjRtClient::CreateOptions options; |
| 1647 | + options.pjrt_client = pjrt_client->obj(); |
| 1648 | + |
| 1649 | + if (num_nodes > 1) { |
| 1650 | + if (distributed_runtime_client == nullptr) { |
| 1651 | + *error = |
| 1652 | + "`distributed_runtime_client` must be non-null if `num_nodes` > 1"; |
| 1653 | + return nullptr; |
| 1654 | + } |
| 1655 | + auto typed_distributed_runtime_client = static_cast< |
| 1656 | + HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>( |
| 1657 | + distributed_runtime_client); |
| 1658 | + options.kv_store = GetDistributedKeyValueStore( |
| 1659 | + typed_distributed_runtime_client->obj(), key_prefix); |
| 1660 | + } |
| 1661 | + |
| 1662 | + options.process_id = node_id; |
| 1663 | + options.num_processes = num_nodes; |
| 1664 | + |
| 1665 | + return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); |
| 1666 | +} |
| 1667 | + |
| 1668 | +extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared(uint8_t asynchronous, |
| 1669 | + int node_id) { |
| 1670 | + PjRtClient *client = MakeCPUClient(asynchronous, node_id); |
| 1671 | + return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
| 1672 | +} |
| 1673 | + |
| 1674 | +extern "C" ifrt::Client * |
| 1675 | +ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes, |
| 1676 | + void *distributed_runtime_client, |
| 1677 | + const char **error) { |
| 1678 | + HeldPjRtClient *pjrt_client = |
| 1679 | + pjrt_make_cpu_client_shared(asynchronous, node_id); |
| 1680 | + if (pjrt_client == nullptr) |
| 1681 | + return nullptr; |
| 1682 | + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, |
| 1683 | + distributed_runtime_client, error, "cpu"); |
| 1684 | +} |
| 1685 | + |
| 1686 | +extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared( |
| 1687 | + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, |
| 1688 | + double memory_fraction, bool preallocate, const char *platform_name, |
| 1689 | + const char **error, void *distributed_runtime_client) { |
| 1690 | + PjRtClient *client = MakeGPUClient( |
| 1691 | + node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, |
| 1692 | + preallocate, platform_name, error, distributed_runtime_client); |
| 1693 | + return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
1668 | 1694 | }
|
1669 | 1695 |
|
1670 | 1696 | extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
|
1671 | 1697 | int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
|
1672 | 1698 | double memory_fraction, bool preallocate, const char *platform_name,
|
1673 | 1699 | const char **error, void *distributed_runtime_client) {
|
1674 |
| - return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared( |
| 1700 | + HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared( |
1675 | 1701 | node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
|
1676 |
| - preallocate, platform_name, error, distributed_runtime_client)); |
| 1702 | + preallocate, platform_name, error, distributed_runtime_client); |
| 1703 | + if (pjrt_client == nullptr) |
| 1704 | + return nullptr; |
| 1705 | + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, |
| 1706 | + distributed_runtime_client, error, "gpu"); |
1677 | 1707 | }
|
1678 | 1708 |
|
1679 |
| -extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path, |
1680 |
| - const char **error) { |
1681 |
| - return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error)); |
| 1709 | +extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path, |
| 1710 | + const char **error) { |
| 1711 | + PjRtClient *client = MakeTPUClient(tpu_path, error); |
| 1712 | + return reactant::capture(std::shared_ptr<PjRtClient>(client)); |
| 1713 | +} |
| 1714 | + |
| 1715 | +extern "C" ifrt::Client * |
| 1716 | +ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id, |
| 1717 | + int num_nodes, void *distributed_runtime_client) { |
| 1718 | + HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error); |
| 1719 | + if (pjrt_client == nullptr) |
| 1720 | + return nullptr; |
| 1721 | + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, |
| 1722 | + distributed_runtime_client, error, "tpu"); |
1682 | 1723 | }
|
1683 | 1724 |
|
1684 | 1725 | extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
|
@@ -1942,7 +1983,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
|
1942 | 1983 |
|
1943 | 1984 | #pragma endregion
|
1944 | 1985 |
|
1945 |
| -#pragma region PjRtDistributed |
| 1986 | +#pragma region xla::Distributed |
1946 | 1987 |
|
1947 | 1988 | extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
|
1948 | 1989 | GetDistributedRuntimeClient(char *c_address, int32_t node_id,
|
|
0 commit comments