Skip to content

Commit 408e120

Browse files
committed
feat: initial low-level IFRT API
fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable
1 parent c9dd714 commit 408e120

17 files changed

+743
-105
lines changed

deps/ReactantExtra/API.cpp

+128-34
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
15021502

15031503
extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }
15041504

1505+
extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) {
1506+
delete exec;
1507+
}
1508+
15051509
extern "C" void ifrt_loaded_executable_execute(
15061510
ifrt::LoadedExecutable *exec, int num_args,
15071511
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1571,38 +1575,56 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15711575

15721576
#pragma region IfRtClient
15731577

1574-
// right now only making it available for TPU
1575-
// in the future, we would like this for CPU and GPU PjRt backends too
15761578
extern "C" ifrt::proxy::GrpcServer *
1577-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1578-
const char *c_address, const char *tpu_path, const char **error) {
1579+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1580+
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
15791581
std::string address = c_address;
15801582

1581-
// taken from `MakeTPUClient`
1582-
std::string tpu_library_path;
1583-
if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
1584-
tpu_library_path = *path;
1585-
} else if (tpu_path) {
1586-
tpu_library_path = std::string(tpu_path);
1587-
} else {
1588-
*error = "Could not find TPU path";
1589-
return nullptr;
1590-
}
1583+
return MyValueOrThrow(
1584+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1585+
address,
1586+
[asynchronous, node_id, num_nodes]()
1587+
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1588+
auto pjrt_client = std::shared_ptr<PjRtClient>(
1589+
MakeCPUClient(asynchronous, node_id, num_nodes));
1590+
return std::shared_ptr<ifrt::Client>(
1591+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1592+
}))
1593+
.release();
1594+
}
15911595

1592-
const PJRT_Api *pluginLoad =
1593-
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
1594-
if (pluginLoad == nullptr)
1595-
return nullptr;
1596-
auto tpu_status = InitializePjrtPlugin("tpu", error);
1597-
if (tpu_status)
1598-
return nullptr;
1596+
extern "C" ifrt::proxy::GrpcServer *
1597+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1598+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1599+
double memory_fraction, bool preallocate, const char *platform_name,
1600+
const char **error) {
1601+
return MyValueOrThrow(
1602+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1603+
std::string(),
1604+
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1605+
memory_fraction, preallocate, platform_name,
1606+
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1607+
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1608+
node_id, num_nodes, allowed_devices, num_allowed_devices,
1609+
memory_fraction, preallocate, platform_name, error));
1610+
return std::shared_ptr<ifrt::Client>(
1611+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1612+
}))
1613+
.release();
1614+
}
1615+
1616+
extern "C" ifrt::proxy::GrpcServer *
1617+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1618+
const char *c_address, const char *tpu_path, const char **error) {
1619+
std::string address = c_address;
15991620

16001621
return MyValueOrThrow(
16011622
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
16021623
address,
1603-
[]() -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1604-
auto pjrt_client =
1605-
std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
1624+
[tpu_path, error]()
1625+
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626+
auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1627+
MakeTPUClient(tpu_path, error));
16061628
return std::shared_ptr<xla::ifrt::Client>(
16071629
xla::ifrt::PjRtClient::Create(pjrt_client).release());
16081630
}))
@@ -1636,28 +1658,27 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16361658
nullptr, // callback `on_connection_update`
16371659
};
16381660
return MyValueOrThrow(
1639-
ifrt::proxy::CreateClient(c_proxy_server_address, options))
1661+
ifrt::proxy::CreateClient(proxy_server_address, options))
16401662
.release();
16411663
}
16421664

1643-
extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id,
1644-
int num_nodes) {
1665+
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1666+
int node_id, int num_nodes) {
16451667
return ifrt_pjrt_make_client(
16461668
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
16471669
}
16481670

1649-
extern "C" ifrt::Client *
1650-
ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1651-
int num_allowed_devices, double memory_fraction,
1652-
bool preallocate, const char *platform_name,
1653-
const char **error, void *distributed_runtime_client) {
1671+
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
1672+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1673+
double memory_fraction, bool preallocate, const char *platform_name,
1674+
const char **error, void *distributed_runtime_client) {
16541675
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
16551676
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
16561677
preallocate, platform_name, error, distributed_runtime_client));
16571678
}
16581679

