Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial IFRT Julia bindings #738

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions src/xla/Buffer.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
# Buffer
@inline function free_buffer(buffer)
sbuffer = buffer.buffer
if sbuffer != C_NULL
@ccall MLIR.API.mlir_c.PjRtBufferFree(sbuffer::Ptr{Cvoid})::Cvoid
mutable struct HeldBuffer
ptr::Ptr{Cvoid}

function HeldBuffer(ptr::Ptr{Cvoid})
return finalizer(release_buffer, new(ptr))
end
end

@inline function release_buffer(held_buffer::HeldBuffer)
@ccall MLIR.API.mlir_c.reactant_release_pjrtbuffer(held_buffer.ptr::Ptr{Cvoid})::Cvoid
end

mutable struct Buffer
buffer::Ptr{Cvoid}
held::Union{Nothing,HeldBuffer}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?

In essence I feel like we should consider buffer and heldbuffer as two totally different API's. Of course under the hood technically one can be converted to the other, but we shouldn't merge them unless needed. This will make it easier to transition the rest of the code from PJRT -> IFRT (and also the inner buffers having a union will make them slow)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I agree with your view that it would be cleaner and more performant. but I'm worried about destruction of the objects and having "use after free" bugs. In this transition period, we will have both Buffer and HeldBuffer.

if Buffer takes ownership of it, then the underlying xla::PjRtBuffer will be freed when Buffer is GCed. if we have a HeldBuffer around, it will be broken because the ptr to which the shared_ptr tries to point will be already freed.

consider the opposite: HeldBuffer takes ownership of it. the same problem applies: HeldBuffer can be GCed before Buffer and Buffer will be broken because the pointer will already be freed if you try to use it again.

here is the dependency graph: both Buffer and HeldBuffer have references to the same xla::PjRtBuffer object.

flowchart TB
    subgraph Julia
    Buffer
    HeldBuffer
    end
    Buffer --> PjRtBuffer
    HeldBuffer --> shared_ptr
    shared_ptr --> PjRtBuffer
Loading

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO in order to avoid these issues, the first implementation was best:

  • it avoids double-free and use-after-free issues
  • both HeldBuffer and Buffer are separate types which makes incremental transition easier
  • no overheads

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should never have a heldbuffer which points to the same data as a regular buffer. Each should be considered an exclusive owner of its underyling data.

And analagously we should avoid making a held buffer out of an existing pjrt buffer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that way there would never be any use after free issues

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that means that we would be duplicating data, because PjRtBuffer holds raw data, which will be costly

mmmm or we will need to replicate stuff for HeldBuffer. lemme try this weekend


function Buffer(buffer::Ptr{Cvoid})
return finalizer(free_buffer, new(buffer))
return finalizer(free_buffer, new(buffer, nothing))
end
end

@inline function free_buffer(buffer)
if isnothing(buffer.holded) && buffer.buffer != C_NULL
@ccall MLIR.API.mlir_c.PjRtBufferFree(buffer.buffer::Ptr{Cvoid})::Cvoid
end
end

function hold!(buffer::Buffer)
if buffer.holded == C_NULL
sbuffer = buffer.buffer
buffer.holded = HeldBuffer(
@ccall MLIR.API.mlir_c.reactant_hold_pjrtbuffer(sbuffer::Ptr{Cvoid})::Ptr{Cvoid}
)
end
return buffer
end

function Base.ndims(buffer::Buffer)
Expand Down
18 changes: 16 additions & 2 deletions src/xla/Client.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
mutable struct Client
client::Ptr{Cvoid}
global_ordinals::Vector{Cint}
holded::Ptr{Cvoid}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment here above having a separation

ideally we have
abstract struct Client

struct PJRTClient <: Client
   ...
end
struct IFRTClient <: Client
   ...
end


function Client(client::Ptr{Cvoid})
@assert client != C_NULL
global_ordinals = Cint[]

client = new(client, global_ordinals)
client = new(client, global_ordinals, C_NULL)

