Skip to content

Commit 17675c2

Browse files
committed
feat: initial low-level IFRT API
fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable
1 parent 59a6689 commit 17675c2

17 files changed

+739
-102
lines changed

deps/ReactantExtra/API.cpp

+124-31
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
15031503

15041504
extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }
15051505

1506+
extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) {
1507+
delete exec;
1508+
}
1509+
15061510
extern "C" void ifrt_loaded_executable_execute(
15071511
ifrt::LoadedExecutable *exec, int num_args,
15081512
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1572,31 +1576,48 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15721576

15731577
#pragma region IfRtClient
15741578

1575-
// right now only making it available for TPU
1576-
// in the future, we would like this for CPU and GPU PjRt backends too
15771579
extern "C" ifrt::proxy::GrpcServer *
1578-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1579-
const char *c_address, const char *tpu_path, const char **error) {
1580+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1581+
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
15801582
std::string address = c_address;
15811583

1582-
// taken from `MakeTPUClient`
1583-
std::string tpu_library_path;
1584-
if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
1585-
tpu_library_path = *path;
1586-
} else if (tpu_path) {
1587-
tpu_library_path = std::string(tpu_path);
1588-
} else {
1589-
*error = "Could not find TPU path";
1590-
return nullptr;
1591-
}
1584+
return MyValueOrThrow(
1585+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1586+
address,
1587+
[asynchronous, node_id, num_nodes]()
1588+
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1589+
auto pjrt_client = std::shared_ptr<PjRtClient>(
1590+
MakeCPUClient(asynchronous, node_id, num_nodes));
1591+
return std::shared_ptr<ifrt::Client>(
1592+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1593+
}))
1594+
.release();
1595+
}
15921596

1593-
const PJRT_Api *pluginLoad =
1594-
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
1595-
if (pluginLoad == nullptr)
1596-
return nullptr;
1597-
auto tpu_status = InitializePjrtPlugin("tpu", error);
1598-
if (tpu_status)
1599-
return nullptr;
1597+
extern "C" ifrt::proxy::GrpcServer *
1598+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1599+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1600+
double memory_fraction, bool preallocate, const char *platform_name,
1601+
const char **error) {
1602+
return MyValueOrThrow(
1603+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1604+
std::string(),
1605+
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1606+
memory_fraction, preallocate, platform_name,
1607+
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1608+
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1609+
node_id, num_nodes, allowed_devices, num_allowed_devices,
1610+
memory_fraction, preallocate, platform_name, error));
1611+
return std::shared_ptr<ifrt::Client>(
1612+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1613+
}))
1614+
.release();
1615+
}
1616+
1617+
extern "C" ifrt::proxy::GrpcServer *
1618+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1619+
const char *c_address, const char *tpu_path, const char **error) {
1620+
std::string address = c_address;
16001621

16011622
return MyValueOrThrow(
16021623
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
@@ -1636,28 +1657,27 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16361657
nullptr, // callback `on_connection_update`
16371658
};
16381659
return MyValueOrThrow(
1639-
ifrt::proxy::CreateClient(c_proxy_server_address, options))
1660+
ifrt::proxy::CreateClient(proxy_server_address, options))
16401661
.release();
16411662
}
16421663

1643-
extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id,
1644-
int num_nodes) {
1664+
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1665+
int node_id, int num_nodes) {
16451666
return ifrt_pjrt_make_client(
16461667
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
16471668
}
16481669

1649-
extern "C" ifrt::Client *
1650-
ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1651-
int num_allowed_devices, double memory_fraction,
1652-
bool preallocate, const char *platform_name,
1653-
const char **error, void *distributed_runtime_client) {
1670+
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
1671+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1672+
double memory_fraction, bool preallocate, const char *platform_name,
1673+
const char **error, void *distributed_runtime_client) {
16541674
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
16551675
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
16561676
preallocate, platform_name, error, distributed_runtime_client));
16571677
}
16581678

1659-
extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path,
1660-
const char **error) {
1679+
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1680+
const char **error) {
16611681
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
16621682
}
16631683

@@ -1847,6 +1867,79 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
18471867
return cstr_from_string(hlo_sharding->DebugString());
18481868
}
18491869

