diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 52a0718655..7e1c335b00 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -1,16 +1,39 @@ # Buffer -@inline function free_buffer(buffer) - sbuffer = buffer.buffer - if sbuffer != C_NULL - @ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid +mutable struct HeldBuffer + ptr::Ptr{Cvoid} + + function HeldBuffer(ptr::Ptr{Cvoid}) + return finalizer(release_buffer, new(ptr)) end end +@inline function release_buffer(held_buffer::HeldBuffer) + @ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(held_buffer.ptr::Ptr{Cvoid})::Cvoid +end + mutable struct Buffer buffer::Ptr{Cvoid} + held::Union{Nothing,HeldBuffer} + function Buffer(buffer::Ptr{Cvoid}) - return finalizer(free_buffer, new(buffer)) + return finalizer(free_buffer, new(buffer, nothing)) + end +end + +@inline function free_buffer(buffer) + if isnothing(buffer.holded) && buffer.buffer != C_NULL + @ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid + end +end + +function hold!(buffer::Buffer) + if buffer.holded == C_NULL + sbuffer = buffer.buffer + buffer.holded = HeldBuffer( + @ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer(sbuffer::Ptr{Cvoid})::Ptr{Cvoid} + ) end + return buffer end function Base.ndims(buffer::Buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index a9ac15e318..862cb3e876 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -1,12 +1,13 @@ mutable struct Client client::Ptr{Cvoid} global_ordinals::Vector{Cint} + holded::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) @assert client != C_NULL global_ordinals = Cint[] - client = new(client, global_ordinals) + client = new(client, global_ordinals, C_NULL) # https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127 devices = [ @@ -29,7 +30,20 @@ end Base.:(==)(a::Client, b::Client) = a.client == b.client @inline function free_client(client::Client) - @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + if client.holded == C_NULL + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + else + @ccall MLIR.API.mlir_c.reactant_release_pjrtclient(client.holded::Ptr{Cvoid})::Cvoid + end +end + +function hold!(client::Client) + if client.holded == C_NULL + client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient( + client.client::Ptr{Cvoid} + )::Ptr{Cvoid} + end + return client end function ClientNumDevices(client::Client) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl new file mode 100644 index 0000000000..ef2ffa409f --- /dev/null +++ b/src/xla/IFRT/Array.jl @@ -0,0 +1,38 @@ +@cenum ArrayCopySemantics::UInt32 begin + AlwaysCopy = 0 + ReuseInput = 1 + DonateInput = 2 +end + +# currently, only supports IFRT-PjRt +mutable struct Array + ptr::Ptr{Cvoid} + + function Array(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_array, new(ptr)) + end +end + +function free_array(array) + @ccall MLIR.API.mlir_c.reactant_release_ifrt_array(array.ptr::Ptr{Cvoid})::Cvoid +end + +function Array(client::Client, buffer::XLA.Buffer) + hold!(buffer) + GC.@preserve client buffer begin + return Array( + @ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer( + client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function CopyArrayToHostBuffer(array::Array, data) + GC.@preserve array data begin + @ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer( + array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint + )::Cvoid + end +end diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl new file mode 100644 index 0000000000..d75d1c5fcf --- /dev/null +++ b/src/xla/IFRT/Client.jl @@ -0,0 +1,31 @@ +# currently, only supports IFRT-PjRt +mutable struct Client + ptr::Ptr{Cvoid} + + function Client(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_client, new(ptr)) + end +end + +function Client(pjrt_client::XLA.Client) + # it needs a `std::shared_ptr` + hold!(pjrt_client) + return Client( + @ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient( + pjrt_client.holded::Ptr{Cvoid} + )::Ptr{Cvoid} + ) +end + +function free_client(client) + @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid +end + +function compile(client::Client, code::MLIR.IR.Module) + return LoadedExecutable( + @ccall MLIR.API.mlir_c.ifrt_ClientCompile( + client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule + )::Ptr{Cvoid} + ) +end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl new file mode 100644 index 0000000000..def0304cd6 --- /dev/null +++ b/src/xla/IFRT/IFRT.jl @@ -0,0 +1,13 @@ +module IFRT + +using CEnum + +import ..XLA +import .XLA: hold! +import ..MLIR + +include("LoadedExecutable.jl") +include("Client.jl") +include("Array.jl") + +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl new file mode 100644 index 0000000000..f376947042 --- /dev/null +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -0,0 +1,48 @@ +# currently, only supports IFRT-PjRt +mutable struct LoadedExecutable + ptr::Ptr{Cvoid} + + function LoadedExecutable(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + return finalizer(free_exec, new(ptr)) + end +end + +@inline function free_exec(exec) + @ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid +end + +function execute( + exec::LoadedExecutable, + args::NTuple{N,Ptr{Cvoid}}, + donated_mask::NTuple{N,UInt8}, + ::Val{n_results}, +) where {N,n_results} + results = Ref{NTuple{n_results,Ptr{Cvoid}}}() + has_future = Ref{UInt8}() + status = Ref{NTuple{1,Ptr{Cvoid}}}() # unused right now + + args = Base.RefValue(args) + donated_mask = Base.RefValue(donated_mask) + + GC.@preserve exec args donated_mask results has_future status begin + @ccall MLIR.API.mlir_c.ifrt_Execute( + exec.ptr::Ptr{Cvoid}, + N::Cint, + args::Ptr{Cvoid}, + donated_mask::Ptr{Cvoid}, + n_results::Cint, + Base.unsafe_convert(Ptr{Cvoid}, results)::Ptr{Cvoid}, + has_future::Ptr{Cvoid}, + status::Ptr{Cvoid}, + )::Cvoid + end + + @assert has_future[] == true + + results = results[] + + return ntuple(Val(n_results)) do i + return Array(results[i]) + end +end diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 2fe4cc9f46..e4ca11f3a6 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -96,4 +96,6 @@ function __init__() return nothing end +include("IFRT/IFRT.jl") + end