@@ -1503,6 +1503,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
1503
1503
1504
1504
extern " C" void ifrt_array_dtor (HeldIfrtArray *array) { delete array; }
1505
1505
1506
+ extern " C" void ifrt_loaded_executable_dtor (ifrt::LoadedExecutable *exec) {
1507
+ delete exec;
1508
+ }
1509
+
1506
1510
extern " C" void ifrt_loaded_executable_execute (
1507
1511
ifrt::LoadedExecutable *exec, int num_args,
1508
1512
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1572,31 +1576,48 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
1572
1576
1573
1577
#pragma region IfRtClient
1574
1578
1575
- // right now only making it available for TPU
1576
- // in the future, we would like this for CPU and GPU PjRt backends too
1577
1579
extern " C" ifrt::proxy::GrpcServer *
1578
- ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu (
1579
- const char *c_address, const char *tpu_path, const char **error ) {
1580
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu (
1581
+ const char *c_address, uint8_t asynchronous, int node_id, int num_nodes ) {
1580
1582
std::string address = c_address;
1581
1583
1582
- // taken from `MakeTPUClient`
1583
- std::string tpu_library_path;
1584
- if (auto path = llvm::sys::Process::GetEnv (kEnvTpuLibraryPath )) {
1585
- tpu_library_path = *path;
1586
- } else if (tpu_path) {
1587
- tpu_library_path = std::string (tpu_path);
1588
- } else {
1589
- *error = " Could not find TPU path" ;
1590
- return nullptr ;
1591
- }
1584
+ return MyValueOrThrow (
1585
+ ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1586
+ address,
1587
+ [asynchronous, node_id, num_nodes]()
1588
+ -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1589
+ auto pjrt_client = std::shared_ptr<PjRtClient>(
1590
+ MakeCPUClient (asynchronous, node_id, num_nodes));
1591
+ return std::shared_ptr<ifrt::Client>(
1592
+ xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1593
+ }))
1594
+ .release ();
1595
+ }
1592
1596
1593
- const PJRT_Api *pluginLoad =
1594
- LoadPjrtPlugin (" tpu" , tpu_library_path.c_str (), error);
1595
- if (pluginLoad == nullptr )
1596
- return nullptr ;
1597
- auto tpu_status = InitializePjrtPlugin (" tpu" , error);
1598
- if (tpu_status)
1599
- return nullptr ;
1597
+ extern " C" ifrt::proxy::GrpcServer *
1598
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu (
1599
+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1600
+ double memory_fraction, bool preallocate, const char *platform_name,
1601
+ const char **error) {
1602
+ return MyValueOrThrow (
1603
+ ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1604
+ std::string (),
1605
+ [node_id, num_nodes, allowed_devices, num_allowed_devices,
1606
+ memory_fraction, preallocate, platform_name,
1607
+ error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1608
+ auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient (
1609
+ node_id, num_nodes, allowed_devices, num_allowed_devices,
1610
+ memory_fraction, preallocate, platform_name, error));
1611
+ return std::shared_ptr<ifrt::Client>(
1612
+ xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1613
+ }))
1614
+ .release ();
1615
+ }
1616
+
1617
+ extern " C" ifrt::proxy::GrpcServer *
1618
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu (
1619
+ const char *c_address, const char *tpu_path, const char **error) {
1620
+ std::string address = c_address;
1600
1621
1601
1622
return MyValueOrThrow (
1602
1623
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
@@ -1636,28 +1657,27 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
1636
1657
nullptr , // callback `on_connection_update`
1637
1658
};
1638
1659
return MyValueOrThrow (
1639
- ifrt::proxy::CreateClient (c_proxy_server_address , options))
1660
+ ifrt::proxy::CreateClient (proxy_server_address , options))
1640
1661
.release ();
1641
1662
}
1642
1663
1643
- extern " C" ifrt::Client *ifrt_make_cpu_client (uint8_t asynchronous, int node_id ,
1644
- int num_nodes) {
1664
+ extern " C" ifrt::Client *ifrt_make_pjrt_cpu_client (uint8_t asynchronous,
1665
+ int node_id, int num_nodes) {
1645
1666
return ifrt_pjrt_make_client (
1646
1667
pjrt_make_cpu_client_shared (asynchronous, node_id, num_nodes));
1647
1668
}
1648
1669
1649
- extern " C" ifrt::Client *
1650
- ifrt_make_gpu_client (int node_id, int num_nodes, int *allowed_devices,
1651
- int num_allowed_devices, double memory_fraction,
1652
- bool preallocate, const char *platform_name,
1653
- const char **error, void *distributed_runtime_client) {
1670
+ extern " C" ifrt::Client *ifrt_make_pjrt_gpu_client (
1671
+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1672
+ double memory_fraction, bool preallocate, const char *platform_name,
1673
+ const char **error, void *distributed_runtime_client) {
1654
1674
return ifrt_pjrt_make_client (pjrt_make_gpu_client_shared (
1655
1675
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1656
1676
preallocate, platform_name, error, distributed_runtime_client));
1657
1677
}
1658
1678
1659
- extern " C" ifrt::Client *ifrt_make_tpu_client (const char *tpu_path,
1660
- const char **error) {
1679
+ extern " C" ifrt::Client *ifrt_make_pjrt_tpu_client (const char *tpu_path,
1680
+ const char **error) {
1661
1681
return ifrt_pjrt_make_client (pjrt_make_tpu_client_shared (tpu_path, error));
1662
1682
}
1663
1683
@@ -1847,6 +1867,79 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
1847
1867
return cstr_from_string (hlo_sharding->DebugString ());
1848
1868
}
1849
1869
1870
+ extern " C" void
1871
+ free_ifrt_sharding (HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1872
+ delete sharding;
1873
+ }
1874
+
1875
+ extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1876
+ ifrt_sharding_from_ifrt_hlo_sharding (ifrt::HloSharding *hlo_sharding) {
1877
+ return reactant::capture (std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1878
+ }
1879
+
1880
+ extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1881
+ ifrt_sharding_from_hlo_sharding (
1882
+ HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1883
+ ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1884
+ return ifrt_sharding_from_ifrt_hlo_sharding (
1885
+ ifrt_hlo_sharding_from_xla_hlo_sharding (device_list, memory_kind,
1886
+ xla_hlo_sharding));
1887
+ }
1888
+
1889
+ extern " C" const char *
1890
+ ifrt_sharding_to_string (HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1891
+ return cstr_from_string (sharding->obj ()->DebugString ());
1892
+ }
1893
+
1894
+ #pragma endregion
1895
+
1896
+ typedef ifrt::Future<> IfRtFutureType;
1897
+
1898
+ extern " C" void ifrt_free_future (IfRtFutureType *Future) { delete Future; }
1899
+
1900
+ extern " C" uint8_t ifrt_future_is_ready (IfRtFutureType *Future) {
1901
+ return Future->IsReady ();
1902
+ }
1903
+
1904
+ extern " C" void ifrt_future_await (IfRtFutureType *Future) { Future->Await (); }
1905
+
1906
+ #pragma region IfRtArray
1907
+
1908
+ extern " C" void ifrt_free_array (HeldIfrtArray *array) { delete array; }
1909
+
1910
+ extern " C" int64_t *ifrt_array_shape (HeldIfrtArray *array) {
1911
+ absl::Span<const long > dims = array->obj ()->shape ().dims ();
1912
+ int64_t *dims_ptr = new int64_t [dims.size ()];
1913
+ std::copy (dims.begin (), dims.end (), dims_ptr);
1914
+ return dims_ptr;
1915
+ }
1916
+
1917
+ extern " C" int64_t ifrt_array_ndims (HeldIfrtArray *array) {
1918
+ return array->obj ()->shape ().dims ().size ();
1919
+ }
1920
+
1921
+ extern " C" ifrt::DType ifrt_array_eltype (HeldIfrtArray *array) {
1922
+ return array->obj ()->dtype ();
1923
+ }
1924
+
1925
+ extern " C" ifrt::Client *ifrt_array_to_client (HeldIfrtArray *array) {
1926
+ return array->obj ()->client ();
1927
+ }
1928
+
1929
+ extern " C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1930
+ ifrt_array_to_sharding (HeldIfrtArray *array) {
1931
+ return reactant::capture (array->obj ()->shared_ptr_sharding ());
1932
+ }
1933
+
1934
+ extern " C" void ifrt_array_copy_to_host_buffer (HeldIfrtArray *array,
1935
+ void *data) {
1936
+ std::optional<absl::Span<const int64_t >> byte_strides;
1937
+ auto future = array->obj ()->CopyToHostBuffer (
1938
+ data, byte_strides, static_cast <ifrt::ArrayCopySemantics>(0 ));
1939
+ future.Await ();
1940
+ return ;
1941
+ }
1942
+
1850
1943
#pragma endregion
1851
1944
1852
1945
#pragma region PjRtDistributed
0 commit comments