1659-
extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path,
1660-
const char **error) {
1680+
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1681+
const char **error) {
16611682
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
16621683
}
16631684

@@ -1847,6 +1868,79 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
18471868
return cstr_from_string(hlo_sharding->DebugString());
18481869
}
18491870

1871+
extern "C" void
1872+
free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1873+
delete sharding;
1874+
}
1875+
1876+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1877+
ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1878+
return reactant::capture(std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1879+
}
1880+
1881+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1882+
ifrt_sharding_from_hlo_sharding(
1883+
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1884+
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1885+
return ifrt_sharding_from_ifrt_hlo_sharding(
1886+
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1887+
xla_hlo_sharding));
1888+
}
1889+
1890+
extern "C" const char *
1891+
ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1892+
return cstr_from_string(sharding->obj()->DebugString());
1893+
}
1894+
1895+
#pragma endregion
1896+
1897+
typedef ifrt::Future<> IfRtFutureType;
1898+
1899+
extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; }
1900+
1901+
extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) {
1902+
return Future->IsReady();
1903+
}
1904+
1905+
extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); }
1906+
1907+
#pragma region IfRtArray
1908+
1909+
extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; }
1910+
1911+
extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) {
1912+
absl::Span<const long> dims = array->obj()->shape().dims();
1913+
int64_t *dims_ptr = new int64_t[dims.size()];
1914+
std::copy(dims.begin(), dims.end(), dims_ptr);
1915+
return dims_ptr;
1916+
}
1917+
1918+
extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) {
1919+
return array->obj()->shape().dims().size();
1920+
}
1921+
1922+
extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) {
1923+
return array->obj()->dtype();
1924+
}
1925+
1926+
extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) {
1927+
return array->obj()->client();
1928+
}
1929+
1930+
extern "C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1931+
ifrt_array_to_sharding(HeldIfrtArray *array) {
1932+
return reactant::capture(array->obj()->shared_ptr_sharding());
1933+
}
1934+
1935+
extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
1936+
void *data) {
1937+
std::optional<absl::Span<const int64_t>> byte_strides;
1938+
auto future = array->obj()->CopyToHostBuffer(
1939+
data, byte_strides, static_cast<ifrt::ArrayCopySemantics>(0));
1940+
future.Await();
1941+
return;
1942+
}
1943+
18501944
#pragma endregion
18511945

18521946
#pragma region PjRtDistributed

src/xla/Buffer.jl

+54
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ function buffer_on_cpu end
55
function to_host end
66
function unsafe_buffer_pointer end
77
function copy_buffer_to_device end
8+
function sharding end
9+
10+
Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer)
11+
12+
function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T}
13+
arr = zeros(T, reverse(size(buffer))...)
14+
XLA.to_host(buffer, arr)
15+
return arr
16+
end
817

918
@inline function client(
1019
buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}}
@@ -19,3 +28,48 @@ end
1928
)
2029
return map(synced_buffer, buffers)
2130
end
31+
32+
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
33+
print(io, "$(B) storing ")
34+
show(io, mime, convert(Array, buffer))
35+
return nothing
36+
end
37+
38+
# Async Buffers
39+
abstract type AbstractAsyncBuffer <: AbstractBuffer end
40+
41+
Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL
42+
43+
function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer)
44+
XLA.await(buffer)
45+
return convert(T, buffer.buffer)
46+
end
47+
48+
function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1}
49+
XLA.await(buffer)
50+
return convert(T, buffer.buffer)
51+
end
52+
53+
for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding)
54+
@eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer)
55+
end
56+
57+
function XLA.synced_buffer(buffer::AbstractAsyncBuffer)
58+
XLA.await(buffer)
59+
return buffer.buffer
60+
end
61+
62+
function XLA.await(buffer::AbstractAsyncBuffer)
63+
buffer.future === nothing && return nothing
64+
future = buffer.future
65+
buffer.future = nothing
66+
XLA.await(future)
67+
return nothing
68+
end
69+
70+
function XLA.is_ready(buffer::AbstractAsyncBuffer)
71+
buffer.future === nothing && return true
72+
return XLA.is_ready(buffer.future)
73+
end
74+
75+
XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)

