From f640775bfcbd963998756669b8dccc3862edea84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 13 Feb 2025 02:13:49 +0100 Subject: [PATCH 1/4] Add initial IFRT Julia bindings This reverts commit c76a9002b0c36a29b553e8bbb5083d2d41319b23. --- src/xla/Buffer.jl | 22 ++++++++++++--- src/xla/Client.jl | 18 ++++++++++-- src/xla/IFRT/Array.jl | 38 +++++++++++++++++++++++++ src/xla/IFRT/Client.jl | 31 +++++++++++++++++++++ src/xla/IFRT/IFRT.jl | 13 +++++++++ src/xla/IFRT/LoadedExecutable.jl | 48 ++++++++++++++++++++++++++++++++ src/xla/XLA.jl | 2 ++ 7 files changed, 166 insertions(+), 6 deletions(-) create mode 100644 src/xla/IFRT/Array.jl create mode 100644 src/xla/IFRT/Client.jl create mode 100644 src/xla/IFRT/IFRT.jl create mode 100644 src/xla/IFRT/LoadedExecutable.jl diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52a0718655..52035de10a 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,16 +1,30 @@ # Buffer @inline function free_buffer(buffer) - sbuffer = buffer.buffer - if sbuffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid + if buffer.holded == C_NULL + if buffer.buffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid + end + else + @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid end end mutable struct Buffer buffer::Ptr{Cvoid} + holded::Ptr{Cvoid} function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer)) + return finalizer(free_buffer, new(buffer, C_NULL)) + end +end + +function hold!(buffer::Buffer) + if buffer.holded == C_NULL + sbuffer = buffer.buffer + buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer( + sbuffer::Ptr{Cvoid} + )::Ptr{Cvoid} end + return buffer end function Base.ndims(buffer::Buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index a9ac15e318..862cb3e876 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -1,12 +1,13 @@ mutable struct Client client::Ptr{Cvoid} global_ordinals::Vector{Cint} + holded::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) @assert client != C_NULL global_ordinals = Cint[] - client = new(client, global_ordinals) + client = new(client, global_ordinals, C_NULL) # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 devices = [ @@ -29,7 +30,20 @@ end Base.:(==)(a::Client, b::Client) = a.client == b.client @inline function free_client(client::Client) - @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + if client.holded == C_NULL + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + else + @ccall MLIR.API.mlir_c.reactant_release_pjrtclient(client.holded::Ptr{Cvoid})::Cvoid + end +end + +function hold!(client::Client) + if client.holded == C_NULL + client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient( + client.client::Ptr{Cvoid} + )::Ptr{Cvoid} + end + return client end function ClientNumDevices(client::Client) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl new file mode 100644 index 0000000000..ef2ffa409f --- /dev/null +++ b/src/xla/IFRT/Array.jl @@ -0,0 +1,38 @@ +@cenum ArrayCopySemantics::UInt32 begin + AlwaysCopy = 0 + ReuseInput = 1 + DonateInput = 2 +end + +# currently, only supports IFRT-PjRt +mutable struct Array + ptr::Ptr{Cvoid} + + function Array(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_array, new(ptr)) + end +end + +function free_array(array) + @ccall MLIR.API.mlir_c.reactant_release_ifrt_array(array.ptr::Ptr{Cvoid})::Cvoid +end + +function Array(client::Client, buffer::XLA.Buffer) + hold!(buffer) + GC.@preserve client buffer begin + return Array( + @ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer( + client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function CopyArrayToHostBuffer(array::Array, data) + GC.@preserve array data begin + @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer( + array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint + )::Cvoid + end +end diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl new file mode 100644 index 0000000000..d75d1c5fcf --- /dev/null +++ b/src/xla/IFRT/Client.jl @@ -0,0 +1,31 @@ +# currently, only supports IFRT-PjRt +mutable struct Client + ptr::Ptr{Cvoid} + + function Client(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_client, new(ptr)) + end +end + +function Client(pjrt_client::XLA.Client) + # it needs a `std::shared_ptr` + hold!(pjrt_client) + return Client( + @ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient( + pjrt_client.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) +end + +function free_client(client) + @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid +end + +function compile(client::Client, code::MLIR.IR.Module) + return LoadedExecutable( + @ccall MLIR.API.mlir_c.ifrt_ClientCompile( + client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule + )::Ptr{Cvoid} + ) +end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl new file mode 100644 index 0000000000..def0304cd6 --- /dev/null +++ b/src/xla/IFRT/IFRT.jl @@ -0,0 +1,13 @@ +module IFRT + +using CEnum + +import ..XLA +import .XLA: hold! +import ..MLIR + +include("LoadedExecutable.jl") +include("Client.jl") +include("Array.jl") + +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl new file mode 100644 index 0000000000..f376947042 --- /dev/null +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -0,0 +1,48 @@ +# currently, only supports IFRT-PjRt +mutable struct LoadedExecutable + ptr::Ptr{Cvoid} + + function LoadedExecutable(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_exec, new(ptr)) + end +end + +@inline function free_exec(exec) + @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid +end + +function execute( + exec::LoadedExecutable, + args::NTuple{N,Ptr{Cvoid}}, + donated_mask::NTuple{N,UInt8}, + ::Val{n_results}, +) where {N,n_results} + results = Ref{NTuple{n_results,Ptr{Cvoid}}}() + has_future = Ref{UInt8}() + status = Ref{NTuple{1,Ptr{Cvoid}}}() # unused right now + + args = Base.RefValue(args) + donated_mask = Base.RefValue(donated_mask) + + GC.@preserve exec args donated_mask results has_future status begin + @ccall MLIR.API.mlir_c.ifrt_Execute( + exec.ptr::Ptr{Cvoid}, + N::Cint, + args::Ptr{Cvoid}, + donated_mask::Ptr{Cvoid}, + n_results::Cint, + Base.unsafe_convert(Ptr{Cvoid}, results)::Ptr{Cvoid}, + has_future::Ptr{Cvoid}, + status::Ptr{Cvoid}, + )::Cvoid + end + + @assert has_future[] == true + + results = results[] + + return ntuple(Val(n_results)) do i + return Array(results[i]) + end +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 2fe4cc9f46..e4ca11f3a6 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -96,4 +96,6 @@ function __init__() return nothing end +include("IFRT/IFRT.jl") + end From 272e44810252168d9527992bace8c626a515b014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 13 Feb 2025 08:40:36 +0100 Subject: [PATCH 2/4] Refactor `holded` field in `Buffer` to `HeldBuffer` type --- src/xla/Buffer.jl | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52035de10a..1512d431df 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,28 +1,42 @@ # Buffer -@inline function free_buffer(buffer) - if buffer.holded == C_NULL - if buffer.buffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid - end - else - @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid +mutable struct HeldBuffer + ptr::Ptr{Cvoid} + + function HeldBuffer(ptr::Ptr{Cvoid}) + return finalizer(release_buffer, new(ptr)) end end +@inline function release_buffer(held_buffer::HeldBuffer) + @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer( + held_buffer.ptr::Ptr{Cvoid} + )::Cvoid +end + mutable struct Buffer buffer::Ptr{Cvoid} - holded::Ptr{Cvoid} + held::Union{Nothing,HeldBuffer} + function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer, C_NULL)) + return finalizer(free_buffer, new(buffer, nothing)) end end +@inline function free_buffer(buffer) + if buffer.holded == C_NULL && buffer.buffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid + end + # else + # @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid + # end +end + function hold!(buffer::Buffer) if buffer.holded == C_NULL sbuffer = buffer.buffer - buffer.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer( - sbuffer::Ptr{Cvoid} - )::Ptr{Cvoid} + buffer.holded = HeldBuffer( + @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer(sbuffer::Ptr{Cvoid})::Ptr{Cvoid} + ) end return buffer end From 15006b2ce3394dd460674d4c34ad53b55162a4cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 13 Feb 2025 08:41:21 +0100 Subject: [PATCH 3/4] Fix `free_buffer` --- src/xla/Buffer.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 1512d431df..4f1cadcf29 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -23,12 +23,9 @@ mutable struct Buffer end @inline function free_buffer(buffer) - if buffer.holded == C_NULL && buffer.buffer != C_NULL + if isnothing(buffer.holded) && buffer.buffer != C_NULL @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid end - # else - # @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(buffer.holded::Ptr{Cvoid})::Cvoid - # end end function hold!(buffer::Buffer) From 9a8f9a701e37208072fba7caceacf397197f7af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 13 Feb 2025 08:41:26 +0100 Subject: [PATCH 4/4] Format code --- src/xla/Buffer.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 4f1cadcf29..7e1c335b00 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -8,9 +8,7 @@ mutable struct HeldBuffer end @inline function release_buffer(held_buffer::HeldBuffer) - @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer( - held_buffer.ptr::Ptr{Cvoid} - )::Cvoid + @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(held_buffer.ptr::Ptr{Cvoid})::Cvoid end mutable struct Buffer