81
81
82
82
// IFRT
83
83
#include " xla/python/ifrt/array.h"
84
- #include " xla/python/ifrt/basic_device_list.h"
85
84
#include " xla/python/ifrt/client.h"
86
85
#include " xla/python/ifrt/compiler.h"
87
86
#include " xla/python/ifrt/device.h"
@@ -1799,26 +1798,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
1799
1798
return device->client ();
1800
1799
}
1801
1800
1802
- extern " C" HeldValue<tsl::RCReference<ifrt::DeviceList>> *
1803
- ifrt_CreateBasicDeviceListFromDevices (ifrt::Device **device_list,
1804
- int32_t num_devices) {
1801
+ tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices (
1802
+ ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
1805
1803
absl::Span<ifrt::Device *const > devices (device_list, num_devices);
1806
- return reactant::capture (ifrt::BasicDeviceList::Create (devices));
1807
- }
1808
-
1809
- extern " C" const char *ifrt_BasicDeviceListToString (
1810
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1811
- return cstr_from_string (device_list->obj ()->DebugString ());
1812
- }
1813
-
1814
- extern " C" int ifrt_BasicDeviceListSize (
1815
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1816
- return device_list->obj ()->size ();
1817
- }
1818
-
1819
- extern " C" ifrt::Device *const ifrt_BasicDeviceListGetDevice (
1820
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, int32_t index) {
1821
- return device_list->obj ()->devices ()[index ];
1804
+ return client->MakeDeviceList (devices);
1822
1805
}
1823
1806
1824
1807
extern " C" ifrt::Memory *ifrt_DeviceGetDefaultMemory (ifrt::Device *device) {
@@ -1888,10 +1871,11 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
1888
1871
}
1889
1872
1890
1873
extern " C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding (
1891
- HeldValue<tsl::RCReference< ifrt::DeviceList>> * device_list,
1874
+ ifrt::Client *client, ifrt::Device ** device_list, int32_t num_devices ,
1892
1875
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1893
- return ifrt::HloSharding::Create (device_list->obj (), *memory_kind,
1894
- *xla_hlo_sharding)
1876
+ return ifrt::HloSharding::Create (
1877
+ ifrt_CreateDeviceListFromDevices (client, device_list, num_devices),
1878
+ *memory_kind, *xla_hlo_sharding)
1895
1879
.release ();
1896
1880
}
1897
1881
@@ -1918,12 +1902,13 @@ ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1918
1902
}
1919
1903
1920
1904
extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1921
- ifrt_sharding_from_hlo_sharding (
1922
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1923
- ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1905
+ ifrt_sharding_from_hlo_sharding (ifrt::Client *client,
1906
+ ifrt::Device **device_list, int32_t num_devices,
1907
+ ifrt::MemoryKind *memory_kind,
1908
+ xla::HloSharding *xla_hlo_sharding) {
1924
1909
return ifrt_sharding_from_ifrt_hlo_sharding (
1925
- ifrt_hlo_sharding_from_xla_hlo_sharding (device_list, memory_kind ,
1926
- xla_hlo_sharding));
1910
+ ifrt_hlo_sharding_from_xla_hlo_sharding (client, device_list, num_devices ,
1911
+ memory_kind, xla_hlo_sharding));
1927
1912
}
1928
1913
1929
1914
extern " C" const char *
0 commit comments