1870+
extern "C" void
1871+
free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1872+
delete sharding;
1873+
}
1874+
1875+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1876+
ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1877+
return reactant::capture(std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1878+
}
1879+
1880+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1881+
ifrt_sharding_from_hlo_sharding(
1882+
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1883+
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1884+
return ifrt_sharding_from_ifrt_hlo_sharding(
1885+
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1886+
xla_hlo_sharding));
1887+
}
1888+
1889+
extern "C" const char *
1890+
ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1891+
return cstr_from_string(sharding->obj()->DebugString());
1892+
}
1893+
1894+
#pragma endregion
1895+
1896+
typedef ifrt::Future<> IfRtFutureType;
1897+
1898+
extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; }
1899+
1900+
extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) {
1901+
return Future->IsReady();
1902+
}
1903+
1904+
extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); }
1905+
1906+
#pragma region IfRtArray
1907+
1908+
extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; }
1909+
1910+
extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) {
1911+
absl::Span<const long> dims = array->obj()->shape().dims();
1912+
int64_t *dims_ptr = new int64_t[dims.size()];
1913+
std::copy(dims.begin(), dims.end(), dims_ptr);
1914+
return dims_ptr;
1915+
}
1916+
1917+
extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) {
1918+
return array->obj()->shape().dims().size();
1919+
}
1920+
1921+
extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) {
1922+
return array->obj()->dtype();
1923+
}
1924+
1925+
extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) {
1926+
return array->obj()->client();
1927+
}
1928+
1929+
extern "C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1930+
ifrt_array_to_sharding(HeldIfrtArray *array) {
1931+
return reactant::capture(array->obj()->shared_ptr_sharding());
1932+
}
1933+
1934+
extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
1935+
void *data) {
1936+
std::optional<absl::Span<const int64_t>> byte_strides;
1937+
auto future = array->obj()->CopyToHostBuffer(
1938+
data, byte_strides, static_cast<ifrt::ArrayCopySemantics>(0));
1939+
future.Await();
1940+
return;
1941+
}
1942+
18501943
#pragma endregion
18511944

18521945
#pragma region PjRtDistributed

src/xla/Buffer.jl

+54
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ function buffer_on_cpu end
55
function to_host end
66
function unsafe_buffer_pointer end
77
function copy_buffer_to_device end
8+
function sharding end
9+
10+
Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer)
11+
12+
function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T}
13+
arr = zeros(T, reverse(size(buffer))...)
14+
XLA.to_host(buffer, arr)
15+
return arr
16+
end
817

918
@inline function client(
1019
buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}}
@@ -19,3 +28,48 @@ end
1928
)
2029
return map(synced_buffer, buffers)
2130
end
31+
32+
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
33+
print(io, "$(B) storing ")
34+
show(io, mime, convert(Array, buffer))
35+
return nothing
36+
end
37+
38+
# Async Buffers
39+
abstract type AbstractAsyncBuffer <: AbstractBuffer end
40+
41+
Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL
42+
43+
function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer)
44+
XLA.await(buffer)
45+
return convert(T, buffer.buffer)
46+
end
47+
48+
function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1}
49+
XLA.await(buffer)
50+
return convert(T, buffer.buffer)
51+
end
52+
53+
for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding)
54+
@eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer)
55+
end
56+
57+
function XLA.synced_buffer(buffer::AbstractAsyncBuffer)
58+
XLA.await(buffer)
59+
return buffer.buffer
60+
end
61+
62+
function XLA.await(buffer::AbstractAsyncBuffer)
63+
buffer.future === nothing && return nothing
64+
future = buffer.future
65+
buffer.future = nothing
66+
XLA.await(future)
67+
return nothing
68+
end
69+
70+
function XLA.is_ready(buffer::AbstractAsyncBuffer)
71+
buffer.future === nothing && return true
72+
return XLA.is_ready(buffer.future)
73+
end
74+
75+
XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)

