82
82
// IFRT
83
83
#include " xla/python/ifrt/array.h"
84
84
#include " xla/python/ifrt/attribute_map.h"
85
- #include " xla/python/ifrt/basic_device_list.h"
86
85
#include " xla/python/ifrt/client.h"
87
86
#include " xla/python/ifrt/compiler.h"
88
87
#include " xla/python/ifrt/device.h"
@@ -1800,26 +1799,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
1800
1799
return device->client ();
1801
1800
}
1802
1801
1803
- extern " C" HeldValue<tsl::RCReference<ifrt::DeviceList>> *
1804
- ifrt_CreateBasicDeviceListFromDevices (ifrt::Device **device_list,
1805
- int32_t num_devices) {
1802
+ tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices (
1803
+ ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
1806
1804
absl::Span<ifrt::Device *const > devices (device_list, num_devices);
1807
- return reactant::capture (ifrt::BasicDeviceList::Create (devices));
1808
- }
1809
-
1810
- extern " C" const char *ifrt_BasicDeviceListToString (
1811
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1812
- return cstr_from_string (device_list->obj ()->DebugString ());
1813
- }
1814
-
1815
- extern " C" int ifrt_BasicDeviceListSize (
1816
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1817
- return device_list->obj ()->size ();
1818
- }
1819
-
1820
- extern " C" ifrt::Device *const ifrt_BasicDeviceListGetDevice (
1821
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, int32_t index) {
1822
- return device_list->obj ()->devices ()[index ];
1805
+ return client->MakeDeviceList (devices);
1823
1806
}
1824
1807
1825
1808
extern " C" ifrt::Memory *ifrt_DeviceGetDefaultMemory (ifrt::Device *device) {
@@ -1889,10 +1872,11 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
1889
1872
}
1890
1873
1891
1874
extern " C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding (
1892
- HeldValue<tsl::RCReference< ifrt::DeviceList>> * device_list,
1875
+ ifrt::Client *client, ifrt::Device ** device_list, int32_t num_devices ,
1893
1876
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1894
- return ifrt::HloSharding::Create (device_list->obj (), *memory_kind,
1895
- *xla_hlo_sharding)
1877
+ return ifrt::HloSharding::Create (
1878
+ ifrt_CreateDeviceListFromDevices (client, device_list, num_devices),
1879
+ *memory_kind, *xla_hlo_sharding)
1896
1880
.release ();
1897
1881
}
1898
1882
@@ -1919,12 +1903,13 @@ ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1919
1903
}
1920
1904
1921
1905
extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1922
- ifrt_sharding_from_hlo_sharding (
1923
- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1924
- ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1906
+ ifrt_sharding_from_hlo_sharding (ifrt::Client *client,
1907
+ ifrt::Device **device_list, int32_t num_devices,
1908
+ ifrt::MemoryKind *memory_kind,
1909
+ xla::HloSharding *xla_hlo_sharding) {
1925
1910
return ifrt_sharding_from_ifrt_hlo_sharding (
1926
- ifrt_hlo_sharding_from_xla_hlo_sharding (device_list, memory_kind ,
1927
- xla_hlo_sharding));
1911
+ ifrt_hlo_sharding_from_xla_hlo_sharding (client, device_list, num_devices ,
1912
+ memory_kind, xla_hlo_sharding));
1928
1913
}
1929
1914
1930
1915
extern " C" const char *
0 commit comments