diff --git a/src/Compiler.jl b/src/Compiler.jl index 69df7c356f..86b6c940b2 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -508,13 +508,6 @@ function compile_mlir(f, args; client=nothing, kwargs...) mlir_fn_res = compile_mlir!(mod, f, args; backend, kwargs...) - client, _ = __resolve_device_and_client( - client, - mlir_fn_res.seen_args, - mlir_fn_res.linear_args, - mlir_fn_res.is_sharded, - ) - # Attach a name, and partitioning attributes to the module __add_mhlo_attributes_and_name!( mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas @@ -1509,6 +1502,8 @@ function compile_xla(f, args; client=nothing, kwargs...) num_parameters=length(mlir_fn_res.linear_args), mlir_fn_res.is_sharded, global_device_ids, + mlir_fn_res.num_replicas, + mlir_fn_res.num_partitions, ) return mod, exec, mlir_fn_res, device, client diff --git a/src/Distributed.jl b/src/Distributed.jl index e419347d77..5e7c498777 100644 --- a/src/Distributed.jl +++ b/src/Distributed.jl @@ -8,24 +8,24 @@ function initialize(; coordinator_address::Union{Nothing,String}=nothing, num_processes::Union{Nothing,Integer}=nothing, process_id::Union{Nothing,Integer}=nothing, - local_device_ids::Union{Nothing,Vector{Int}}=nothing, + local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, initialization_timeout_in_seconds::Integer=300, kwargs..., ) @assert !initialized[] "`Distributed.initialize` has already been called" - (coordinator_address, num_processes, process_id, local_device_ids) = auto_detect_unset_distributed_params(; + (coordinator_address, num_processes, process_id, local_gpu_device_ids) = auto_detect_unset_distributed_params(; coordinator_address, num_processes, process_id, - local_device_ids, + local_gpu_device_ids, initialization_timeout_in_seconds, ) - @debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_device_ids + @debug "Detected Reactant distributed params" coordinator_address num_processes process_id local_gpu_device_ids Reactant.XLA.update_global_state!(; - coordinator_address, num_processes, process_id, local_device_ids, kwargs... + coordinator_address, num_processes, process_id, local_gpu_device_ids, kwargs... ) @debug "New Global State" Reactant.XLA.global_state @@ -57,14 +57,14 @@ function auto_detect_unset_distributed_params(; coordinator_address::Union{Nothing,String}=nothing, num_processes::Union{Nothing,Integer}=nothing, process_id::Union{Nothing,Integer}=nothing, - local_device_ids::Union{Nothing,Vector{Int}}=nothing, + local_gpu_device_ids::Union{Nothing,Vector{Int}}=nothing, initialization_timeout_in_seconds::Integer=300, ) if all( Base.Fix2(!==, nothing), - (coordinator_address, num_processes, process_id, local_device_ids), + (coordinator_address, num_processes, process_id, local_gpu_device_ids), ) - return coordinator_address, num_processes, process_id, local_device_ids + return coordinator_address, num_processes, process_id, local_gpu_device_ids end idx = findfirst(is_env_present, detector_list) @@ -91,11 +91,11 @@ function auto_detect_unset_distributed_params(; process_id = get_process_id(detector) end - if local_device_ids === nothing - local_device_ids = [get_local_process_id(detector)] + if local_gpu_device_ids === nothing + local_gpu_device_ids = [get_local_process_id(detector)] end - return coordinator_address, num_processes, process_id, local_device_ids + return coordinator_address, num_processes, process_id, local_gpu_device_ids end # OpenMPIORTEEnvDetector & OpenMPIPMIXEnvDetector diff --git a/src/Types.jl b/src/Types.jl index 694f83c6ec..7035bae3de 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -135,11 +135,11 @@ function ConcretePJRTArray( if idx === nothing device = XLA.default_device(client) else - device = XLA.get_addressable_device(client, idx) + device = XLA.get_device(client, idx) end else if idx !== nothing - device_from_idx = XLA.get_addressable_device(client, idx) + device_from_idx = XLA.get_device(client, idx) @assert device_from_idx == device "If both `idx` and `device` are \ specified, `idx` must match `device`" end diff --git a/src/xla/Buffer.jl b/src/xla/Buffer.jl index 8e3c3aac08..091faf7895 100644 --- a/src/xla/Buffer.jl +++ b/src/xla/Buffer.jl @@ -5,6 +5,15 @@ function buffer_on_cpu end function to_host end function unsafe_buffer_pointer end function copy_buffer_to_device end +function sharding end + +Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer) + +function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T} + arr = zeros(T, reverse(size(buffer))...) + XLA.to_host(buffer, arr) + return arr +end @inline function client( buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}} @@ -19,3 +28,48 @@ end ) return map(synced_buffer, buffers) end + +function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer} + print(io, "$(B) storing ") + show(io, mime, convert(Array, buffer)) + return nothing +end + +# Async Buffers +abstract type AbstractAsyncBuffer <: AbstractBuffer end + +Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL + +function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer) + XLA.await(buffer) + return convert(T, buffer.buffer) +end + +function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1} + XLA.await(buffer) + return convert(T, buffer.buffer) +end + +for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding) + @eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer) +end + +function XLA.synced_buffer(buffer::AbstractAsyncBuffer) + XLA.await(buffer) + return buffer.buffer +end + +function XLA.await(buffer::AbstractAsyncBuffer) + buffer.future === nothing && return nothing + future = buffer.future + buffer.future = nothing + XLA.await(future) + return nothing +end + +function XLA.is_ready(buffer::AbstractAsyncBuffer) + buffer.future === nothing && return true + return XLA.is_ready(buffer.future) +end + +XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 7a30909e4e..ccf715c1ba 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -14,55 +14,3 @@ function get_addressable_device end function platform_name end default_device(client::AbstractClient) = first(addressable_devices(client)) - -# Clients for Different Backends -function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true) - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc)) - client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) - LLVMclopts("-nvptx-fma-level=1") - return client -end - -function GPUClient( - cfunc, - node_id=0, - num_nodes=1, - platform="gpu"; - allowed_devices::Union{Nothing,Vector{Int}}=nothing, - distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing, -) - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc)) - refstr = Ref{Cstring}() - - num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices) - allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices - distributed_runtime_client = - distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client - - client = ccall( - f, - Ptr{Cvoid}, - (Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}), - node_id, - num_nodes, - allowed_devices, - num_allowed_devices, - XLA_REACTANT_GPU_MEM_FRACTION[], - false, - platform, - refstr, - distributed_runtime_client, - ) - client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) - LLVMclopts("-nvptx-fma-level=1") - return client -end - -function TPUClient(cfunc, tpu_path::String) - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc)) - refstr = Ref{Cstring}() - client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr) - client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) - LLVMclopts("-nvptx-fma-level=1") - return client -end diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 84c4cc7326..f4b27eaa30 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -10,19 +10,15 @@ function get_local_device_id end function device_kind end function default_memory end function memories end +function is_addressable end """ device_ordinal(device::Device) - device_ordinal(client::XLA.AbstractClient, local_device_id::Int) -Given the device or local device id, return the corresponding global device ordinal in the client. +Given the device, return the corresponding global device ordinal in the client. """ function device_ordinal end -function device_ordinal(client::AbstractClient, local_device_id::Integer) - return device_ordinal(get_addressable_device(client, local_device_id)) -end - function Base.string(device::AbstractDevice) client = XLA.client(device) pname = XLA.platform_name(client) diff --git a/src/xla/Distributed.jl b/src/xla/Distributed.jl index f9e89807de..791d3cdc11 100644 --- a/src/xla/Distributed.jl +++ b/src/xla/Distributed.jl @@ -106,7 +106,7 @@ end @kwdef mutable struct State process_id::Int = 0 num_processes::Int = 1 - local_device_ids::Union{Nothing,Vector{Int}} = nothing + local_gpu_device_ids::Union{Nothing,Vector{Int}} = nothing service::Union{Nothing,DistributedRuntimeService} = nothing client::Union{Nothing,DistributedRuntimeClient} = nothing coordinator_address::Union{Nothing,String} = nothing @@ -129,7 +129,7 @@ function update!( coordinator_address::String, num_processes::Int, process_id::Int, - local_device_ids::Vector{Int}, + local_gpu_device_ids::Vector{Int}, coordinator_bind_address::Union{Nothing,String}=nothing, cluster_register_timeout_in_minutes::Integer=60, rpc_timeout_in_seconds::Integer=120, @@ -141,7 +141,7 @@ function update!( @assert 0 ≤ process_id < num_processes state.coordinator_address = coordinator_address - state.local_device_ids = local_device_ids + state.local_gpu_device_ids = local_gpu_device_ids state.process_id = process_id state.num_processes = num_processes diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl new file mode 100644 index 0000000000..dd0e596db8 --- /dev/null +++ b/src/xla/IFRT/Array.jl @@ -0,0 +1,212 @@ +mutable struct Array <: XLA.AbstractBuffer + buffer::Ptr{Cvoid} + + function Array(buffer::Ptr{Cvoid}) + # return finalizer(free_ifrt_array, new(buffer)) + return new(buffer) + end +end + +function Array( + client::Client, + array::Base.Array{T,N}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +) where {T,N} + sizear = collect(Int64, reverse(size(array))) + buffer = GC.@preserve array sizear begin + @ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer( + client.client::Ptr{Cvoid}, + array::Ptr{T}, + XLA.primitive_type(T)::UInt64, + N::Csize_t, + sizear::Ptr{Int64}, + 0::Cint, # kAlwaysCopy + device.device::Ptr{Cvoid}, + string(memory_kind)::Cstring, + )::Ptr{Cvoid} + end + return Array(buffer) +end + +function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N} + sizear = collect(Int64, reverse(size(array))) + + if is_single_device_sharding(sharding) || is_fully_replicated(sharding) + buffer = GC.@preserve array sizear begin + @ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer( + client.client::Ptr{Cvoid}, + array::Ptr{T}, + XLA.primitive_type(T)::Cint, + N::Csize_t, + sizear::Ptr{Int64}, + sharding.ptr::Ptr{Cvoid}, + 0::Cint, # kAlwaysCopy + )::Ptr{Cvoid} + end + return Array(buffer) + end + + all_devices = XLA.devices(sharding) + array_slices = XLA.sharding_to_concrete_array_indices( + convert(XLA.HloSharding, sharding), + size(array), + collect(Int64, 0:(length(all_devices) - 1)), + ) + array_shape = collect(Int64, reverse(size(array))) + arrays_list = [ + Array(client, array[slice...], device).buffer for + (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) + ] + + buffer = GC.@preserve client arrays_list array_shape sharding begin + @ccall MLIR.API.mlir_c.ifrt_client_assemble_array_from_single_shards( + client.client::Ptr{Cvoid}, + Int32(length(array_shape))::Int32, + array_shape::Ptr{Int64}, + sharding.ptr::Ptr{Cvoid}, + Int32(length(arrays_list))::Int32, + arrays_list::Ptr{Ptr{Cvoid}}, + 2::Cint, # kDonateInput + )::Ptr{Cvoid} + end + + return Array(buffer) +end + +function Array(client::Client, array::Base.Array{T,N}, sharding) where {T,N} + @assert sharding isa Reactant.Sharding.AbstractSharding + if !(sharding isa Reactant.Sharding.HloSharding) + sharding = convert(Reactant.Sharding.HloSharding, sharding) + end + + (; hlo_sharding, mesh) = sharding + devices = XLA.get_device.((client,), mesh.device_ids) + ifrt_sharding = Sharding([devices...], hlo_sharding) + + return Array(client, array, ifrt_sharding) +end + +@inline function free_ifrt_array(buffer::Array) + sbuffer = buffer.buffer + if sbuffer != C_NULL + @ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid + end +end + +function Base.ndims(buffer::Array) + GC.@preserve buffer begin + return @ccall MLIR.API.mlir_c.ifrt_array_ndims(buffer.buffer::Ptr{Cvoid})::Int64 + end +end + +function Base.size(buffer::Array) + GC.@preserve buffer begin + sz = @ccall MLIR.API.mlir_c.ifrt_array_shape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64} + end + return Tuple(unsafe_wrap(Base.Array, sz, ndims(buffer))) +end + +function Base.eltype(buffer::Array) + GC.@preserve buffer begin + return XLA.julia_type( + @ccall MLIR.API.mlir_c.ifrt_array_eltype(buffer.buffer::Ptr{Cvoid})::Cint + ) + end +end + +function XLA.device(::Array) + return error("IFRT.Array can be sharded/replicated across multiple devices. Hence, \ + `XLA.device` is not defined.") +end + +function XLA.client(buffer::Array) + GC.@preserve buffer begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_array_to_client( + buffer.buffer::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.synced_buffer(buffer::Array) = buffer + +function XLA.buffer_on_cpu(::Array) + return error("IFRT.Array does not support `XLA.buffer_on_cpu`") +end + +function XLA.to_host(buffer::Array, data) + sharding = XLA.sharding(buffer) + all_devices = XLA.devices(sharding) + + if length(all_devices) == 1 + GC.@preserve buffer data begin + @ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer( + buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} + )::Cvoid + end + return nothing + end + + if any(!is_addressable, all_devices) + @warn "Not all devices are addressable. Currently we only fill in the data for \ + addressable devices. Remaining slices of data in `data` are left \ + untouched." + end + + # While some client implementations might support directly copying to host, but we + # avoid the complexity of supporting that for now. + single_device_arrays = disassemble_into_single_device_arrays(buffer, true) + + array_slices = XLA.sharding_to_concrete_array_indices( + convert(XLA.HloSharding, sharding), + size(data), + collect(Int64, 0:(length(all_devices) - 1)), + ) + array_slices = [ + slice for + (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) + ] + + @assert length(array_slices) == length(single_device_arrays) + + for (slice, arr) in zip(array_slices, single_device_arrays) + data_slice = data[slice...] + XLA.to_host(arr, data_slice) + data[slice...] .= data_slice + end + return nothing +end + +function disassemble_into_single_device_arrays(array::Array, only_addressable_devices::Bool) + c_single_device_shard_semantics = Int32(!only_addressable_devices) + narrays = Ref{Int32}(0) + arrays = GC.@preserve array begin + @ccall MLIR.API.mlir_c.ifrt_array_disassemble_into_single_device_arrays( + array.buffer::Ptr{Cvoid}, + Int32(0)::Int32, + c_single_device_shard_semantics::Int32, + narrays::Ptr{Int32}, + )::Ptr{Ptr{Cvoid}} + end + return [Array(unsafe_load(arrays, i)) for i in 1:narrays[]] +end + +function XLA.unsafe_buffer_pointer(::Array) + return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`") +end + +function XLA.copy_buffer_to_device(::Array, ::Device) + return error("IFRT.Array does not support `XLA.copy_buffer_to_device`") +end + +function XLA.sharding(buffer::Array) + GC.@preserve buffer begin + return Sharding( + @ccall MLIR.API.mlir_c.ifrt_array_to_sharding( + buffer.buffer::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end diff --git a/src/xla/IFRT/AsyncArray.jl b/src/xla/IFRT/AsyncArray.jl new file mode 100644 index 0000000000..b049289b13 --- /dev/null +++ b/src/xla/IFRT/AsyncArray.jl @@ -0,0 +1,8 @@ +mutable struct AsyncArray <: XLA.AbstractAsyncBuffer + buffer::Array + future::Union{Future,Nothing} +end + +const AsyncEmptyArray = AsyncArray(Array(C_NULL), nothing) + +AsyncArray(args...; kwargs...) = AsyncArray(Array(args...; kwargs...), nothing) diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl new file mode 100644 index 0000000000..356a151401 --- /dev/null +++ b/src/xla/IFRT/Client.jl @@ -0,0 +1,192 @@ +mutable struct Client <: XLA.AbstractClient + client::Ptr{Cvoid} + + function Client(client::Ptr{Cvoid}; skip_check::Bool=false) + skip_check || (@assert client != C_NULL) + return new(client) + end +end + +function XLA.free_client(client::Client) + GC.@preserve client begin + @ccall MLIR.API.mlir_c.ifrt_FreeClient(client.client::Ptr{Cvoid})::Cvoid + end +end + +function XLA.num_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_client_device_count( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.num_addressable_devices(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_client_addressable_device_count( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.process_index(client::Client) + GC.@preserve client begin + return @ccall MLIR.API.mlir_c.ifrt_ClientProcessIndex( + client.client::Ptr{Cvoid} + )::Cint + end +end + +function XLA.get_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ifrt_client_lookup_device( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.get_addressable_device(client::Client, idx) + GC.@preserve client begin + return Device( + @ccall MLIR.API.mlir_c.ifrt_client_lookup_addressable_device( + client.client::Ptr{Cvoid}, idx::Cint + )::Ptr{Cvoid} + ) + end +end + +function XLA.platform_name(client::Client) + GC.@preserve client begin + str = @ccall MLIR.API.mlir_c.ifrt_ClientGetPlatformName( + client.client::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function XLA.devices(client::Client) + ndevices = Int(XLA.num_devices(client)) + devices = Ref{NTuple{ndevices,Ptr{Cvoid}}}() + GC.@preserve client devices begin + @ccall MLIR.API.mlir_c.ifrt_client_devices( + client.client::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in devices[]] +end + +function XLA.addressable_devices(client::Client) + naddressable_devices = Int(XLA.num_addressable_devices(client)) + addressable_devices = Ref{NTuple{naddressable_devices,Ptr{Cvoid}}}() + GC.@preserve client addressable_devices begin + @ccall MLIR.API.mlir_c.ifrt_client_addressable_devices( + client.client::Ptr{Cvoid}, addressable_devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in addressable_devices[]] +end + +# Different Backends +const cpu_client_count = Ref(0) +const gpu_client_count = Ref(0) +const tpu_client_count = Ref(0) + +for (backend, counter) in ( + (:CPUClient, :cpu_client_count), + (:GPUClient, :gpu_client_count), + (:TPUClient, :tpu_client_count), +) + main_fn = Symbol(:MakeIFRTPJRT, backend) + @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) + if checkcount + @assert $(counter)[] == 0 + end + client, refstr = $(main_fn)(args...; kwargs...) + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + XLA.LLVMclopts("-nvptx-fma-level=1") + if checkcount + # Only increment the counter if we successfully created a client + $(counter)[] += 1 + end + return Client(client) + end +end + +function MakeIFRTPJRTCPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + asynchronous::Bool=true, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_cpu_client( + asynchronous::UInt8, + node_id::Cint, + num_nodes::Cint, + distributed_runtime_client::Ptr{Cvoid}, + refstr::Ptr{Cstring}, + )::Ptr{Cvoid} + end + + return client, refstr +end + +function MakeIFRTPJRTGPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + platform::String="gpu", + allowed_devices::Union{Nothing,Vector{Int}}=nothing, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + + num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices) + allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr allowed_devices distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_gpu_client( + node_id::Cint, + num_nodes::Cint, + allowed_devices::Ptr{Cvoid}, + num_allowed_devices::Cint, + XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble, + XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool, + platform::Cstring, + refstr::Ptr{Cstring}, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + return client, refstr +end + +function MakeIFRTPJRTTPUClient(; + tpu_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_tpu_client( + tpu_path::Cstring, + refstr::Ptr{Cstring}, + node_id::Cint, + num_nodes::Cint, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + return client, refstr +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl new file mode 100644 index 0000000000..e550fd3013 --- /dev/null +++ b/src/xla/IFRT/Device.jl @@ -0,0 +1,74 @@ +struct Device <: XLA.AbstractDevice + device::Ptr{Cvoid} +end + +function XLA.client(device::Device) + GC.@preserve device begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_DeviceToClient( + device.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function XLA.device_ordinal(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetGlobalDeviceId( + device.device::Ptr{Cvoid} + )::Int64 + end +end + +function XLA.device_kind(device::Device) + GC.@preserve device begin + str = @ccall MLIR.API.mlir_c.ifrt_DeviceGetKind(device.device::Ptr{Cvoid})::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function XLA.get_local_device_id(::Device) + return error("Not implemented for ifrt devices") +end + +function XLA.default_memory(device::Device) + GC.@preserve device begin + return Memory( + @ccall MLIR.API.mlir_c.ifrt_DeviceGetDefaultMemory( + device.device::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function XLA.memories(device::Device) + memories_size = Ref{Int32}(0) + GC.@preserve device memories_size begin + ptr = @ccall MLIR.API.mlir_c.ifrt_DeviceGetMemories( + device.device::Ptr{Cvoid}, memories_size::Ptr{Int32} + )::Ptr{Ptr{Cvoid}} + end + return [Memory(unsafe_load(ptr, i)) for i in 1:memories_size[]] +end + +# TODO: https://github.com/openxla/xla/blob/ad0814d221883609f784e57dd26914b17f92fbbc/xla/python/ifrt/sharding.cc#L60 +function XLA.default_memory(device_list::AbstractVector{Device}) + default_memories = XLA.default_memory.(device_list) + default_memory_kinds = convert.(MemoryKind, default_memories) + @assert allequal(default_memory_kinds) "All devices must have the same default memory" + return first(default_memories) +end + +function XLA.client(device_list::AbstractVector{Device}) + clients = XLA.client.(device_list) + @assert allequal(clients) "All devices must have the same client" + return first(clients) +end + +function XLA.is_addressable(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceIsAddressable( + device.device::Ptr{Cvoid} + )::Bool + end +end diff --git a/src/xla/IFRT/Future.jl b/src/xla/IFRT/Future.jl new file mode 100644 index 0000000000..bbb8128dbe --- /dev/null +++ b/src/xla/IFRT/Future.jl @@ -0,0 +1,27 @@ +mutable struct Future <: XLA.AbstractFuture + future::Ptr{Cvoid} + + function Future(future::Ptr{Cvoid}) + @assert future != C_NULL + return finalizer(free_future, new(future)) + end +end + +@inline function free_future(future::Future) + @ccall MLIR.API.mlir_c.ifrt_free_future(future.future::Ptr{Cvoid})::Cvoid +end + +function XLA.is_ready(future::Future) + GC.@preserve future begin + return (@ccall MLIR.API.mlir_c.ifrt_future_is_ready( + future.future::Ptr{Cvoid} + )::UInt8) != 0 + end +end + +@inline function XLA.await(future::Future)::Nothing + GC.@preserve future begin + @ccall MLIR.API.mlir_c.ifrt_future_await(future.future::Ptr{Cvoid})::Cvoid + end + return nothing +end diff --git a/src/xla/IFRT/IFRT.jl b/src/xla/IFRT/IFRT.jl new file mode 100644 index 0000000000..cdca7a520c --- /dev/null +++ b/src/xla/IFRT/IFRT.jl @@ -0,0 +1,15 @@ +module IFRT + +using ..Reactant: Reactant, MLIR +using ..XLA: XLA + +include("Client.jl") +include("Device.jl") +include("Memory.jl") +include("Future.jl") +include("Sharding.jl") +include("Array.jl") +include("AsyncArray.jl") +include("LoadedExecutable.jl") + +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl new file mode 100644 index 0000000000..d126119d70 --- /dev/null +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -0,0 +1,133 @@ +mutable struct LoadedExecutable <: XLA.AbstractLoadedExecutable + exec::Ptr{Cvoid} + num_outputs::Int64 + num_parameters::Int64 + is_sharded::Bool + num_replicas::Int64 + num_partitions::Int64 + + function LoadedExecutable(exec::Ptr{Cvoid}, args...) + @assert exec != C_NULL + return finalizer(free_exec, new(exec, args...)) + end +end + +function free_exec(exec::LoadedExecutable) + GC.@preserve exec begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_dtor(exec.exec::Ptr{Cvoid})::Cvoid + end +end + +function XLA.client(exec::LoadedExecutable) + GC.@preserve exec begin + return Client( + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_client( + exec.exec::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +XLA.num_partitions(exec::LoadedExecutable) = exec.num_partitions +XLA.num_replicas(exec::LoadedExecutable) = exec.num_replicas +XLA.num_devices(exec::LoadedExecutable) = XLA.num_replicas(exec) * XLA.num_partitions(exec) + +for (jlop, xlaop, field) in ( + (:get_output_shardings, :ifrt_loaded_executable_get_output_shardings, :num_outputs), + ( + :get_parameter_shardings, + :ifrt_loaded_executable_get_parameter_shardings, + :num_parameters, + ), +) + @eval function XLA.$(jlop)(exec::LoadedExecutable) + exec.is_sharded || return XLA.OpSharding[] + + op_shardings = Ref{NTuple{exec.$(field),Ptr{Cvoid}}}() + + GC.@preserve exec op_shardings begin + @ccall MLIR.API.mlir_c.$(xlaop)( + exec.exec::Ptr{Cvoid}, op_shardings::Ptr{Ptr{Cvoid}}, exec.$(field)::Cint + )::Cvoid + end + + return [XLA.OpSharding(op_sharding) for op_sharding in op_shardings[]] + end +end + +function XLA.get_hlo_modules(exec::LoadedExecutable) + # If we had compiled with MPMD then we would need all the partitions to get hlo_modules + # but if we used SPMD we get only 1 module. To be safe we allocate for all the modules + # and use the ones assigned to by XLA + hlo_modules = Ref{NTuple{Int64(XLA.num_partitions(exec)),Ptr{Cvoid}}}() + nmodules = Ref{Int32}(0) + GC.@preserve exec hlo_modules begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_get_hlo_modules( + exec.exec::Ptr{Cvoid}, hlo_modules::Ptr{Ptr{Cvoid}}, nmodules::Ptr{Int32} + )::Cvoid + end + return map(XLA.HloModule, hlo_modules[][1:Int(nmodules[])]) +end + +function XLA.compile( + client::Client, + device::Union{Device,Nothing}, + mod::MLIR.IR.Module; + is_sharded::Bool=false, + global_device_ids::Vector{Int64}=Int64[], + num_outputs::Int64, + num_parameters::Int64, + num_replicas::Int64, + num_partitions::Int64, +) + device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device)) + GC.@preserve client mod begin + exec = @ccall MLIR.API.mlir_c.ifrt_compile( + client.client::Ptr{Cvoid}, + mod.module_::MLIR.API.MlirModule, + device_id::Clong, + is_sharded::Bool, + global_device_ids::Ptr{Clong}, + length(global_device_ids)::Clong, + XLA.CUDA_DATA_DIR[]::Cstring, + )::Ptr{Cvoid} + end + return LoadedExecutable( + exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions + ) +end + +@inline function XLA.execute( + exec::LoadedExecutable, + inputs::NTuple{N,Ptr{Cvoid}}, + donated_args::NTuple{M,UInt8}, + ::Val{n_outs}, +) where {N,M,n_outs} + outputs = Ref{NTuple{n_outs,Ptr{Cvoid}}}() + future_res = Ref{Ptr{Cvoid}}() + futures = Ref{UInt8}(0) + + inputs = Base.RefValue(inputs) + donated_args = Base.RefValue(donated_args) + GC.@preserve exec outputs future_res futures begin + @ccall MLIR.API.mlir_c.ifrt_loaded_executable_execute( + exec.exec::Ptr{Cvoid}, + N::Cint, + inputs::Ptr{Ptr{Cvoid}}, + donated_args::Ptr{UInt8}, + n_outs::Cint, + Base.unsafe_convert(Ptr{Ptr{Cvoid}}, outputs)::Ptr{Ptr{Cvoid}}, + futures::Ptr{UInt8}, + future_res::Ptr{Ptr{Cvoid}}, + )::Cvoid + end + + outputs = outputs[] + future = futures[] != 0 + future && (future_res[] = future_res[]) + + return ntuple(n_outs) do i + Base.@_inline_meta + AsyncArray(Array(outputs[i]), future ? Future(future_res[]) : nothing) + end +end diff --git a/src/xla/IFRT/Memory.jl b/src/xla/IFRT/Memory.jl new file mode 100644 index 0000000000..b2f9575667 --- /dev/null +++ b/src/xla/IFRT/Memory.jl @@ -0,0 +1,56 @@ +mutable struct Memory <: XLA.AbstractMemory + ptr::Ptr{Cvoid} +end + +function Base.show(io::IO, memory::Memory) + GC.@preserve memory begin + str = @ccall MLIR.API.mlir_c.ifrt_MemoryToString(memory.ptr::Ptr{Cvoid})::Cstring + end + print(io, "XLA.IFRT.Memory(\"", XLA.unsafe_string_and_free(str), "\")") + return nothing +end + +mutable struct MemoryKind <: XLA.AbstractMemoryKind + ptr::Ptr{Cvoid} +end + +function MemoryKind(str::AbstractString) + str = string(str) + GC.@preserve str begin + return MemoryKind( + @ccall MLIR.API.mlir_c.ifrt_memory_kind_from_string(str::Cstring)::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{MemoryKind}, memory::Memory) + GC.@preserve memory begin + return MemoryKind( + @ccall MLIR.API.mlir_c.ifrt_MemoryGetMemoryKind( + memory.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.:(==)(a::MemoryKind, b::MemoryKind) + GC.@preserve a b begin + return @ccall MLIR.API.mlir_c.ifrt_MemoryKindsAreEqual( + a.ptr::Ptr{Cvoid}, b.ptr::Ptr{Cvoid} + )::Bool + end +end + +function Base.string(memory_kind::MemoryKind) + GC.@preserve memory_kind begin + str = @ccall MLIR.API.mlir_c.ifrt_MemoryKindToString( + memory_kind.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function Base.show(io::IO, memory_kind::MemoryKind) + print(io, "XLA.IFRT.MemoryKind(\"", string(memory_kind), "\")") + return nothing +end diff --git a/src/xla/IFRT/Sharding.jl b/src/xla/IFRT/Sharding.jl new file mode 100644 index 0000000000..01d9780ed3 --- /dev/null +++ b/src/xla/IFRT/Sharding.jl @@ -0,0 +1,170 @@ +# xla::ifrt::HloSharding (distinct from xla::HloSharding) +mutable struct HloSharding + ptr::Ptr{Cvoid} + + function HloSharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + # return finalizer(free_hlo_sharding, new(ptr)) + return new(ptr) + end +end + +function free_hlo_sharding(hlo_sharding::HloSharding) + @ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function Base.convert(::Type{XLA.HloSharding}, sharding::HloSharding) + GC.@preserve sharding begin + return XLA.HloSharding( + @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_xla_hlo_sharding( + sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function HloSharding( + device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding +) + addressable_devices = filter(XLA.is_addressable, device_list) + default_memory_kind = convert(MemoryKind, XLA.default_memory(addressable_devices)) + return HloSharding(device_list, xla_hlo_sharding, default_memory_kind) +end + +function HloSharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memoy_kind::AbstractString, +) + return HloSharding(device_list, xla_hlo_sharding, MemoryKind(memoy_kind)) +end + +function HloSharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memory_kind::MemoryKind, +) + client = XLA.client(device_list) + GC.@preserve device_list memory_kind xla_hlo_sharding client begin + return HloSharding( + @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding( + client.client::Ptr{Cvoid}, + [d.device for d in device_list]::Ptr{Ptr{Cvoid}}, + length(device_list)::Int32, + memory_kind.ptr::Ptr{Cvoid}, + xla_hlo_sharding.ptr::Ptr{Cvoid}, + )::Ptr{Cvoid} + ) + end +end + +function Base.string(hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + str = @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_string( + hlo_sharding.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding) + print(io, "XLA.IFRT.HloSharding(\"", string(hlo_sharding), "\")") + return nothing +end + +# HloSharding is more specific than Sharding. But Sharding is a neater way to deal with +# most of the IFRT APIs. +mutable struct Sharding + ptr::Ptr{Cvoid} + + function Sharding(ptr::Ptr{Cvoid}) + @assert ptr != C_NULL + # return finalizer(free_sharding, new(ptr)) + return new(ptr) + end +end + +function Sharding(device_list::AbstractVector{<:Device}, xla_hlo_sharding::XLA.HloSharding) + return convert(Sharding, HloSharding(device_list, xla_hlo_sharding)) +end + +function Sharding( + device_list::AbstractVector{<:Device}, + xla_hlo_sharding::XLA.HloSharding, + memoy_kind::Union{AbstractString,MemoryKind}, +) + return convert(Sharding, HloSharding(device_list, xla_hlo_sharding, memoy_kind)) +end + +function free_sharding(sharding::Sharding) + @ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid +end + +function XLA.devices(sharding::Sharding) + GC.@preserve sharding begin + ndevices = @ccall MLIR.API.mlir_c.ifrt_sharding_devices_size( + sharding.ptr::Ptr{Cvoid} + )::Int32 + end + devices = Ref{NTuple{Int64(ndevices),Ptr{Cvoid}}}() + GC.@preserve sharding devices begin + @ccall MLIR.API.mlir_c.ifrt_sharding_to_device_list( + sharding.ptr::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}} + )::Cvoid + end + return [Device(device) for device in devices[]] +end + +function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding) + GC.@preserve hlo_sharding begin + return Sharding( + @ccall MLIR.API.mlir_c.ifrt_sharding_from_ifrt_hlo_sharding( + hlo_sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{HloSharding}, sharding::Sharding) + GC.@preserve sharding begin + return HloSharding( + @ccall MLIR.API.mlir_c.ifrt_sharding_to_ifrt_hlo_sharding( + sharding.ptr::Ptr{Cvoid} + )::Ptr{Cvoid} + ) + end +end + +function Base.convert(::Type{XLA.HloSharding}, sharding::Sharding) + return convert(XLA.HloSharding, convert(HloSharding, sharding)) +end + +function Base.string(sharding::Sharding) + GC.@preserve sharding begin + str = @ccall MLIR.API.mlir_c.ifrt_sharding_to_string( + sharding.ptr::Ptr{Cvoid} + )::Cstring + end + return XLA.unsafe_string_and_free(str) +end + +function is_fully_replicated(sharding::Sharding) + GC.@preserve sharding begin + return @ccall MLIR.API.mlir_c.ifrt_sharding_is_fully_replicated( + sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function is_single_device_sharding(sharding::Sharding) + GC.@preserve sharding begin + return @ccall MLIR.API.mlir_c.ifrt_sharding_is_single_device_sharding( + sharding.ptr::Ptr{Cvoid} + )::Bool + end +end + +function Base.show(io::IO, ::MIME"text/plain", sharding::Sharding) + print(io, "XLA.IFRT.Sharding(\"", string(sharding), "\")") + return nothing +end diff --git a/src/xla/LoadedExecutable.jl b/src/xla/LoadedExecutable.jl index 8f79a004a7..55efd21b66 100644 --- a/src/xla/LoadedExecutable.jl +++ b/src/xla/LoadedExecutable.jl @@ -2,6 +2,7 @@ abstract type AbstractLoadedExecutable end function num_replicas end function num_partitions end +function num_devices end function get_hlo_modules end function get_output_shardings end function get_parameter_shardings end diff --git a/src/xla/PJRT/AsyncBuffer.jl b/src/xla/PJRT/AsyncBuffer.jl index 613edcdaf5..9cec91eba8 100644 --- a/src/xla/PJRT/AsyncBuffer.jl +++ b/src/xla/PJRT/AsyncBuffer.jl @@ -1,44 +1,8 @@ -mutable struct AsyncBuffer <: XLA.AbstractBuffer +mutable struct AsyncBuffer <: XLA.AbstractAsyncBuffer buffer::Buffer future::Union{Future,Nothing} end const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing) -function AsyncBuffer(client::Client, array::Array{T,N}, device::Device) where {T,N} - return AsyncBuffer(Buffer(client, array, device), nothing) -end - -Base.isempty(buffer::AsyncBuffer) = buffer == AsyncEmptyBuffer - -function Base.convert(::Type{<:Array{T}}, buffer::AsyncBuffer) where {T} - XLA.await(buffer) - return convert(Array{T}, buffer.buffer) -end - -for op in (:(Base.ndims), :(Base.size), :device, :client) - @eval $op(buffer::AsyncBuffer) = $op(buffer.buffer) -end - -function XLA.synced_buffer(buffer::AsyncBuffer) - XLA.await(buffer) - return buffer.buffer -end - -function XLA.await(buffer::AsyncBuffer) - buffer.future === nothing && return nothing - future = buffer.future - buffer.future = nothing - XLA.await(future) - return nothing -end - -function XLA.is_ready(buffer::AsyncBuffer) - buffer.future === nothing && return true - return XLA.is_ready(buffer.future) -end - -XLA.buffer_on_cpu(buffer::AsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer) - -XLA.client(buffer::AsyncBuffer) = XLA.client(buffer.buffer) -XLA.device(buffer::AsyncBuffer) = XLA.device(buffer.buffer) +AsyncBuffer(args...; kwargs...) = AsyncBuffer(Buffer(args...; kwargs...), nothing) diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 2f45c19645..3a78830c8b 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -38,7 +38,15 @@ function Base.size(buffer::Buffer) GC.@preserve buffer begin sz = @ccall MLIR.API.mlir_c.BufferShape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64} end - return [unsafe_load(sz, i) for i in 1:ndims(buffer)] + return Tuple(unsafe_wrap(Array, sz, ndims(buffer))) +end + +function Base.eltype(buffer::Buffer) + GC.@preserve buffer begin + return XLA.julia_type( + @ccall MLIR.API.mlir_c.BufferPrimitiveType(buffer.buffer::Ptr{Cvoid})::Cint + ) + end end function XLA.device(buffer::Buffer) @@ -65,12 +73,6 @@ function XLA.buffer_on_cpu(buffer::Buffer) end end -function Base.convert(::Type{<:Array{T}}, buffer::Buffer) where {T} - arr = zeros(T, reverse(size(buffer))...) - XLA.to_host(buffer, arr) - return arr -end - function XLA.to_host(buffer::Buffer, data) GC.@preserve buffer begin @ccall MLIR.API.mlir_c.BufferToHost( @@ -94,3 +96,5 @@ function XLA.copy_buffer_to_device(buffer::Buffer, dev::Device) ) end end + +XLA.sharding(::Buffer) = error("PJRT Buffers are not sharded.") diff --git a/src/xla/PJRT/Client.jl b/src/xla/PJRT/Client.jl index fd3f8f04b9..8842dc02f0 100644 --- a/src/xla/PJRT/Client.jl +++ b/src/xla/PJRT/Client.jl @@ -89,16 +89,18 @@ const cpu_client_count = Ref(0) const gpu_client_count = Ref(0) const tpu_client_count = Ref(0) -for (backend, fname, counter) in ( - (:CPUClient, "MakeCPUClient", :cpu_client_count), - (:GPUClient, "MakeGPUClient", :gpu_client_count), - (:TPUClient, "MakeTPUClient", :tpu_client_count), +for (backend, counter) in ( + (:CPUClient, :cpu_client_count), + (:GPUClient, :gpu_client_count), + (:TPUClient, :tpu_client_count), ) + main_fn = Symbol(:Make, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) if checkcount @assert $(counter)[] == 0 end - client = Client(XLA.$(backend)($(fname), args...; kwargs...)) + client = Client($(main_fn)(args...; kwargs...)) + XLA.LLVMclopts("-nvptx-fma-level=1") if checkcount # Only increment the counter if we successfully created a client $(counter)[] += 1 @@ -106,3 +108,72 @@ for (backend, fname, counter) in ( return client end end + +function MakeCPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + asynchronous::Bool=true, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert num_nodes == 1 "`PJRT.MakeCPUClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeCPUClient` does not support \ + distributed_runtime_client" + + return @ccall MLIR.API.mlir_c.MakeCPUClient( + asynchronous::UInt8, node_id::Cint + )::Ptr{Cvoid} +end + +function MakeGPUClient(; + node_id::Integer=0, + num_nodes::Integer=1, + platform::String="gpu", + allowed_devices::Union{Nothing,Vector{Int}}=nothing, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + refstr = Ref{Cstring}() + + num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices) + allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices + distributed_runtime_client = + distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client + + GC.@preserve refstr allowed_devices distributed_runtime_client begin + client = @ccall MLIR.API.mlir_c.MakeGPUClient( + node_id::Cint, + num_nodes::Cint, + allowed_devices::Ptr{Cvoid}, + num_allowed_devices::Cint, + XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble, + XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool, + platform::Cstring, + refstr::Ptr{Cstring}, + distributed_runtime_client::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + return client +end + +function MakeTPUClient(; + tpu_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert node_id == 0 "`PJRT.MakeTPUClient` does not support node_id" + @assert num_nodes == 1 "`PJRT.MakeTPUClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeTPUClient` does not support \ + distributed_runtime_client" + + refstr = Ref{Cstring}() + GC.@preserve refstr begin + client = @ccall MLIR.API.mlir_c.MakeTPUClient( + tpu_path::Cstring, refstr::Ptr{Cstring} + )::Ptr{Cvoid} + end + + client == C_NULL && throw(AssertionError(unsafe_string(refstr[]))) + return client +end diff --git a/src/xla/PJRT/LoadedExecutable.jl b/src/xla/PJRT/LoadedExecutable.jl index 7dd6ece26d..15d3dbc0d5 100644 --- a/src/xla/PJRT/LoadedExecutable.jl +++ b/src/xla/PJRT/LoadedExecutable.jl @@ -3,12 +3,12 @@ mutable struct LoadedExecutable <: XLA.AbstractLoadedExecutable num_outputs::Int64 num_parameters::Int64 is_sharded::Bool + num_replicas::Int64 + num_partitions::Int64 - function LoadedExecutable( - exec::Ptr{Cvoid}, num_outputs::Int64, num_parameters::Int64, is_sharded::Bool - ) + function LoadedExecutable(exec::Ptr{Cvoid}, args...) @assert exec != C_NULL - return finalizer(free_exec, new(exec, num_outputs, num_parameters, is_sharded)) + return finalizer(free_exec, new(exec, args...)) end end @@ -26,16 +26,9 @@ function XLA.client(exec::LoadedExecutable) end end -for (jlop, xlaop) in ( - (:num_replicas, :PjRtLoadedExecutableNumReplicas), - (:num_partitions, :PjRtLoadedExecutableNumPartitions), -) - @eval function XLA.$(jlop)(exec::LoadedExecutable) - GC.@preserve exec begin - return @ccall MLIR.API.mlir_c.$(xlaop)(exec.exec::Ptr{Cvoid})::Cint - end - end -end +XLA.num_partitions(exec::LoadedExecutable) = exec.num_partitions +XLA.num_replicas(exec::LoadedExecutable) = exec.num_replicas +XLA.num_devices(exec::LoadedExecutable) = XLA.num_replicas(exec) * XLA.num_partitions(exec) for (jlop, xlaop, field) in ( (:get_output_shardings, :PjRtLoadedExecutableGetOuputShardings, :num_outputs), @@ -78,6 +71,8 @@ function XLA.compile( global_device_ids::Vector{Int64}=Int64[], num_outputs::Int64, num_parameters::Int64, + num_replicas::Int64, + num_partitions::Int64, ) device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device)) GC.@preserve client mod begin @@ -91,7 +86,9 @@ function XLA.compile( XLA.CUDA_DATA_DIR[]::Cstring, )::Ptr{Cvoid} end - return LoadedExecutable(exec, num_outputs, num_parameters, is_sharded) + return LoadedExecutable( + exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions + ) end function execute_ir(N, M, n_outs, fn, with_device::Bool, nmesh_ids::Int64) diff --git a/src/xla/Utils.jl b/src/xla/Utils.jl index 1844fa7785..1bedcc88bb 100644 --- a/src/xla/Utils.jl +++ b/src/xla/Utils.jl @@ -13,37 +13,40 @@ function reactant_err(msg::Cstring)::Cvoid end # https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29 -@inline primitive_type(::Type{Bool}) = 1 - -@inline primitive_type(::Type{Int8}) = 2 -@inline primitive_type(::Type{UInt8}) = 6 - -@inline primitive_type(::Type{Int16}) = 3 -@inline primitive_type(::Type{UInt16}) = 7 - -@inline primitive_type(::Type{Int32}) = 4 -@inline primitive_type(::Type{UInt32}) = 8 - -@inline primitive_type(::Type{Int64}) = 5 -@inline primitive_type(::Type{UInt64}) = 9 - -@inline primitive_type(::Type{Float16}) = 10 -@inline primitive_type(::Type{Float32}) = 11 - -@inline primitive_type(::Type{Reactant.F8E5M2}) = 19 -@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20 -@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23 -@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24 -@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25 +primitive_types_list = [ + (1, Bool), + (2, Int8), + (6, UInt8), + (3, Int16), + (7, UInt16), + (4, Int32), + (8, UInt32), + (5, Int64), + (9, UInt64), + (10, Float16), + (11, Float32), + (19, Reactant.F8E5M2), + (20, Reactant.F8E4M3FN), + (23, Reactant.F8E4M3B11FNUZ), + (24, Reactant.F8E5M2FNUZ), + (25, Reactant.F8E4M3FNUZ), + (12, Float64), + (15, Complex{Float32}), + (18, Complex{Float64}), +] @static if isdefined(Core, :BFloat16) - @inline primitive_type(::Type{Core.BFloat16}) = 16 + push!(primitive_types_list, (16, Core.BFloat16)) end -@inline primitive_type(::Type{Float64}) = 12 +for (int_val, jl_type) in primitive_types_list + @eval begin + @inline primitive_type(::Type{$(jl_type)}) = $(int_val) + @inline julia_type(::Val{$(int_val)}) = $(jl_type) + end +end -@inline primitive_type(::Type{Complex{Float32}}) = 15 -@inline primitive_type(::Type{Complex{Float64}}) = 18 +@inline julia_type(@nospecialize(x::Integer)) = julia_type(Val(Int64(x))) function unsafe_string_and_free(str::Cstring, args...) str_jl = unsafe_string(str, args...) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 39a556a3ed..97aacded34 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -35,25 +35,27 @@ include("Memory.jl") include("PJRT/PJRT.jl") -@kwdef mutable struct BackendState +include("IFRT/IFRT.jl") + +@kwdef mutable struct PJRTBackendState initialized::Bool = false clients::Dict{String,PJRT.Client} = Dict{String,PJRT.Client}() default_client::PJRT.Client = PJRT.Client(C_NULL; skip_check=true) end -function Base.getproperty(bs::BackendState, sym::Symbol) +function Base.getproperty(bs::PJRTBackendState, sym::Symbol) (sym === :initialized || bs.initialized) && return getfield(bs, sym) - initialize_default_clients!(bs) + initialize_default_pjrt_clients!(bs) return getfield(bs, sym) end -function Base.setproperty!(bs::BackendState, sym::Symbol, val) +function Base.setproperty!(bs::PJRTBackendState, sym::Symbol, val) (sym === :initialized || bs.initialized) && return setfield!(bs, sym, val) - initialize_default_clients!(bs) + initialize_default_pjrt_clients!(bs) return setfield!(bs, sym, val) end -const global_backend_state = BackendState() +const global_backend_state = PJRTBackendState() const global_state = State() client(backend::String) = global_backend_state.clients[backend] @@ -72,8 +74,13 @@ end function update_global_state!(args...; kwargs...) update!(global_state, args...; kwargs...) - # We need to update the clients based on the new state - initialize_default_clients!(global_backend_state) + # We conditionally initialize for now, since a lot of options that are set are not + # necessarily supported by PJRT. This makes testing for IFRT quite hard. + # Once we move to IFRT completely, we can remove this. + if global_backend_state.initialized + # We need to update the clients based on the new state + initialize_default_pjrt_clients!(global_backend_state) + end return nothing end @@ -100,9 +107,9 @@ function __init__() end if haskey(ENV, "REACTANT_VISIBLE_GPU_DEVICES") - global_state.local_device_ids = + global_state.local_gpu_device_ids = parse.(Int, split(ENV["REACTANT_VISIBLE_GPU_DEVICES"], ",")) - @debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_device_ids + @debug "REACTANT_VISIBLE_GPU_DEVICES: " global_state.local_gpu_device_ids end @ccall MLIR.API.mlir_c.RegisterEnzymeXLACPUHandler()::Cvoid @@ -110,16 +117,27 @@ function __init__() return nothing end -function initialize_default_clients!(state::BackendState) +function initialize_default_pjrt_clients!(state::PJRTBackendState) was_initialized = state.initialized state.initialized = true + distributed_runtime_client = if global_state.num_processes > 1 + @assert global_state.client !== nothing + global_state.client + else + nothing + end + common_kwargs = (; + node_id=global_state.process_id, + num_nodes=global_state.num_processes, + distributed_runtime_client, + ) # CPU if was_initialized && haskey(state.clients, "cpu") XLA.free_client(state.clients["cpu"]) XLA.PJRT.cpu_client_count[] -= 1 end - cpu = PJRT.CPUClient(global_state.process_id, global_state.num_processes) + cpu = PJRT.CPUClient(; common_kwargs..., asynchronous=true) state.clients["cpu"] = cpu state.default_client = cpu @@ -142,8 +160,9 @@ function initialize_default_clients!(state::BackendState) XLA.free_client(state.clients["tpu"]) XLA.PJRT.tpu_client_count[] -= 1 end - # XXX: process_id? num_processes? - tpu = PJRT.TPUClient(dataset_dir * "/libtpu.so") + tpu = PJRT.TPUClient(; + tpu_path=dataset_dir * "/libtpu.so", common_kwargs... + ) state.clients["tpu"] = tpu state.default_client = tpu catch e @@ -152,22 +171,12 @@ function initialize_default_clients!(state::BackendState) else if !Reactant.precompiling() try - distributed_runtime_client = if global_state.num_processes > 1 - @assert global_state.client !== nothing - global_state.client - else - nothing - end - if was_initialized && haskey(state.clients, "gpu") XLA.free_client(state.clients["gpu"]) XLA.PJRT.gpu_client_count[] -= 1 end - gpu = PJRT.GPUClient( - global_state.process_id, - global_state.num_processes; - allowed_devices=global_state.local_device_ids, - distributed_runtime_client, + gpu = PJRT.GPUClient(; + common_kwargs..., allowed_devices=global_state.local_gpu_device_ids ) state.clients["gpu"] = gpu state.default_client = gpu diff --git a/test/ifrt/low_level.jl b/test/ifrt/low_level.jl new file mode 100644 index 0000000000..350754681a --- /dev/null +++ b/test/ifrt/low_level.jl @@ -0,0 +1,56 @@ +# Testing manual IFRT buffer creation + compilation + execution +using Reactant, Test +using Reactant: XLA +using Reactant.XLA: IFRT + +fn_test1(x, y) = x .+ y +fn_test2(x, y) = x .* y +fn_test3(x, y) = x .+ y' .- x + +@testset "IFRT Low-level API" begin + x = reshape(collect(Float32, 1:64), 8, 8) + y = collect((x .+ 64)') + + pjrt_client = Reactant.XLA.default_backend() + platform_name = lowercase(XLA.platform_name(pjrt_client)) + + ifrt_client = if platform_name == "cpu" + IFRT.CPUClient(; checkcount=false) + elseif platform_name == "gpu" || platform_name == "cuda" + IFRT.GPUClient(; checkcount=false) + elseif platform_name == "tpu" + IFRT.TPUClient(; checkcount=false) + else + error("Unsupported platform: $(platform_name)") + end + + pjrt_x = ConcreteRArray(x) # XXX: Rename to ConcretePJRTArray + pjrt_y = ConcreteRArray(y) # XXX: Rename to ConcretePJRTArray + + ifrt_x = IFRT.Array(ifrt_client, x) # XXX: Use ConcreteIFRTArray once ready + ifrt_y = IFRT.Array(ifrt_client, y) # XXX: Use ConcreteIFRTArray once ready + + @testset for fn in (fn_test1, fn_test2, fn_test3) + pjrt_result = @jit fn(pjrt_x, pjrt_y) + + mlir_mod, mlir_fn_res = Reactant.Compiler.compile_mlir(fn, (pjrt_x, pjrt_y)) + + ifrt_loaded_executable = XLA.compile( + ifrt_client, + XLA.default_device(ifrt_client), + mlir_mod; + num_outputs=length(mlir_fn_res.linear_results), + num_parameters=length(mlir_fn_res.linear_args), + mlir_fn_res.is_sharded, + global_device_ids=Int64[], + num_replicas=1, + num_partitions=1, + ) + + ifrt_result = XLA.execute( + ifrt_loaded_executable, (ifrt_x.buffer, ifrt_y.buffer), UInt8.((0, 0)), Val(1) + ) + + @test convert(Array, only(ifrt_result)) ≈ Array(pjrt_result) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 7f382bed8b..1e848d1715 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Custom Number Types" include("custom_number_types.jl") end @safetestset "Sharding" include("sharding.jl") + + @testset "IFRT" begin + @safetestset "IFRT Low-Level API" include("ifrt/low_level.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"