Skip to content

Commit 6da5b9d

Browse files
committed
feat: initial low-level IFRT API
fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable
1 parent 83a2c1d commit 6da5b9d

17 files changed

+744
-105
lines changed

deps/ReactantExtra/API.cpp

+129-34
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@
7676

7777
// IFRT
7878
#include "xla/python/ifrt/array.h"
79+
#include "xla/python/ifrt/basic_device_list.h"
7980
#include "xla/python/ifrt/client.h"
8081
#include "xla/python/ifrt/compiler.h"
8182
#include "xla/python/ifrt/device.h"
8283
#include "xla/python/ifrt/device_list.h"
83-
#include "xla/python/ifrt/basic_device_list.h"
8484
#include "xla/python/ifrt/dtype.h"
8585
#include "xla/python/ifrt/executable.h"
8686
#include "xla/python/ifrt/hlo/hlo_program.h"
@@ -1469,6 +1469,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
14691469

14701470
extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }
14711471

1472+
extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) {
1473+
delete exec;
1474+
}
1475+
14721476
extern "C" void ifrt_loaded_executable_execute(
14731477
ifrt::LoadedExecutable *exec, int num_args,
14741478
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1538,38 +1542,56 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15381542

15391543
#pragma region IfRtClient
15401544

1541-
// right now only making it available for TPU
1542-
// in the future, we would like this for CPU and GPU PjRt backends too
15431545
extern "C" ifrt::proxy::GrpcServer *
1544-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1545-
const char *c_address, const char *tpu_path, const char **error) {
1546+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1547+
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
15461548
std::string address = c_address;
15471549

1548-
// taken from `MakeTPUClient`
1549-
std::string tpu_library_path;
1550-
if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
1551-
tpu_library_path = *path;
1552-
} else if (tpu_path) {
1553-
tpu_library_path = std::string(tpu_path);
1554-
} else {
1555-
*error = "Could not find TPU path";
1556-
return nullptr;
1557-
}
1550+
return MyValueOrThrow(
1551+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1552+
address,
1553+
[asynchronous, node_id, num_nodes]()
1554+
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1555+
auto pjrt_client = std::shared_ptr<PjRtClient>(
1556+
MakeCPUClient(asynchronous, node_id, num_nodes));
1557+
return std::shared_ptr<ifrt::Client>(
1558+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1559+
}))
1560+
.release();
1561+
}
15581562

1559-
const PJRT_Api *pluginLoad =
1560-
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
1561-
if (pluginLoad == nullptr)
1562-
return nullptr;
1563-
auto tpu_status = InitializePjrtPlugin("tpu", error);
1564-
if (tpu_status)
1565-
return nullptr;
1563+
extern "C" ifrt::proxy::GrpcServer *
1564+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1565+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1566+
double memory_fraction, bool preallocate, const char *platform_name,
1567+
const char **error) {
1568+
return MyValueOrThrow(
1569+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1570+
std::string(),
1571+
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1572+
memory_fraction, preallocate, platform_name,
1573+
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1574+
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1575+
node_id, num_nodes, allowed_devices, num_allowed_devices,
1576+
memory_fraction, preallocate, platform_name, error));
1577+
return std::shared_ptr<ifrt::Client>(
1578+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1579+
}))
1580+
.release();
1581+
}
1582+
1583+
extern "C" ifrt::proxy::GrpcServer *
1584+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1585+
const char *c_address, const char *tpu_path, const char **error) {
1586+
std::string address = c_address;
15661587

15671588
return MyValueOrThrow(
15681589
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
15691590
address,
1570-
[]() -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1571-
auto pjrt_client =
1572-
std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
1591+
[tpu_path, error]()
1592+
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1593+
auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1594+
MakeTPUClient(tpu_path, error));
15731595
return std::shared_ptr<xla::ifrt::Client>(
15741596
xla::ifrt::PjRtClient::Create(pjrt_client).release());
15751597
}))
@@ -1604,28 +1626,28 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16041626
nullptr, // callback `on_connection_update`
16051627
};
16061628
return MyValueOrThrow(
1607-
ifrt::proxy::CreateClient(c_proxy_server_address, options))
1629+
ifrt::proxy::CreateClient(proxy_server_address, options))
16081630
.release();
16091631
}
16101632