# https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L127
devices = [
Expand All @@ -29,7 +30,20 @@ end
Base.:(==)(a::Client, b::Client) = a.client == b.client

@inline function free_client(client::Client)
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
if client.holded == C_NULL
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
else
@ccall MLIR.API.mlir_c.reactant_release_pjrtclient(client.holded::Ptr{Cvoid})::Cvoid
end
end

function hold!(client::Client)
if client.holded == C_NULL
client.holded = @ccall MLIR.API.mlir_c.reactant_hold_pjrtclient(
client.client::Ptr{Cvoid}
)::Ptr{Cvoid}
end
return client
end

function ClientNumDevices(client::Client)
Expand Down
38 changes: 38 additions & 0 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@cenum ArrayCopySemantics::UInt32 begin
AlwaysCopy = 0
ReuseInput = 1
DonateInput = 2
end

# currently, only supports IFRT-PjRt
mutable struct Array
ptr::Ptr{Cvoid}

function Array(ptr::Ptr{Cvoid})
@assert ptr != C_NULL
return finalizer(free_array, new(ptr))
end
end

function free_array(array)
@ccall MLIR.API.mlir_c.reactant_release_ifrt_array(array.ptr::Ptr{Cvoid})::Cvoid
end

function Array(client::Client, buffer::XLA.Buffer)
hold!(buffer)
GC.@preserve client buffer begin
return Array(
@ccall MLIR.API.mlir_c.ifrt_pjrt_ArrayFromHostBuffer(
client.ptr::Ptr{Cvoid}, buffer.holded::Ptr{Cvoid}
)::Ptr{Cvoid}
)
end
end

function CopyArrayToHostBuffer(array::Array, data)
GC.@preserve array data begin
@ccall MLIR.API.mlir_c.ifrt_CopyArrayToHostBuffer(
array.ptr::Ptr{Cvoid}, data::Ptr{Cvoid}, AlwaysCopy::Cuint
)::Cvoid
end
end
31 changes: 31 additions & 0 deletions src/xla/IFRT/Client.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# currently, only supports IFRT-PjRt
mutable struct Client
ptr::Ptr{Cvoid}

function Client(ptr::Ptr{Cvoid})
@assert ptr != C_NULL
return finalizer(free_client, new(ptr))
end
end

function Client(pjrt_client::XLA.Client)
# it needs a `std::shared_ptr<xla::PjRtClient>`
hold!(pjrt_client)
return Client(
@ccall MLIR.API.mlir_c.ifrt_pjrt_MakeClient(
pjrt_client.holded::Ptr{Cvoid}
)::Ptr{Cvoid}
)
end

function free_client(client)
@ccall MLIR.API.mlir_c.ifrt_FreeClient(client.ptr::Ptr{Cvoid})::Cvoid
end

function compile(client::Client, code::MLIR.IR.Module)
return LoadedExecutable(
@ccall MLIR.API.mlir_c.ifrt_ClientCompile(
client.ptr::Ptr{Cvoid}, code.module_::MLIR.API.MlirModule
)::Ptr{Cvoid}
)
end
13 changes: 13 additions & 0 deletions src/xla/IFRT/IFRT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module IFRT

using CEnum

import ..XLA
import .XLA: hold!
import ..MLIR

include("LoadedExecutable.jl")
include("Client.jl")
include("Array.jl")

end
48 changes: 48 additions & 0 deletions src/xla/IFRT/LoadedExecutable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# currently, only supports IFRT-PjRt
mutable struct LoadedExecutable
ptr::Ptr{Cvoid}

function LoadedExecutable(ptr::Ptr{Cvoid})
@assert ptr != C_NULL
return finalizer(free_exec, new(ptr))
end
end

@inline function free_exec(exec)
@ccall MLIR.API.mlir_c.ifrt_pjrt_FreeLoadedExecutable(exec.ptr::Ptr{Cvoid})::Cvoid
end

function execute(
exec::LoadedExecutable,
args::NTuple{N,Ptr{Cvoid}},
donated_mask::NTuple{N,UInt8},
::Val{n_results},
) where {N,n_results}
results = Ref{NTuple{n_results,Ptr{Cvoid}}}()
has_future = Ref{UInt8}()
status = Ref{NTuple{1,Ptr{Cvoid}}}() # unused right now

args = Base.RefValue(args)
donated_mask = Base.RefValue(donated_mask)

GC.@preserve exec args donated_mask results has_future status begin
@ccall MLIR.API.mlir_c.ifrt_Execute(
exec.ptr::Ptr{Cvoid},
N::Cint,
args::Ptr{Cvoid},
donated_mask::Ptr{Cvoid},
n_results::Cint,
Base.unsafe_convert(Ptr{Cvoid}, results)::Ptr{Cvoid},
has_future::Ptr{Cvoid},
status::Ptr{Cvoid},
)::Cvoid
end

@assert has_future[] == true

results = results[]

return ntuple(Val(n_results)) do i
return Array(results[i])
end
end
2 changes: 2 additions & 0 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,6 @@ function __init__()
return nothing
end

include("IFRT/IFRT.jl")

end
Loading