|
76 | 76 |
|
77 | 77 | // IFRT
|
78 | 78 | #include "xla/python/ifrt/array.h"
|
| 79 | +#include "xla/python/ifrt/basic_device_list.h" |
79 | 80 | #include "xla/python/ifrt/client.h"
|
80 | 81 | #include "xla/python/ifrt/compiler.h"
|
81 | 82 | #include "xla/python/ifrt/device.h"
|
82 | 83 | #include "xla/python/ifrt/device_list.h"
|
83 |
| -#include "xla/python/ifrt/basic_device_list.h" |
84 | 84 | #include "xla/python/ifrt/dtype.h"
|
85 | 85 | #include "xla/python/ifrt/executable.h"
|
86 | 86 | #include "xla/python/ifrt/hlo/hlo_program.h"
|
@@ -1469,6 +1469,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
|
1469 | 1469 |
|
1470 | 1470 | extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }
|
1471 | 1471 |
|
| 1472 | +extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) { |
| 1473 | + delete exec; |
| 1474 | +} |
| 1475 | + |
1472 | 1476 | extern "C" void ifrt_loaded_executable_execute(
|
1473 | 1477 | ifrt::LoadedExecutable *exec, int num_args,
|
1474 | 1478 | HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
|
@@ -1538,38 +1542,56 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
|
1538 | 1542 |
|
1539 | 1543 | #pragma region IfRtClient
|
1540 | 1544 |
|
1541 |
| -// right now only making it available for TPU |
1542 |
| -// in the future, we would like this for CPU and GPU PjRt backends too |
1543 | 1545 | extern "C" ifrt::proxy::GrpcServer *
|
1544 |
| -ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( |
1545 |
| - const char *c_address, const char *tpu_path, const char **error) { |
| 1546 | +ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu( |
| 1547 | + const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) { |
1546 | 1548 | std::string address = c_address;
|
1547 | 1549 |
|
1548 |
| - // taken from `MakeTPUClient` |
1549 |
| - std::string tpu_library_path; |
1550 |
| - if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) { |
1551 |
| - tpu_library_path = *path; |
1552 |
| - } else if (tpu_path) { |
1553 |
| - tpu_library_path = std::string(tpu_path); |
1554 |
| - } else { |
1555 |
| - *error = "Could not find TPU path"; |
1556 |
| - return nullptr; |
1557 |
| - } |
| 1550 | + return MyValueOrThrow( |
| 1551 | + ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
| 1552 | + address, |
| 1553 | + [asynchronous, node_id, num_nodes]() |
| 1554 | + -> absl::StatusOr<std::shared_ptr<ifrt::Client>> { |
| 1555 | + auto pjrt_client = std::shared_ptr<PjRtClient>( |
| 1556 | + MakeCPUClient(asynchronous, node_id, num_nodes)); |
| 1557 | + return std::shared_ptr<ifrt::Client>( |
| 1558 | + xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
| 1559 | + })) |
| 1560 | + .release(); |
| 1561 | +} |
1558 | 1562 |
|
1559 |
| - const PJRT_Api *pluginLoad = |
1560 |
| - LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error); |
1561 |
| - if (pluginLoad == nullptr) |
1562 |
| - return nullptr; |
1563 |
| - auto tpu_status = InitializePjrtPlugin("tpu", error); |
1564 |
| - if (tpu_status) |
1565 |
| - return nullptr; |
| 1563 | +extern "C" ifrt::proxy::GrpcServer * |
| 1564 | +ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu( |
| 1565 | + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, |
| 1566 | + double memory_fraction, bool preallocate, const char *platform_name, |
| 1567 | + const char **error) { |
| 1568 | + return MyValueOrThrow( |
| 1569 | + ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( |
| 1570 | + std::string(), |
| 1571 | + [node_id, num_nodes, allowed_devices, num_allowed_devices, |
| 1572 | + memory_fraction, preallocate, platform_name, |
| 1573 | + error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> { |
| 1574 | + auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient( |
| 1575 | + node_id, num_nodes, allowed_devices, num_allowed_devices, |
| 1576 | + memory_fraction, preallocate, platform_name, error)); |
| 1577 | + return std::shared_ptr<ifrt::Client>( |
| 1578 | + xla::ifrt::PjRtClient::Create(pjrt_client).release()); |
| 1579 | + })) |
| 1580 | + .release(); |
| 1581 | +} |
| 1582 | + |
| 1583 | +extern "C" ifrt::proxy::GrpcServer * |
| 1584 | +ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( |
| 1585 | + const char *c_address, const char *tpu_path, const char **error) { |
| 1586 | + std::string address = c_address; |
1566 | 1587 |
|
1567 | 1588 | return MyValueOrThrow(
|
1568 | 1589 | xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
|
1569 | 1590 | address,
|
1570 |
| - []() -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> { |
1571 |
| - auto pjrt_client = |
1572 |
| - std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU")); |
| 1591 | + [tpu_path, error]() |
| 1592 | + -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> { |
| 1593 | + auto pjrt_client = std::shared_ptr<xla::PjRtClient>( |
| 1594 | + MakeTPUClient(tpu_path, error)); |
1573 | 1595 | return std::shared_ptr<xla::ifrt::Client>(
|
1574 | 1596 | xla::ifrt::PjRtClient::Create(pjrt_client).release());
|
1575 | 1597 | }))
|
@@ -1604,28 +1626,28 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
|
1604 | 1626 | nullptr, // callback `on_connection_update`
|
1605 | 1627 | };
|
1606 | 1628 | return MyValueOrThrow(
|
1607 |
| - ifrt::proxy::CreateClient(c_proxy_server_address, options)) |
| 1629 | + ifrt::proxy::CreateClient(proxy_server_address, options)) |
1608 | 1630 | .release();
|
1609 | 1631 | }
|
1610 | 1632 |
|
1611 |
| -extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id, |
1612 |
| - int num_nodes) { |
| 1633 | +extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous, |
| 1634 | + int node_id, int num_nodes) { |
1613 | 1635 | return ifrt_pjrt_make_client(
|
1614 | 1636 | pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
|
1615 | 1637 | }
|
1616 | 1638 |
|
1617 | 1639 | extern "C" ifrt::Client *
|
1618 |
| -ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices, |
1619 |
| - int num_allowed_devices, double memory_fraction, |
1620 |
| - bool preallocate, const char *platform_name, |
1621 |
| - const char **error) { |
| 1640 | +ifrt_make_pjrt_gpu_client(int node_id, int num_nodes, int *allowed_devices, |
| 1641 | + int num_allowed_devices, double memory_fraction, |
| 1642 | + bool preallocate, const char *platform_name, |
| 1643 | + const char **error) { |
1622 | 1644 | return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
|
1623 | 1645 | node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
|
1624 | 1646 | preallocate, platform_name, error));
|
1625 | 1647 | }
|
1626 | 1648 |
|
1627 |
| -extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path, |
1628 |
| - const char **error) { |
| 1649 | +extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path, |
| 1650 | + const char **error) { |
1629 | 1651 | return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
|
1630 | 1652 | }
|
1631 | 1653 |
|
@@ -1815,4 +1837,77 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
|
1815 | 1837 | return cstr_from_string(hlo_sharding->DebugString());
|
1816 | 1838 | }
|
1817 | 1839 |
|
| 1840 | +extern "C" void |
| 1841 | +free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) { |
| 1842 | + delete sharding; |
| 1843 | +} |
| 1844 | + |
| 1845 | +extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> * |
| 1846 | +ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) { |
| 1847 | + return reactant::capture(std::shared_ptr<ifrt::Sharding>(hlo_sharding)); |
| 1848 | +} |
| 1849 | + |
| 1850 | +extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> * |
| 1851 | +ifrt_sharding_from_hlo_sharding( |
| 1852 | + HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, |
| 1853 | + ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) { |
| 1854 | + return ifrt_sharding_from_ifrt_hlo_sharding( |
| 1855 | + ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind, |
| 1856 | + xla_hlo_sharding)); |
| 1857 | +} |
| 1858 | + |
| 1859 | +extern "C" const char * |
| 1860 | +ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) { |
| 1861 | + return cstr_from_string(sharding->obj()->DebugString()); |
| 1862 | +} |
| 1863 | + |
| 1864 | +#pragma endregion |
| 1865 | + |
| 1866 | +typedef ifrt::Future<> IfRtFutureType; |
| 1867 | + |
| 1868 | +extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; } |
| 1869 | + |
| 1870 | +extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) { |
| 1871 | + return Future->IsReady(); |
| 1872 | +} |
| 1873 | + |
| 1874 | +extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); } |
| 1875 | + |
| 1876 | +#pragma region IfRtArray |
| 1877 | + |
| 1878 | +extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; } |
| 1879 | + |
| 1880 | +extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) { |
| 1881 | + absl::Span<const long> dims = array->obj()->shape().dims(); |
| 1882 | + int64_t *dims_ptr = new int64_t[dims.size()]; |
| 1883 | + std::copy(dims.begin(), dims.end(), dims_ptr); |
| 1884 | + return dims_ptr; |
| 1885 | +} |
| 1886 | + |
| 1887 | +extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) { |
| 1888 | + return array->obj()->shape().dims().size(); |
| 1889 | +} |
| 1890 | + |
| 1891 | +extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) { |
| 1892 | + return array->obj()->dtype(); |
| 1893 | +} |
| 1894 | + |
| 1895 | +extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) { |
| 1896 | + return array->obj()->client(); |
| 1897 | +} |
| 1898 | + |
| 1899 | +extern "C" HeldValue<std::shared_ptr<const ifrt::Sharding>> * |
| 1900 | +ifrt_array_to_sharding(HeldIfrtArray *array) { |
| 1901 | + return reactant::capture(array->obj()->shared_ptr_sharding()); |
| 1902 | +} |
| 1903 | + |
| 1904 | +extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array, |
| 1905 | + void *data) { |
| 1906 | + std::optional<absl::Span<const int64_t>> byte_strides; |
| 1907 | + auto future = array->obj()->CopyToHostBuffer( |
| 1908 | + data, byte_strides, static_cast<ifrt::ArrayCopySemantics>(0)); |
| 1909 | + future.Await(); |
| 1910 | + return; |
| 1911 | +} |
| 1912 | + |
1818 | 1913 | #pragma endregion
|
0 commit comments