1611-
extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id,
1612-
int num_nodes) {
1633+
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1634+
int node_id, int num_nodes) {
16131635
return ifrt_pjrt_make_client(
16141636
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
16151637
}
16161638

16171639
extern "C" ifrt::Client *
1618-
ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1619-
int num_allowed_devices, double memory_fraction,
1620-
bool preallocate, const char *platform_name,
1621-
const char **error) {
1640+
ifrt_make_pjrt_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1641+
int num_allowed_devices, double memory_fraction,
1642+
bool preallocate, const char *platform_name,
1643+
const char **error) {
16221644
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
16231645
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
16241646
preallocate, platform_name, error));
16251647
}
16261648

1627-
extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path,
1628-
const char **error) {
1649+
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1650+
const char **error) {
16291651
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
16301652
}
16311653

@@ -1815,4 +1837,77 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
18151837
return cstr_from_string(hlo_sharding->DebugString());
18161838
}
18171839

1840+
extern "C" void
1841+
free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1842+
delete sharding;
1843+
}
1844+
1845+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1846+
ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1847+
return reactant::capture(std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1848+
}
1849+
1850+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1851+
ifrt_sharding_from_hlo_sharding(
1852+
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1853+
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1854+
return ifrt_sharding_from_ifrt_hlo_sharding(
1855+
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1856+
xla_hlo_sharding));
1857+
}
1858+
1859+
extern "C" const char *
1860+
ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1861+
return cstr_from_string(sharding->obj()->DebugString());
1862+
}
1863+
1864+
#pragma endregion
1865+
1866+
typedef ifrt::Future<> IfRtFutureType;
1867+
1868+
extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; }
1869+
1870+
extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) {
1871+
return Future->IsReady();
1872+
}
1873+
1874+
extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); }
1875+
1876+
#pragma region IfRtArray
1877+
1878+
extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; }
1879+
1880+
extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) {
1881+
absl::Span<const long> dims = array->obj()->shape().dims();
1882+
int64_t *dims_ptr = new int64_t[dims.size()];
1883+
std::copy(dims.begin(), dims.end(), dims_ptr);
1884+
return dims_ptr;
1885+
}
1886+
1887+
extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) {
1888+
return array->obj()->shape().dims().size();
1889+
}
1890+
1891+
extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) {
1892+
return array->obj()->dtype();
1893+
}
1894+
1895+
extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) {
1896+
return array->obj()->client();
1897+
}
1898+
1899+
extern "C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1900+
ifrt_array_to_sharding(HeldIfrtArray *array) {
1901+
return reactant::capture(array->obj()->shared_ptr_sharding());
1902+
}
1903+
1904+
extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
1905+
void *data) {
1906+
std::optional<absl::Span<const int64_t>> byte_strides;
1907+
auto future = array->obj()->CopyToHostBuffer(
1908+
data, byte_strides, static_cast<ifrt::ArrayCopySemantics>(0));
1909+
future.Await();
1910+
return;
1911+
}
1912+
18181913
#pragma endregion

src/xla/Buffer.jl

+54
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ function buffer_on_cpu end
55
function to_host end
66
function unsafe_buffer_pointer end
77
function copy_buffer_to_device end
8+
function sharding end
9+
10+
Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer)
11+
12+
function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T}
13+
arr = zeros(T, reverse(size(buffer))...)
14+
XLA.to_host(buffer, arr)
15+
return arr
16+
end
817

918
@inline function client(
1019
buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}}
@@ -19,3 +28,48 @@ end
1928
)
2029
return map(synced_buffer, buffers)
2130
end
31+
32+
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
33+
print(io, "$(B) storing ")
34+
show(io, mime, convert(Array, buffer))
35+
return nothing
36+
end
37+
38+
# Async Buffers
39+
abstract type AbstractAsyncBuffer <: AbstractBuffer end
40+
41+
Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL
42+
43+
function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer)
44+
XLA.await(buffer)
45+
return convert(T, buffer.buffer)
46+
end
47+
48+
function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1}
49+
XLA.await(buffer)
50+
return convert(T, buffer.buffer)
51+
end
52+
53+
for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding)
54+
@eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer)
55+
end
56+
57+
function XLA.synced_buffer(buffer::AbstractAsyncBuffer)
58+
XLA.await(buffer)
59+
return buffer.buffer
60+
end
61+
62+
function XLA.await(buffer::AbstractAsyncBuffer)
63+
buffer.future === nothing && return nothing
64+
future = buffer.future
65+
buffer.future = nothing
66+
XLA.await(future)
67+
return nothing
68+
end
69+
70+
function XLA.is_ready(buffer::AbstractAsyncBuffer)
71+
buffer.future === nothing && return true
72+
return XLA.is_ready(buffer.future)
73+
end
74+
75+
XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)

0 commit comments

Comments
 (0)