src/xla/IFRT/Array.jl

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
mutable struct Array <: XLA.AbstractBuffer
2+
buffer::Ptr{Cvoid}
3+
4+
function Array(buffer::Ptr{Cvoid})
5+
return finalizer(free_ifrt_array, new(buffer))
6+
end
7+
end
8+
9+
function Array(client::Client, array::Base.Array{T,N}, device::Device) where {T,N}
10+
sizear = collect(Int64, reverse(size(array)))
11+
buffer = GC.@preserve array sizear begin
12+
@ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer(
13+
client.client::Ptr{Cvoid},
14+
pointer(array)::Ptr{T},
15+
XLA.primitive_type(T)::UInt64,
16+
N::Csize_t,
17+
pointer(sizear)::Ptr{Int64},
18+
0::Cint, # kAlwaysCopy
19+
device.device::Ptr{Cvoid},
20+
string(convert(MemoryKind, XLA.default_memory(device)))::Cstring,
21+
)::Ptr{Cvoid}
22+
end
23+
return Array(buffer)
24+
end
25+
26+
function Array(client::Client, array::Base.Array{T,N}, sharding::HloSharding) where {T,N}
27+
return Array(client, array, convert(Sharding, sharding))
28+
end
29+
30+
function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N}
31+
sizear = collect(Int64, reverse(size(array)))
32+
buffer = GC.@preserve array sizear begin
33+
@ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer(
34+
client.client::Ptr{Cvoid},
35+
pointer(array)::Ptr{T},
36+
XLA.primitive_type(T)::Cint,
37+
N::Csize_t,
38+
pointer(sizear)::Ptr{Int64},
39+
sharding.ptr::Ptr{Cvoid},
40+
0::Cint, # kAlwaysCopy
41+
)::Ptr{Cvoid}
42+
end
43+
return Array(buffer)
44+
end
45+
46+
@inline function free_ifrt_array(buffer::Array)
47+
sbuffer = buffer.buffer
48+
if sbuffer != C_NULL
49+
@ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid
50+
end
51+
end
52+
53+
function Base.ndims(buffer::Array)
54+
GC.@preserve buffer begin
55+
return @ccall MLIR.API.mlir_c.ifrt_array_ndims(buffer.buffer::Ptr{Cvoid})::Int64
56+
end
57+
end
58+
59+
function Base.size(buffer::Array)
60+
GC.@preserve buffer begin
61+
sz = @ccall MLIR.API.mlir_c.ifrt_array_shape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64}
62+
end
63+
return Tuple(unsafe_wrap(Base.Array, sz, ndims(buffer)))
64+
end
65+
66+
function Base.eltype(buffer::Array)
67+
GC.@preserve buffer begin
68+
return XLA.julia_type(
69+
@ccall MLIR.API.mlir_c.ifrt_array_eltype(buffer.buffer::Ptr{Cvoid})::Cint
70+
)
71+
end
72+
end
73+
74+
function XLA.device(::Array)
75+
return error("IFRT.Array can be sharded/replicated across multiple devices. Hence, \
76+
`XLA.device` is not defined.")
77+
end
78+
79+
function XLA.client(buffer::Array)
80+
GC.@preserve buffer begin
81+
return Client(
82+
@ccall MLIR.API.mlir_c.ifrt_array_to_client(
83+
buffer.buffer::Ptr{Cvoid}
84+
)::Ptr{Cvoid}
85+
)
86+
end
87+
end
88+
89+
XLA.synced_buffer(buffer::Array) = buffer
90+
91+
function XLA.buffer_on_cpu(::Array)
92+
return error("IFRT.Array does not support `XLA.buffer_on_cpu`")
93+
end
94+
95+
function XLA.to_host(buffer::Array, data)
96+
GC.@preserve buffer data begin
97+
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
98+
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
99+
)::Cvoid
100+
end
101+
return nothing
102+
end
103+
104+
function XLA.unsafe_buffer_pointer(::Array)
105+
return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`")
106+
end
107+
108+
function XLA.copy_buffer_to_device(::Array, ::Device)
109+
return error("IFRT.Array does not support `XLA.copy_buffer_to_device`")
110+
end
111+
112+
function XLA.sharding(buffer::Array)
113+
GC.@preserve buffer begin
114+
return Sharding(
115+
@ccall MLIR.API.mlir_c.ifrt_array_to_sharding(
116+
buffer.buffer::Ptr{Cvoid}
117+
)::Ptr{Cvoid}
118+
)
119+
end
120+
end

src/xla/IFRT/AsyncArray.jl

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mutable struct AsyncArray <: XLA.AbstractAsyncBuffer
2+
buffer::Array
3+
future::Union{Future,Nothing}
4+
end
5+
6+
const AsyncEmptyArray = AsyncArray(Array(C_NULL), nothing)
7+
8+
AsyncArray(args...; kwargs...) = AsyncArray(Array(args...; kwargs...), nothing)

0 commit comments

Comments
 (0)