@@ -1502,6 +1502,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
1502
1502
1503
1503
extern " C" void ifrt_array_dtor (HeldIfrtArray *array) { delete array; }
1504
1504
1505
+ extern " C" void ifrt_loaded_executable_dtor (ifrt::LoadedExecutable *exec) {
1506
+ delete exec;
1507
+ }
1508
+
1505
1509
extern " C" void ifrt_loaded_executable_execute (
1506
1510
ifrt::LoadedExecutable *exec, int num_args,
1507
1511
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1571,38 +1575,56 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
1571
1575
1572
1576
#pragma region IfRtClient
1573
1577
1574
- // right now only making it available for TPU
1575
- // in the future, we would like this for CPU and GPU PjRt backends too
1576
1578
extern " C" ifrt::proxy::GrpcServer *
1577
- ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu (
1578
- const char *c_address, const char *tpu_path, const char **error ) {
1579
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu (
1580
+ const char *c_address, uint8_t asynchronous, int node_id, int num_nodes ) {
1579
1581
std::string address = c_address;
1580
1582
1581
- // taken from `MakeTPUClient`
1582
- std::string tpu_library_path;
1583
- if (auto path = llvm::sys::Process::GetEnv (kEnvTpuLibraryPath )) {
1584
- tpu_library_path = *path;
1585
- } else if (tpu_path) {
1586
- tpu_library_path = std::string (tpu_path);
1587
- } else {
1588
- *error = " Could not find TPU path" ;
1589
- return nullptr ;
1590
- }
1583
+ return MyValueOrThrow (
1584
+ ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1585
+ address,
1586
+ [asynchronous, node_id, num_nodes]()
1587
+ -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1588
+ auto pjrt_client = std::shared_ptr<PjRtClient>(
1589
+ MakeCPUClient (asynchronous, node_id, num_nodes));
1590
+ return std::shared_ptr<ifrt::Client>(
1591
+ xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1592
+ }))
1593
+ .release ();
1594
+ }
1591
1595
1592
- const PJRT_Api *pluginLoad =
1593
- LoadPjrtPlugin (" tpu" , tpu_library_path.c_str (), error);
1594
- if (pluginLoad == nullptr )
1595
- return nullptr ;
1596
- auto tpu_status = InitializePjrtPlugin (" tpu" , error);
1597
- if (tpu_status)
1598
- return nullptr ;
1596
+ extern " C" ifrt::proxy::GrpcServer *
1597
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu (
1598
+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1599
+ double memory_fraction, bool preallocate, const char *platform_name,
1600
+ const char **error) {
1601
+ return MyValueOrThrow (
1602
+ ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1603
+ std::string (),
1604
+ [node_id, num_nodes, allowed_devices, num_allowed_devices,
1605
+ memory_fraction, preallocate, platform_name,
1606
+ error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1607
+ auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient (
1608
+ node_id, num_nodes, allowed_devices, num_allowed_devices,
1609
+ memory_fraction, preallocate, platform_name, error));
1610
+ return std::shared_ptr<ifrt::Client>(
1611
+ xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1612
+ }))
1613
+ .release ();
1614
+ }
1615
+
1616
+ extern " C" ifrt::proxy::GrpcServer *
1617
+ ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu (
1618
+ const char *c_address, const char *tpu_path, const char **error) {
1619
+ std::string address = c_address;
1599
1620
1600
1621
return MyValueOrThrow (
1601
1622
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
1602
1623
address,
1603
- []() -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1604
- auto pjrt_client =
1605
- std::shared_ptr<xla::PjRtClient>(GetCApiClient (" TPU" ));
1624
+ [tpu_path, error]()
1625
+ -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626
+ auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1627
+ MakeTPUClient (tpu_path, error));
1606
1628
return std::shared_ptr<xla::ifrt::Client>(
1607
1629
xla::ifrt::PjRtClient::Create (pjrt_client).release ());
1608
1630
}))
@@ -1636,28 +1658,27 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
1636
1658
nullptr , // callback `on_connection_update`
1637
1659
};
1638
1660
return MyValueOrThrow (
1639
- ifrt::proxy::CreateClient (c_proxy_server_address , options))
1661
+ ifrt::proxy::CreateClient (proxy_server_address , options))
1640
1662
.release ();
1641
1663
}
1642
1664
1643
- extern " C" ifrt::Client *ifrt_make_cpu_client (uint8_t asynchronous, int node_id ,
1644
- int num_nodes) {
1665
+ extern " C" ifrt::Client *ifrt_make_pjrt_cpu_client (uint8_t asynchronous,
1666
+ int node_id, int num_nodes) {
1645
1667
return ifrt_pjrt_make_client (
1646
1668
pjrt_make_cpu_client_shared (asynchronous, node_id, num_nodes));
1647
1669
}
1648
1670
1649
- extern " C" ifrt::Client *
1650
- ifrt_make_gpu_client (int node_id, int num_nodes, int *allowed_devices,
1651
- int num_allowed_devices, double memory_fraction,
1652
- bool preallocate, const char *platform_name,
1653
- const char **error, void *distributed_runtime_client) {
1671
+ extern " C" ifrt::Client *ifrt_make_pjrt_gpu_client (
1672
+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1673
+ double memory_fraction, bool preallocate, const char *platform_name,
1674
+ const char **error, void *distributed_runtime_client) {
1654
1675
return ifrt_pjrt_make_client (pjrt_make_gpu_client_shared (
1655
1676
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1656
1677
preallocate, platform_name, error, distributed_runtime_client));
1657
1678
}
1658
1679
1659
- extern " C" ifrt::Client *ifrt_make_tpu_client (const char *tpu_path,
1660
- const char **error) {
1680
+ extern " C" ifrt::Client *ifrt_make_pjrt_tpu_client (const char *tpu_path,
1681
+ const char **error) {
1661
1682
return ifrt_pjrt_make_client (pjrt_make_tpu_client_shared (tpu_path, error));
1662
1683
}
1663
1684
@@ -1847,6 +1868,79 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
1847
1868
return cstr_from_string (hlo_sharding->DebugString ());
1848
1869
}
1849
1870
1871
+ extern " C" void
1872
+ free_ifrt_sharding (HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1873
+ delete sharding;
1874
+ }
1875
+
1876
+ extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1877
+ ifrt_sharding_from_ifrt_hlo_sharding (ifrt::HloSharding *hlo_sharding) {
1878
+ return reactant::capture (std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1879
+ }
1880
+
1881
+ extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1882
+ ifrt_sharding_from_hlo_sharding (
1883
+ HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1884
+ ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1885
+ return ifrt_sharding_from_ifrt_hlo_sharding (
1886
+ ifrt_hlo_sharding_from_xla_hlo_sharding (device_list, memory_kind,
1887
+ xla_hlo_sharding));
1888
+ }
1889
+
1890
+ extern " C" const char *
1891
+ ifrt_sharding_to_string (HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1892
+ return cstr_from_string (sharding->obj ()->DebugString ());
1893
+ }
1894
+
1895
+ #pragma endregion
1896
+
1897
+ typedef ifrt::Future<> IfRtFutureType;
1898
+
1899
+ extern " C" void ifrt_free_future (IfRtFutureType *Future) { delete Future; }
1900
+
1901
+ extern " C" uint8_t ifrt_future_is_ready (IfRtFutureType *Future) {
1902
+ return Future->IsReady ();
1903
+ }
1904
+
1905
+ extern " C" void ifrt_future_await (IfRtFutureType *Future) { Future->Await (); }
1906
+
1907
+ #pragma region IfRtArray
1908
+
1909
+ extern " C" void ifrt_free_array (HeldIfrtArray *array) { delete array; }
1910
+
1911
+ extern " C" int64_t *ifrt_array_shape (HeldIfrtArray *array) {
1912
+ absl::Span<const long > dims = array->obj ()->shape ().dims ();
1913
+ int64_t *dims_ptr = new int64_t [dims.size ()];
1914
+ std::copy (dims.begin (), dims.end (), dims_ptr);
1915
+ return dims_ptr;
1916
+ }
1917
+
1918
+ extern " C" int64_t ifrt_array_ndims (HeldIfrtArray *array) {
1919
+ return array->obj ()->shape ().dims ().size ();
1920
+ }
1921
+
1922
+ extern " C" ifrt::DType ifrt_array_eltype (HeldIfrtArray *array) {
1923
+ return array->obj ()->dtype ();
1924
+ }
1925
+
1926
+ extern " C" ifrt::Client *ifrt_array_to_client (HeldIfrtArray *array) {
1927
+ return array->obj ()->client ();
1928
+ }
1929
+
1930
+ extern " C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1931
+ ifrt_array_to_sharding (HeldIfrtArray *array) {
1932
+ return reactant::capture (array->obj ()->shared_ptr_sharding ());
1933
+ }
1934
+
1935
+ extern " C" void ifrt_array_copy_to_host_buffer (HeldIfrtArray *array,
1936
+ void *data) {
1937
+ std::optional<absl::Span<const int64_t >> byte_strides;
1938
+ auto future = array->obj ()->CopyToHostBuffer (
1939
+ data, byte_strides, static_cast <ifrt::ArrayCopySemantics>(0 ));
1940
+ future.Await ();
1941
+ return ;
1942
+ }
1943
+
1850
1944
#pragma endregion
1851
1945
1852
1946
#pragma region PjRtDistributed
0 commit comments