src/xla/IFRT/Array.jl

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
mutable struct Array <: XLA.AbstractBuffer
2+
buffer::Ptr{Cvoid}
3+
4+
function Array(buffer::Ptr{Cvoid})
5+
return finalizer(free_ifrt_array, new(buffer))
6+
end
7+
end
8+
9+
function Array(client::Client, array::Base.Array{T,N}, device::Device) where {T,N}
10+
sizear = collect(Int64, reverse(size(array)))
11+
buffer = GC.@preserve array sizear begin
12+
@ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer(
13+
client.client::Ptr{Cvoid},
14+
pointer(array)::Ptr{T},
15+
XLA.primitive_type(T)::UInt64,
16+
N::Csize_t,
17+
pointer(sizear)::Ptr{Int64},
18+
0::Cint, # kAlwaysCopy
19+
device.device::Ptr{Cvoid},
20+
string(convert(MemoryKind, XLA.default_memory(device)))::Cstring,
21+
)::Ptr{Cvoid}
22+
end
23+
return Array(buffer)
24+
end
25+
26+
function Array(client::Client, array::Base.Array{T,N}, sharding::HloSharding) where {T,N}
27+
return Array(client, array, convert(Sharding, sharding))
28+
end
29+
30+
function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N}
31+
sizear = collect(Int64, reverse(size(array)))
32+
buffer = GC.@preserve array sizear begin
33+
@ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer(
34+
client.client::Ptr{Cvoid},
35+
pointer(array)::Ptr{T},
36+
XLA.primitive_type(T)::Cint,
37+
N::Csize_t,
38+
pointer(sizear)::Ptr{Int64},
39+
sharding.ptr::Ptr{Cvoid},
40+
0::Cint, # kAlwaysCopy
41+
)::Ptr{Cvoid}
42+
end
43+
return Array(buffer)
44+
end
45+
46+
@inline function free_ifrt_array(buffer::Array)
47+
sbuffer = buffer.buffer
48+
if sbuffer != C_NULL
49+
@ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid
50+
end
51+
end
52+
53+
function Base.ndims(buffer::Array)
54+
GC.@preserve buffer begin
55+
return @ccall MLIR.API.mlir_c.ifrt_array_ndims(buffer.buffer::Ptr{Cvoid})::Int64
56+
end
57+
end
58+
59+
function Base.size(buffer::Array)
60+
GC.@preserve buffer begin
61+
sz = @ccall MLIR.API.mlir_c.ifrt_array_shape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64}
62+
end
63+
return Tuple(unsafe_wrap(Base.Array, sz, ndims(buffer)))
64+
end
65+
66+
function Base.eltype(buffer::Array)
67+
GC.@preserve buffer begin
68+
return XLA.julia_type(
69+
@ccall MLIR.API.mlir_c.ifrt_array_eltype(buffer.buffer::Ptr{Cvoid})::Cint
70+
)
71+
end
72+
end
73+
74+
function XLA.device(::Array)
75+
return error("IFRT.Array can be sharded/replicated across multiple devices. Hence, \
76+
`XLA.device` is not defined.")
77+
end
78+
79+
function XLA.client(buffer::Array)
80+
GC.@preserve buffer begin
81+
return Client(
82+
@ccall MLIR.API.mlir_c.ifrt_array_to_client(
83+
buffer.buffer::Ptr{Cvoid}
84+
)::Ptr{Cvoid}
85+
)
86+
end
87+
end
88+
89+
XLA.synced_buffer(buffer::Array) = buffer
90+
91+
function XLA.buffer_on_cpu(::Array)
92+
return error("IFRT.Array does not support `XLA.buffer_on_cpu`")
93+
end
94+
95+
function XLA.to_host(buffer::Array, data)
96+
GC.@preserve buffer data begin
97+
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
98+
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
99+
)::Cvoid
100+
end
101+
return nothing
102+
end
103+
104+
function XLA.unsafe_buffer_pointer(::Array)
105+
return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`")
106+
end
107+
108+
function XLA.copy_buffer_to_device(::Array, ::Device)
109+
return error("IFRT.Array does not support `XLA.copy_buffer_to_device`")
110+
end
111+
112+
function XLA.sharding(buffer::Array)
113+
GC.@preserve buffer begin
114+
return Sharding(
115+
@ccall MLIR.API.mlir_c.ifrt_array_to_sharding(
116+
buffer.buffer::Ptr{Cvoid}
117+
)::Ptr{Cvoid}
118+
)
119+
end
120+
end

0 commit comments

Comments
 (0)