From 3627e6aa1b0539b6b8f4f43b1fc929c053a64ec4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 09:31:20 -0400 Subject: [PATCH 01/16] feat: query distributed environment state --- docs/src/api/sharding.md | 10 ++++++++++ src/Distributed.jl | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/docs/src/api/sharding.md b/docs/src/api/sharding.md index a32b0183a8..92314ffff2 100644 --- a/docs/src/api/sharding.md +++ b/docs/src/api/sharding.md @@ -12,3 +12,13 @@ Currently we haven't documented all the functions in `Reactant.Sharding`. ```@autodocs Modules = [Reactant.Sharding] ``` + +# [Distributed API](@id distributed-api) + +`Reactant.Distributed` module provides a high-level API to run reactant on multiple hosts. + +Currently we haven't documented all the functions in `Reactant.Distributed`. + +```@autodocs +Modules = [Reactant.Distributed] +``` diff --git a/src/Distributed.jl b/src/Distributed.jl index 7eb44d488f..ff5e62c130 100644 --- a/src/Distributed.jl +++ b/src/Distributed.jl @@ -5,6 +5,27 @@ using Sockets const initialized = Ref(false) +""" + local_rank() + +Returns the local rank of the current process. +""" +local_rank() = Reactant.XLA.global_state.process_id + +""" + num_processes() + +Returns the number of processes. +""" +num_processes() = Reactant.XLA.global_state.num_processes + +""" + is_initialized() + +Returns `true` if the distributed environment has been initialized. +""" +is_initialized() = initialized[] + function initialize(; coordinator_address::Union{Nothing,String}=nothing, num_processes::Union{Nothing,Integer}=nothing, From 47fce04540802002139147027b3af7a7236258ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 10:26:40 -0400 Subject: [PATCH 02/16] fix: define is_addressable fallback --- src/xla/Device.jl | 5 +++++ src/xla/PJRT/Device.jl | 2 ++ 2 files changed, 7 insertions(+) diff --git a/src/xla/Device.jl b/src/xla/Device.jl index f4b27eaa30..3bef856f40 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -24,3 +24,8 @@ function Base.string(device::AbstractDevice) pname = XLA.platform_name(client) return "$(uppercase(pname)):$(device_ordinal(device)) $(device_kind(device))" end + +# Fallback method, preferably all device implementations overload this +function XLA.is_addressable(device::AbstractDevice) + return device ∈ XLA.addressable_devices(XLA.client(device)) +end diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index c4c4a2caa5..71694008ac 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -32,3 +32,5 @@ function XLA.get_local_device_id(device::Device) )::Cint end end + +# TODO: Expose is addressable for pjrt devices in ReactantExtra From 02ac018ddbf55f9f8457ae7feb031b3ade40ec3b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 11:04:10 -0400 Subject: [PATCH 03/16] feat: expose API to get array slices --- src/Sharding.jl | 189 ++++++++++++++++++++++-------------------------- 1 file changed, 88 insertions(+), 101 deletions(-) diff --git a/src/Sharding.jl b/src/Sharding.jl index 53959c25d8..7699ee8d61 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -87,6 +87,13 @@ end function get_shardy_tensor_sharding_attribute end +""" + sharding_to_array_slices(sharding, size_x; client=nothing) + +Given a sharding and an array size, returns the device to array slices mapping. +""" +function sharding_to_array_slices end + """ NoSharding() @@ -115,6 +122,8 @@ function (::NoSharding)(client::XLA.IFRT.Client, device, x::Union{AbstractArray, return XLA.IFRT.AsyncArray(client, x, device), ShardInfo(NoSharding(), nothing) end +sharding_to_array_slices(::NoSharding, size_x; client=nothing) = Base.OneTo.(size_x) + """ NamedSharding( mesh::Mesh, partition_spec::Tuple; @@ -208,6 +217,39 @@ function (sharding::NamedSharding)( return HloSharding(sharding, client, device, x) end +# This doesn't account for the size of the input so in-presence of padding this will be +# incorrect. Hence always use the HloSharding constructor. +function generate_hlo_sharding_from_tensor_attribute(sharding::NamedSharding) + if MLIR.IR._has_context() + ctx = MLIR.IR.context() + else + ctx = MLIR.IR.Context(Reactant.registry[], false) + @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid + end + + MLIR.IR.context!(ctx) do + mesh_op = Reactant.Ops.mesh( + sharding.mesh; mod=MLIR.IR.Module(MLIR.IR.Location(; context=ctx)) + ) + + tensor_sharding_attr = get_shardy_tensor_sharding_attribute( + sharding, ctx, mesh_op.sym_name, mesh_op.mesh_attr; do_transpose=true + ) + + return HloSharding( + XLA.HloSharding( + @ccall MLIR.API.mlir_c.hloShardingFromTensorShardingAttr( + tensor_sharding_attr.attribute::MLIR.API.MlirAttribute, + mesh_op.mesh_attr.attribute::MLIR.API.MlirAttribute, + )::Ptr{Cvoid} + ), + sharding.mesh, + sharding.is_closed, + sharding.priority, + ) + end +end + function get_shardy_tensor_sharding_attribute( sharding::NamedSharding, ctx, mesh_name, mesh_attr; do_transpose=true ) @@ -242,6 +284,12 @@ function get_shardy_tensor_sharding_attribute( ) end +function sharding_to_array_slices(sharding::NamedSharding, size_x; client=nothing) + return sharding_to_array_slices( + generate_hlo_sharding_from_tensor_attribute(sharding), size_x; client + ) +end + # TODO: Something like NamedDims.jl will allow us to support NamedDimsSharding similar to # `levanter` @@ -288,24 +336,25 @@ end return shard_type(HloSharding{M,N}, N) end -function standardize_sharding(sharding::DimsSharding, x::Union{AbstractArray,Number}) +function standardize_sharding(sharding::DimsSharding, size_x) + N = length(size_x) final_dims = map(sharding.dims) do d @assert !iszero(d) "dims cannot contain 0" - return ifelse(d < 0, ndims(x) + d + 1, d) + return ifelse(d < 0, N + d + 1, d) end - dim_indices = ntuple(i -> findfirst(==(i), final_dims), ndims(x)) - partition_spec = ntuple(ndims(x)) do i + dim_indices = ntuple(i -> findfirst(==(i), final_dims), N) + partition_spec = ntuple(N) do i dim_index = dim_indices[i] dim_index === nothing && return nothing # replicated dimension return sharding.partition_spec[dim_index] end - is_closed = ntuple(ndims(x)) do i + is_closed = ntuple(N) do i dim_index = dim_indices[i] dim_index === nothing && return true # replicated dimension return sharding.is_closed[dim_index] end - priority = ntuple(ndims(x)) do i + priority = ntuple(N) do i dim_index = dim_indices[i] dim_index === nothing && return -1 # replicated dimension return sharding.priority[dim_index] @@ -317,7 +366,11 @@ end function (sharding::DimsSharding)( client::XLA.AbstractClient, device, x::Union{AbstractArray,Number} ) - return (standardize_sharding(sharding, x))(client, device, x) + return (standardize_sharding(sharding, size(x)))(client, device, x) +end + +function sharding_to_array_slices(sharding::DimsSharding, size_x; client=nothing) + return sharding_to_array_slices(standardize_sharding(sharding, size_x), size_x; client) end # HloSharding @@ -344,75 +397,47 @@ end return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}} end -# This doesn't account for the size of the input so in-presence of padding this will be -# incorrect. Hence always use the HloSharding constructor. -function generate_hlo_sharding_from_tensor_attribute(sharding::NamedSharding) - if MLIR.IR._has_context() - ctx = MLIR.IR.context() - else - ctx = MLIR.IR.Context(Reactant.registry[], false) - @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - end - - MLIR.IR.context!(ctx) do - mesh_op = Reactant.Ops.mesh( - sharding.mesh; mod=MLIR.IR.Module(MLIR.IR.Location(; context=ctx)) - ) - - tensor_sharding_attr = get_shardy_tensor_sharding_attribute( - sharding, ctx, mesh_op.sym_name, mesh_op.mesh_attr; do_transpose=true - ) - - return HloSharding( - XLA.HloSharding( - @ccall MLIR.API.mlir_c.hloShardingFromTensorShardingAttr( - tensor_sharding_attr.attribute::MLIR.API.MlirAttribute, - mesh_op.mesh_attr.attribute::MLIR.API.MlirAttribute, - )::Ptr{Cvoid} - ), - sharding.mesh, - sharding.is_closed, - sharding.priority, - ) - end -end - -function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) - hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) - +function sharding_to_array_slices(sharding::HloSharding, size_x; client=nothing) # Check if the input needs to be padded. If so this sharding is not valid and we # need to request the tensor sharding from XLA - condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) + condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids + condensed_op_sharding, size_x, sharding.mesh.logical_device_ids ) if needs_padding - # Compile a dummy function to get the tensor sharding - tmp = if x isa Number - Reactant.ConcretePJRTNumber(zero(eltype(x))) + kws = client === nothing ? (;) : (; client) + tmp = if length(size_x) == 0 + Reactant.ConcreteRNumber(zero(Float32); kws...) else - Reactant.ConcretePJRTArray(ones(eltype(x), size(x)...)) + Reactant.ConcreteRArray(ones(Float32, size_x...); kws...) end _, exec, _, _, _ = Reactant.Compiler.compile_xla( Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding) ) + xla_hlo_sharding = convert( Reactant.XLA.HloSharding, only(Reactant.XLA.get_parameter_shardings(exec)) ) - hlo_sharding = HloSharding( - xla_hlo_sharding, - hlo_sharding.mesh, - hlo_sharding.is_closed, - hlo_sharding.priority, + sharding = HloSharding( + xla_hlo_sharding, sharding.mesh, sharding.is_closed, sharding.priority ) - condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) + condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids + condensed_op_sharding, size_x, sharding.mesh.logical_device_ids ) + + @assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl" end + return device_to_array_slices +end + +function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) + hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) + device_to_array_slices = sharding_to_array_slices(hlo_sharding, size(x); client) + data = ntuple(length(hlo_sharding.mesh)) do i XLA.PJRT.AsyncBuffer( client, @@ -426,39 +451,7 @@ end function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x) hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) - - # Check if the input needs to be padded. If so this sharding is not valid and we - # need to request the tensor sharding from XLA - condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) - device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids - ) - - if needs_padding - # Compile a dummy function to get the tensor sharding - tmp = if x isa Number - Reactant.ConcreteIFRTNumber(zero(eltype(x))) - else - Reactant.ConcreteIFRTArray(ones(eltype(x), size(x)...)) - end - _, exec, _, _, _ = Reactant.Compiler.compile_xla( - Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding) - ) - xla_hlo_sharding = convert( - Reactant.XLA.HloSharding, only(Reactant.XLA.get_parameter_shardings(exec)) - ) - hlo_sharding = HloSharding( - xla_hlo_sharding, - hlo_sharding.mesh, - hlo_sharding.is_closed, - hlo_sharding.priority, - ) - - condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding) - device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids - ) - end + device_to_array_slices = sharding_to_array_slices(hlo_sharding, size(x); client) ifrt_sharding = XLA.IFRT.Sharding( vec(Reactant.XLA.get_device.((client,), hlo_sharding.mesh.device_ids)), @@ -471,12 +464,7 @@ end function (sharding::HloSharding)( client::XLA.PJRT.Client, ::Nothing, x::Union{AbstractArray,Number} ) - condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) - - device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), sharding.mesh.logical_device_ids - ) - @assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl" + device_to_array_slices = sharding_to_array_slices(sharding, size(x); client) data = ntuple(length(sharding.mesh)) do i XLA.PJRT.AsyncBuffer( @@ -492,12 +480,7 @@ end function (sharding::HloSharding)( client::XLA.IFRT.Client, ::Nothing, x::Union{AbstractArray,Number} ) - condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) - - device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(x), sharding.mesh.logical_device_ids - ) - @assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl" + device_to_array_slices = sharding_to_array_slices(sharding, size(x); client) ifrt_sharding = XLA.IFRT.Sharding( vec(Reactant.XLA.get_device.((client,), sharding.mesh.device_ids)), @@ -553,6 +536,10 @@ function (sharding::ShardInfo)( return (sharding.sharding)(client, device, x) end +function sharding_to_array_slices(sharding::ShardInfo, size_x; client=nothing) + return sharding_to_array_slices(sharding.sharding, size_x; client) +end + const NoShardInfo = ShardInfo{NoSharding,Nothing} ShardInfo{NoSharding,Nothing}() = ShardInfo(NoSharding(), nothing) From 608099e5a3949686b47065ff4b2927dcdef5bfbe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 11:52:51 -0400 Subject: [PATCH 04/16] fix: generalize getproperty --- src/Sharding.jl | 13 ++++++++----- src/Tracing.jl | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/Sharding.jl b/src/Sharding.jl index 7699ee8d61..1010ede3e3 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -85,6 +85,13 @@ function (T::AbstractSharding)(::XLA.AbstractClient, device, ::Union{AbstractArr ) end +# By default we use same sharding for all leaf nodes +Base.getproperty(sharding::AbstractSharding, name) = sharding +function Base.getproperty(sharding::AbstractSharding, name::Symbol) + name ∈ fieldnames(typeof(sharding)) && return getfield(sharding, name) + return sharding +end + function get_shardy_tensor_sharding_attribute end """ @@ -107,10 +114,6 @@ struct NoSharding <: AbstractSharding end @inline shard_type(::Type{NoSharding}, _) = ShardInfo{NoSharding,Nothing} -# This allows us to mark entire branches as NoSharding -Base.getproperty(::NoSharding, x) = NoSharding() -Base.getproperty(::NoSharding, x::Symbol) = NoSharding() - function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,Number}) device === nothing && (device = XLA.default_device(client)) buffer = XLA.PJRT.AsyncBuffer(client, x, device) @@ -523,7 +526,7 @@ end function Base.getproperty(sharding::ShardInfo, name::Symbol) name ∈ (:sharding, :device_to_array_slices) && return getfield(sharding, name) - return getproperty(sharding.sharding, name) + return getproperty(unwrap_shardinfo(sharding), name) end function get_shardy_tensor_sharding_attribute(sharding::ShardInfo, args...; kwargs...) diff --git a/src/Tracing.jl b/src/Tracing.jl index 91d1e40594..7d539adb82 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1094,7 +1094,7 @@ function make_tracer_unknown( newpath, mode; track_numbers, - sharding=Base.getproperty(sharding, i), + sharding=getproperty(sharding, i), runtime, kwargs..., ) From 16e84f9dbdadf4c99134b05cf23075ca41b7bfaa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 13:45:46 -0500 Subject: [PATCH 05/16] feat: dump debug info + another constant function --- src/Ops.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index 7c4ebbd2c6..81f2989b11 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -133,6 +133,12 @@ end end end +@noinline function constant( + x::Array{TracedRNumber{T},N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) +) where {T,N} + return reshape(vcat(x...), size(x)...; location) +end + @noinline function constant( x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T,N} From a8c4ed41fc0444149041d35a80d1a94336208a31 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 15:58:11 -0500 Subject: [PATCH 06/16] fix: array construction and more scalar issues --- src/Ops.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 81f2989b11..7c4ebbd2c6 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -133,12 +133,6 @@ end end end -@noinline function constant( - x::Array{TracedRNumber{T},N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) -) where {T,N} - return reshape(vcat(x...), size(x)...; location) -end - @noinline function constant( x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T,N} From 9f8be58af5e57ad357bb86c26e57724b42478202 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 18:27:15 -0500 Subject: [PATCH 07/16] fix: unflatten codegen --- src/Compiler.jl | 41 +++++++++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 8e896cdc89..2d05634ab8 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -32,6 +32,12 @@ end return Base.getfield(obj, field) end +@inline function traced_getfield( + @nospecialize(obj::AbstractArray{<:Union{ConcretePJRTNumber,ConcreteIFRTNumber}}), field +) + return Base.getfield(obj, field) +end + @inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} (isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) && return Base.getfield(obj, field) @@ -39,6 +45,7 @@ end end @inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val) + @inline function traced_setfield!( @nospecialize(obj::AbstractArray{T}), field, val ) where {T} @@ -47,6 +54,14 @@ end return Base.setindex!(obj, val, field) end +@inline function traced_setfield!( + @nospecialize(obj::AbstractArray{<:Union{ConcretePJRTNumber,ConcreteIFRTNumber}}), + field, + val, +) + return setfield_carray!(obj, field, val) +end + @inline function traced_setfield!(@nospecialize(obj::Dict), field, val) return Base.setindex!(obj, field, val) end @@ -116,7 +131,7 @@ function create_result( if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -129,7 +144,7 @@ function create_result( end # We will set the data for this later - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -146,7 +161,7 @@ function create_result( if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -157,7 +172,7 @@ function create_result( end # We will set the data for this later - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -176,7 +191,7 @@ function create_result( if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -188,7 +203,7 @@ function create_result( end end - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -208,7 +223,7 @@ function create_result( if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -220,7 +235,7 @@ function create_result( end end - if path_to_shard_info !== nothing # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -295,6 +310,16 @@ function create_result( return :($D([$(elems...)])) end +function create_result( + tocopy::Reactant.XLA.AbstractDevice, + path, + result_stores, + path_to_shard_info, + sharding_mesh, +) + return Meta.quot(:($(tocopy))) +end + function create_result( tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, path, From 80b93d308e2a3e37a54476e4125f488e8d0b67ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 20:45:09 -0500 Subject: [PATCH 08/16] feat: unresharding is atleast partially working --- src/Compiler.jl | 181 ++++++++++++++++++------------- src/xla/IFRT/Array.jl | 28 +++++ src/xla/IFRT/AsyncArray.jl | 12 ++ src/xla/IFRT/LoadedExecutable.jl | 4 +- src/xla/PJRT/LoadedExecutable.jl | 4 +- 5 files changed, 153 insertions(+), 76 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2d05634ab8..0d80f6cfee 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -86,9 +86,7 @@ end return Base.setfield!(obj, field, (val[idx],)) end -function create_result( - tocopy::T, path, result_stores, path_to_shard_info, sharding_mesh -) where {T} +function create_result(tocopy::T, path, args...) where {T} if !isstructtype(typeof(tocopy)) error("cannot copy $tocopy of type $(Core.Typeof(tocopy))") end @@ -99,13 +97,7 @@ function create_result( # If the field is undefined we don't set it. A common example for this is `du2` # for Tridiagonal isdefined(tocopy, i) || continue - ev = create_result( - getfield(tocopy, i), - append_path(path, i), - result_stores, - path_to_shard_info, - sharding_mesh, - ) + ev = create_result(getfield(tocopy, i), append_path(path, i), args...) push!(elems, ev) end @@ -127,11 +119,15 @@ function create_result( result_stores, path_to_shard_info, sharding_mesh, + to_unreshard_results, ) where {T,D,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("XXX: Implement This") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -144,7 +140,7 @@ function create_result( end # We will set the data for this later - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -156,12 +152,20 @@ function create_result( end function create_result( - tocopy::ConcreteIFRTNumber{T,S}, path, result_stores, path_to_shard_info, sharding_mesh + tocopy::ConcreteIFRTNumber{T,S}, + path, + result_stores, + path_to_shard_info, + sharding_mesh, + to_unreshard_results, ) where {T,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("XXX: Implement This") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -172,7 +176,7 @@ function create_result( end # We will set the data for this later - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -187,11 +191,15 @@ function create_result( result_stores, path_to_shard_info, sharding_mesh, + to_unreshard_results, ) where {T,N,D,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("XXX: Implement This") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -203,7 +211,11 @@ function create_result( end end - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + # We will set the data for this later + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("XXX: Implement This") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -211,19 +223,28 @@ function create_result( ($(tocopy.data)...,), $(tocopy.shape), $sharding )) end - # We will set the data for this later return :(ConcretePJRTArray{$T,$N,$D,$S}( $(tocopy.data), $(tocopy.shape), $(tocopy.sharding) )) end function create_result( - tocopy::ConcreteIFRTArray{T,N,S}, path, result_stores, path_to_shard_info, sharding_mesh + tocopy::ConcreteIFRTArray{T,N,S}, + path, + result_stores, + path_to_shard_info, + sharding_mesh, + to_unreshard_results, ) where {T,N,S} if haskey(result_stores, path) restore = result_stores[path] delete!(result_stores, path) - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + return :(generate_unresharded_ifrt_array( + $(restore), $(to_unreshard_results[path]), $(T), $(N), $(tocopy.shape) + )) + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -231,11 +252,17 @@ function create_result( $(restore), $(tocopy.shape), $sharding )) else - return :(ConcreteIFRTArray{$T,$N}($restore, $(tocopy.shape))) + return :(ConcreteIFRTArray{$T,$N}($(restore), $(tocopy.shape))) end end - if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) # restore sharding + # We will set the data for this later + if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + return :(generate_unresharded_ifrt_array( + $(tocopy.data), $(to_unreshard_results[path]), $(T), $(N), $(tocopy.shape) + )) + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -243,89 +270,67 @@ function create_result( $(tocopy.data), $(tocopy.shape), $sharding )) end - # We will set the data for this later return :(ConcreteIFRTArray{$T,$N,$S}( $(tocopy.data), $(tocopy.shape), $(tocopy.sharding) )) end -function create_result( - tocopy::Array{T,N}, path, result_stores, path_to_shard_info, sharding_mesh -) where {T,N} +function generate_unresharded_ifrt_array( + arr::Reactant.XLA.IFRT.AsyncArray, + (target_device, output_sharding, mesh), + ::Type{T}, + N::Integer, + size_arr, +) where {T} + single_device_arrays = Reactant.XLA.IFRT.disassemble_into_single_device_arrays( + Reactant.XLA.IFRT.replicate_array_to_all_devices(arr, output_sharding, mesh), true + ) + devs = Reactant.XLA.device.(single_device_arrays) + idx = findfirst(isequal(target_device), devs) + return ConcreteIFRTArray{T,N}( + Reactant.XLA.IFRT.AsyncArray(single_device_arrays[idx], nothing), size(arr) + ) +end + +function create_result(tocopy::Array{T,N}, path, args...) where {T,N} elems = Expr[] for (i, v) in enumerate(tocopy) - push!( - elems, - create_result( - v, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh - ), - ) + push!(elems, create_result(v, append_path(path, i), args...)) end # TODO is there a way to not call `reshape` here? what expr is used for array literals? return :(reshape($T[$(elems...)], $(size(tocopy))...)) end -function create_result( - tocopy::Tuple, path, result_stores, path_to_shard_info, sharding_mesh -) +function create_result(tocopy::Tuple, path, args...) elems = Union{Symbol,Expr}[] for (k, v) in pairs(tocopy) - push!( - elems, - create_result( - v, append_path(path, k), result_stores, path_to_shard_info, sharding_mesh - ), - ) + push!(elems, create_result(v, append_path(path, k), args...)) end return :(($(elems...),)) end -function create_result( - tocopy::NamedTuple{K,T}, path, result_stores, path_to_shard_info, sharding_mesh -) where {K,T} +function create_result(tocopy::NamedTuple{K,T}, path, args...) where {K,T} elems = Union{Symbol,Expr}[] for (i, (k, v)) in enumerate(pairs(tocopy)) - push!( - elems, - create_result( - v, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh - ), - ) + push!(elems, create_result(v, append_path(path, i), args...)) end return :(NamedTuple{$K}(($(elems...),))) end -function create_result( - tocopy::D, path, result_stores, path_to_shard_info, sharding_mesh -) where {K,V,D<:AbstractDict{K,V}} +function create_result(tocopy::D, path, args...) where {K,V,D<:AbstractDict{K,V}} elems = Expr[] for (i, p) in enumerate(pairs(tocopy)) - push!( - elems, - create_result( - p, append_path(path, i), result_stores, path_to_shard_info, sharding_mesh - ), - ) + push!(elems, create_result(p, append_path(path, i), args...)) end return :($D([$(elems...)])) end -function create_result( - tocopy::Reactant.XLA.AbstractDevice, - path, - result_stores, - path_to_shard_info, - sharding_mesh, -) +function create_result(tocopy::Reactant.XLA.AbstractDevice, args...) return Meta.quot(:($(tocopy))) end function create_result( - tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, - path, - result_stores, - path_to_shard_info, - sharding_mesh, + tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char}, args... ) return Meta.quot(tocopy) end @@ -1339,6 +1344,7 @@ function codegen_flatten!( flatten_names = Symbol[] flatten_code = Expr[] runtime = XLA.runtime(client) + resharded_inputs = Dict{Tuple,Any}() if is_sharded inv_seen_args = Reactant.OrderedIdDict() @@ -1392,6 +1398,10 @@ function codegen_flatten!( ) end else + resharded_inputs[path[3:end]] = ( + Reactant.XLA.device(carg), condensed_op_sharding, mesh + ) + push!(flatten_code, :($usbuf = $flatcode)) device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices( condensed_op_sharding, size(carg), mesh.logical_device_ids @@ -1461,6 +1471,10 @@ function codegen_flatten!( @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) else + resharded_inputs[path[3:end]] = ( + Reactant.XLA.device(carg), condensed_op_sharding, mesh + ) + # XXX: Currently we copy to host and then make the transfer to the # sharded devices. This is not ideal, we might be able to do a # device-to-device transfer, maybe using reshard? @@ -1498,7 +1512,8 @@ function codegen_flatten!( is_sharded && runtime isa Val{:PJRT} && (flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...)) - return flatten_names, flatten_code + + return flatten_names, flatten_code, resharded_inputs end """ @@ -1518,6 +1533,7 @@ function codegen_unflatten!( linear_result_shard_info, sharding_mesh, client, + resharded_inputs, ) cache_dict = gensym("cache_dict") has_cache_dict = false @@ -1535,6 +1551,8 @@ function codegen_unflatten!( end ctypes = Union{arrtype,numtype} + to_unreshard_results = Dict{Tuple,Any}() + # mutate the result stores to point to the correct concrete results for (concrete_res_name, result, shard_info) in zip(concretized_res_names, linear_results, linear_result_shard_info) @@ -1548,6 +1566,15 @@ function codegen_unflatten!( if path[1] == :result unflatcode = :result path = path[2:end] + + if Reactant.TracedUtils.has_argidx(result) + _, argidx = Reactant.TracedUtils.get_argidx(result) + arg_path = argidx[3:end] + if haskey(resharded_inputs, arg_path) + to_unreshard_results[path] = resharded_inputs[arg_path] + end + end + result_stores[path] = concrete_res_name if path_to_shard_info !== nothing path_to_shard_info[path] = shard_info @@ -1617,7 +1644,12 @@ function codegen_unflatten!( prevkeys = collect(keys(result_stores)) result_code = create_result( - concrete_result, (), result_stores, path_to_shard_info, sharding_mesh + concrete_result, + (), + result_stores, + path_to_shard_info, + sharding_mesh, + to_unreshard_results, ) postkeys = collect(keys(result_stores)) used = [t for t in prevkeys if !in(t, postkeys)] @@ -1878,7 +1910,7 @@ function compile(f, args; sync=false, kwargs...) path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing # generate Julia `Thunk` code - flatten_arg_names, flatten_code = codegen_flatten!( + flatten_arg_names, flatten_code, resharded_inputs = codegen_flatten!( linear_args, seen_args, result_stores, @@ -1918,6 +1950,7 @@ function compile(f, args; sync=false, kwargs...) linear_result_shard_info, mlir_fn_res.sharding_mesh, client, + resharded_inputs, ) sync_call = if sync diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index d235e1d166..676b5cf2bc 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -202,6 +202,34 @@ function disassemble_into_single_device_arrays(array::Array, only_addressable_de return [Array(unsafe_load(arrays, i)) for i in 1:narrays[]] end +function replicate_array_to_all_devices(array::Array, sharding, mesh) + is_fully_replicated(XLA.sharding(array)) && return array + + hlo_sharding = Reactant.Sharding.HloSharding( + convert(XLA.HloSharding, sharding), + mesh, + ntuple(Returns(1), ndims(array)), + ntuple(Returns(-1), ndims(array)), + ) + shard_info = Reactant.Sharding.ShardInfo( + hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size(array)) + ) + sharding_constraint = Reactant.Sharding.NamedSharding( + mesh, ntuple(Returns(nothing), ndims(array)) + ) + data = Reactant.ConcreteIFRTArray{eltype(array),ndims(array), typeof(shard_info)}( + AsyncArray(array, nothing), size(array), shard_info + ) + + fn(x) = Reactant.Ops.sharding_constraint(x, sharding_constraint) + + fn_compiled = Reactant.compile((data,)) do x + return Reactant.Ops.sharding_constraint(x, sharding_constraint) + end + + return fn_compiled(data).data.buffer +end + function XLA.unsafe_buffer_pointer(::Array) return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`") end diff --git a/src/xla/IFRT/AsyncArray.jl b/src/xla/IFRT/AsyncArray.jl index b049289b13..c28ff320e2 100644 --- a/src/xla/IFRT/AsyncArray.jl +++ b/src/xla/IFRT/AsyncArray.jl @@ -6,3 +6,15 @@ end const AsyncEmptyArray = AsyncArray(Array(C_NULL), nothing) AsyncArray(args...; kwargs...) = AsyncArray(Array(args...; kwargs...), nothing) + +function disassemble_into_single_device_arrays( + x::AsyncArray, only_addressable_devices::Bool +) + wait(x) + return disassemble_into_single_device_arrays(x.buffer, only_addressable_devices) +end + +function replicate_array_to_all_devices(array::AsyncArray, args...) + wait(array) + return replicate_array_to_all_devices(array.buffer, args...) +end diff --git a/src/xla/IFRT/LoadedExecutable.jl b/src/xla/IFRT/LoadedExecutable.jl index 4fffdb0ae0..3676f15560 100644 --- a/src/xla/IFRT/LoadedExecutable.jl +++ b/src/xla/IFRT/LoadedExecutable.jl @@ -41,7 +41,9 @@ for (jlop, xlaop, field) in ( ), ) @eval function XLA.$(jlop)(exec::LoadedExecutable) - exec.is_sharded || return XLA.OpSharding[] + if !exec.is_sharded || iszero(exec.$(field)) + return XLA.OpSharding[] + end op_shardings = Ref{NTuple{exec.$(field),Ptr{Cvoid}}}() diff --git a/src/xla/PJRT/LoadedExecutable.jl b/src/xla/PJRT/LoadedExecutable.jl index 11eb845b6e..b2b2e8ea5a 100644 --- a/src/xla/PJRT/LoadedExecutable.jl +++ b/src/xla/PJRT/LoadedExecutable.jl @@ -35,7 +35,9 @@ for (jlop, xlaop, field) in ( (:get_parameter_shardings, :PjRtLoadedExecutableGetParameterShardings, :num_parameters), ) @eval function XLA.$(jlop)(exec::LoadedExecutable) - exec.is_sharded || return XLA.OpSharding[] + if !exec.is_sharded || iszero(exec.$(field)) + return XLA.OpSharding[] + end op_shardings = Ref{NTuple{exec.$(field),Ptr{Cvoid}}}() From 5ee11e5e339a012c67e738bea19e805ea185010e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Mar 2025 20:49:53 -0500 Subject: [PATCH 09/16] fix: replication --- src/Compiler.jl | 13 +++++++++---- src/xla/IFRT/Array.jl | 15 ++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 0d80f6cfee..956c231695 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -283,13 +283,18 @@ function generate_unresharded_ifrt_array( size_arr, ) where {T} single_device_arrays = Reactant.XLA.IFRT.disassemble_into_single_device_arrays( - Reactant.XLA.IFRT.replicate_array_to_all_devices(arr, output_sharding, mesh), true + Reactant.XLA.IFRT.replicate_array_to_all_devices( + arr, output_sharding, mesh, size_arr + ), + true, ) devs = Reactant.XLA.device.(single_device_arrays) idx = findfirst(isequal(target_device), devs) - return ConcreteIFRTArray{T,N}( - Reactant.XLA.IFRT.AsyncArray(single_device_arrays[idx], nothing), size(arr) - ) + res_arr = Reactant.XLA.IFRT.AsyncArray(single_device_arrays[idx], nothing) + res_arr_size = reverse(size(res_arr)) + @assert size_arr == res_arr_size "Expected size of array to be $(size_arr), but got \ + $(res_arr_size)" + return ConcreteIFRTArray{T,N}(res_arr, size_arr) end function create_result(tocopy::Array{T,N}, path, args...) where {T,N} diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 676b5cf2bc..fc2d380104 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -162,6 +162,7 @@ function XLA.to_host(buffer::Array, data, reactant_sharding) client = XLA.client(buffer) all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids) + # TODO: Test if the below logic for replication works for distributed cases as well if any(!XLA.is_addressable, all_devices) @warn "Not all devices are addressable. Currently we only fill in the data for \ addressable devices. Remaining slices of data in `data` are left \ @@ -202,23 +203,23 @@ function disassemble_into_single_device_arrays(array::Array, only_addressable_de return [Array(unsafe_load(arrays, i)) for i in 1:narrays[]] end -function replicate_array_to_all_devices(array::Array, sharding, mesh) +function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr) is_fully_replicated(XLA.sharding(array)) && return array hlo_sharding = Reactant.Sharding.HloSharding( convert(XLA.HloSharding, sharding), mesh, - ntuple(Returns(1), ndims(array)), - ntuple(Returns(-1), ndims(array)), + ntuple(Returns(1), length(size_arr)), + ntuple(Returns(-1), length(size_arr)), ) shard_info = Reactant.Sharding.ShardInfo( - hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size(array)) + hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size_arr) ) sharding_constraint = Reactant.Sharding.NamedSharding( - mesh, ntuple(Returns(nothing), ndims(array)) + mesh, ntuple(Returns(nothing), length(size_arr)) ) - data = Reactant.ConcreteIFRTArray{eltype(array),ndims(array), typeof(shard_info)}( - AsyncArray(array, nothing), size(array), shard_info + data = Reactant.ConcreteIFRTArray{eltype(array),length(size_arr),typeof(shard_info)}( + AsyncArray(array, nothing), size_arr, shard_info ) fn(x) = Reactant.Ops.sharding_constraint(x, sharding_constraint) From e89a3fe19f8af939cf1c08963f88208c55d62a42 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 09:29:45 -0500 Subject: [PATCH 10/16] fix: add unchecked_onto --- src/ConcreteRArray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 13bda152f1..4e61656d7e 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -25,6 +25,10 @@ end Base.strides(x::AbstractConcreteArray) = Base.size_to_strides(1, size(x)...) +function Base.unchecked_oneto(x::AbstractConcreteNumber{<:Integer}) + return Base.unchecked_oneto(to_number(x)) +end + # Ensure the device and client are the same as the input for numType in (:ConcretePJRTNumber, :ConcreteIFRTNumber) @eval function Base.float(x::$(numType){T}) where {T} From 60640bd36c379460a00b607cec2efcae07e6b643 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 10:34:06 -0500 Subject: [PATCH 11/16] feat: copy data from all processes for to_host --- src/Compiler.jl | 26 ++++++++++++--- src/TracedUtils.jl | 23 ++++++++++++- src/xla/IFRT/Array.jl | 77 ++++++++++++++++++++++++++----------------- 3 files changed, 89 insertions(+), 37 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 956c231695..c5f630443b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -126,7 +126,7 @@ function create_result( delete!(result_stores, path) if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) if haskey(to_unreshard_results, path) - error("XXX: Implement This") + error("TODO: Not yet Implemented. Use IFRT for this.") end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) @@ -141,6 +141,9 @@ function create_result( # We will set the data for this later if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("TODO: Not yet Implemented. Use IFRT for this.") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -164,7 +167,7 @@ function create_result( delete!(result_stores, path) if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) if haskey(to_unreshard_results, path) - error("XXX: Implement This") + error("TODO: Not yet Implemented.") end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) @@ -177,6 +180,9 @@ function create_result( # We will set the data for this later if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) + if haskey(to_unreshard_results, path) + error("TODO: Not yet Implemented.") + end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) ) @@ -198,7 +204,7 @@ function create_result( delete!(result_stores, path) if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) if haskey(to_unreshard_results, path) - error("XXX: Implement This") + error("TODO: Not yet Implemented. Use IFRT for this.") end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) @@ -214,7 +220,7 @@ function create_result( # We will set the data for this later if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) if haskey(to_unreshard_results, path) - error("XXX: Implement This") + error("TODO: Not yet Implemented. Use IFRT for this.") end sharding = __reconstruct_shardinfo( path, path_to_shard_info, sharding_mesh, ndims(tocopy) @@ -767,6 +773,8 @@ function compile_mlir!( fn_kwargs=(), raise::Union{Bool,String}=false, input_shardings=nothing, + output_shardings=nothing, + do_transpose=true, runtime::Union{Val{:PJRT},Val{:IFRT}}, ) # Explicitly don't use block! to avoid creating a closure, which creates @@ -788,7 +796,15 @@ function compile_mlir!( mlir_fn_res = try Reactant.TracedUtils.make_mlir_fn( - f, args, fn_kwargs, "main", true; input_shardings, runtime + f, + args, + fn_kwargs, + "main", + true; + input_shardings, + output_shardings, + runtime, + do_transpose, ) finally deactivate_raising!(is_raising) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 7a65eb669e..a7f4d4997d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -186,7 +186,8 @@ function make_mlir_fn( args_in_result::Symbol=:all, construct_function_without_args::Bool=false, do_transpose=true, - input_shardings=nothing, # This is not meant to be used by the user. + input_shardings=nothing, # This is not meant to be used by the user. + output_shardings=nothing, # This is not meant to be used by the user. runtime=nothing, ) if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction @@ -201,6 +202,7 @@ function make_mlir_fn( do_transpose, args_in_result, input_shardings, + output_shardings, runtime, ) mlir_fn_res.fnwrapped = true @@ -436,6 +438,25 @@ function make_mlir_fn( ) end end + + # XXX: Generalize the output shardings and expose it to the user + # output_shardings is a Int -> Sharding mapping + if output_shardings !== nothing + for (i, arg) in enumerate(linear_results) + if haskey(output_shardings, i) + sharding = output_shardings[i] + (; sym_name, mesh_attr) = mesh_cache[sharding.mesh] + MLIR.API.mlirFuncSetResultAttr( + func2, + i - 1, + "sdy.sharding", + Reactant.Sharding.get_shardy_tensor_sharding_attribute( + sharding, ctx, sym_name, mesh_attr + ), + ) + end + end + end else sharding_mesh = nothing end diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index fc2d380104..27fe6db970 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -150,6 +150,9 @@ function XLA.to_host(buffer::Array, data, reactant_sharding) if reactant_sharding isa Reactant.Sharding.NoSharding data_buffer = first(single_device_arrays) + data_buffer_shape = reverse(size(data_buffer)) + @assert size(data) == data_buffer_shape "Expected data to be of size \ + $(size(data)), got $(data_buffer_shape)" GC.@preserve data_buffer data begin @ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer( data_buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} @@ -162,30 +165,33 @@ function XLA.to_host(buffer::Array, data, reactant_sharding) client = XLA.client(buffer) all_devices = XLA.get_device.((client,), reactant_sharding.mesh.device_ids) - # TODO: Test if the below logic for replication works for distributed cases as well - if any(!XLA.is_addressable, all_devices) - @warn "Not all devices are addressable. Currently we only fill in the data for \ - addressable devices. Remaining slices of data in `data` are left \ - untouched." + if any(XLA.is_addressable, all_devices) + # Take a fast path if all devices are addressable + array_slices, _ = XLA.sharding_to_concrete_array_indices( + convert(XLA.CondensedOpSharding, reactant_sharding.hlo_sharding), + size(data), + reactant_sharding.mesh.logical_device_ids, + ) + array_slices = [ + slice for + (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) + ] + + @assert length(array_slices) == length(single_device_arrays) + + for (slice, arr) in zip(array_slices, single_device_arrays) + data_slice = data isa Base.RefValue ? data : data[slice...] + XLA.to_host(arr, data_slice, Reactant.Sharding.NoSharding()) + data isa Base.RefValue || (data[slice...] .= data_slice) + end end - array_slices, _ = XLA.sharding_to_concrete_array_indices( - convert(XLA.CondensedOpSharding, reactant_sharding.hlo_sharding), - size(data), - reactant_sharding.mesh.logical_device_ids, + # Here we need to copy data from all the processes to the host + arr = replicate_array_to_all_devices( + buffer, reactant_sharding, reactant_sharding.mesh, size(data) ) - array_slices = [ - slice for - (device, slice) in zip(all_devices, array_slices) if XLA.is_addressable(device) - ] + XLA.to_host(arr, data, Reactant.Sharding.NoSharding()) - @assert length(array_slices) == length(single_device_arrays) - - for (slice, arr) in zip(array_slices, single_device_arrays) - data_slice = data isa Base.RefValue ? data : data[slice...] - XLA.to_host(arr, data_slice, Reactant.Sharding.NoSharding()) - data isa Base.RefValue || (data[slice...] .= data_slice) - end return nothing end @@ -206,27 +212,36 @@ end function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr) is_fully_replicated(XLA.sharding(array)) && return array - hlo_sharding = Reactant.Sharding.HloSharding( - convert(XLA.HloSharding, sharding), - mesh, - ntuple(Returns(1), length(size_arr)), - ntuple(Returns(-1), length(size_arr)), - ) + if sharding isa Reactant.Sharding.HloSharding + hlo_sharding = sharding + else + hlo_sharding = Reactant.Sharding.HloSharding( + convert(XLA.HloSharding, sharding), + mesh, + ntuple(Returns(1), length(size_arr)), + ntuple(Returns(-1), length(size_arr)), + ) + end + shard_info = Reactant.Sharding.ShardInfo( hlo_sharding, Reactant.Sharding.sharding_to_array_slices(hlo_sharding, size_arr) ) sharding_constraint = Reactant.Sharding.NamedSharding( mesh, ntuple(Returns(nothing), length(size_arr)) ) + data = Reactant.ConcreteIFRTArray{eltype(array),length(size_arr),typeof(shard_info)}( AsyncArray(array, nothing), size_arr, shard_info ) - fn(x) = Reactant.Ops.sharding_constraint(x, sharding_constraint) - - fn_compiled = Reactant.compile((data,)) do x - return Reactant.Ops.sharding_constraint(x, sharding_constraint) - end + # TODO: Directly write the MLIR for this part?? + fn_compiled = Reactant.compile( + identity, + (data,); + shardy_passes=:to_mhlo_shardings, + optimize=false, + output_shardings=Dict(1 => sharding_constraint), + ) return fn_compiled(data).data.buffer end From 905e1064f7ddc7201b4d837ba9053d3a1137e730 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 11:00:24 -0500 Subject: [PATCH 12/16] fix: always assembly from single shards --- src/xla/IFRT/Array.jl | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 27fe6db970..73c3036b9e 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -32,23 +32,6 @@ end function Array( client::Client, array::Base.Array{T,N}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive,N} - sizear = collect(Int64, reverse(size(array))) - - if is_single_device_sharding(sharding) || is_fully_replicated(sharding) - buffer = GC.@preserve array sizear begin - @ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer( - client.client::Ptr{Cvoid}, - array::Ptr{T}, - XLA.primitive_type(T)::Cint, - N::Csize_t, - sizear::Ptr{Int64}, - sharding.ptr::Ptr{Cvoid}, - 0::Cint, # kAlwaysCopy - )::Ptr{Cvoid} - end - return Array(buffer) - end - all_devices = XLA.devices(sharding) array_slices, _ = XLA.sharding_to_concrete_array_indices( convert(XLA.HloSharding, sharding), From 36d55d551f9299846d83f96c71de1c7f7444e46d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 11:07:40 -0500 Subject: [PATCH 13/16] fix: more fixups --- src/ConcreteRArray.jl | 6 ++++-- src/xla/IFRT/Array.jl | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 4e61656d7e..c0820c8596 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -25,8 +25,10 @@ end Base.strides(x::AbstractConcreteArray) = Base.size_to_strides(1, size(x)...) -function Base.unchecked_oneto(x::AbstractConcreteNumber{<:Integer}) - return Base.unchecked_oneto(to_number(x)) +@static if isdefined(Base, :unchecked_oneto) + function Base.unchecked_oneto(x::AbstractConcreteNumber{<:Integer}) + return Base.unchecked_oneto(to_number(x)) + end end # Ensure the device and client are the same as the input diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 73c3036b9e..1c86787be6 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -7,6 +7,15 @@ mutable struct Array <: XLA.AbstractBuffer end end +function Array( + client::Client, + array::Reactant.ReactantPrimitive, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +) + return Array(client, fill(array), device, memory_kind) +end + function Array( client::Client, array::Base.Array{T,N}, From 50aa163ec44f4725217ebe775df8972ec4226ee6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 12:44:48 -0400 Subject: [PATCH 14/16] fix: traced_getfield for arrays --- src/Compiler.jl | 2 +- test/basic.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index c5f630443b..4d094478db 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -35,7 +35,7 @@ end @inline function traced_getfield( @nospecialize(obj::AbstractArray{<:Union{ConcretePJRTNumber,ConcreteIFRTNumber}}), field ) - return Base.getfield(obj, field) + return Base.getindex(obj, field) end @inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} diff --git a/test/basic.jl b/test/basic.jl index e35e526386..37c65abe3f 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -512,7 +512,7 @@ end end @testset for op in [round, ceil, floor] - for x in (rand(Float32, (3, 3)), rand(Float64)) + @testset "$(typeof(x)) : $(size(x))" for x in (rand(Float32, (3, 3)), rand(Float64)) intop = Base.Fix1(op, Int) x_ra = Reactant.to_rarray.(x; track_numbers=Number) From 6b4f19aae571ad60f63e554ff63d66422a4a4abf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 13:34:15 -0400 Subject: [PATCH 15/16] fix: sharding with non-divisible cases --- src/Compiler.jl | 25 +++++++++++++++++-------- src/Sharding.jl | 39 ++++++++++++++++++++++++++++----------- src/xla/Sharding.jl | 7 ++++++- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 4d094478db..711dbf5840 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1334,6 +1334,10 @@ function compile_call_expr(mod, compiler, options::Dict, args...) ) end +function assert_mismatched_sharding(hlo_sharding_from_input, hlo_sharding_from_executable) + @assert hlo_sharding_from_executable == hlo_sharding_from_input "Sharding provided by the user ($(string(hlo_sharding_from_input))) does not match the sharding computed by XLA ($(string(hlo_sharding_from_executable))). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." +end + """ codegen_flatten! @@ -1403,12 +1407,14 @@ function codegen_flatten!( condensed_op_sharding = convert( XLA.CondensedOpSharding, linear_parameter_shardings[i] ) + hlo_sharding_from_executable = convert( + XLA.HloSharding, condensed_op_sharding + ) if Reactant.Sharding.is_sharded(carg) - arg_condensed_op_sharding = convert( - XLA.CondensedOpSharding, carg.sharding.sharding.hlo_sharding - ) # Check if the sharding provided is same as the one we have - @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." + assert_mismatched_sharding( + carg.sharding.sharding.hlo_sharding, hlo_sharding_from_executable + ) push!(flatten_code, :($usbuf = $flatcode.data)) for j in 1:length(mesh) @@ -1484,12 +1490,15 @@ function codegen_flatten!( condensed_op_sharding = convert( XLA.CondensedOpSharding, linear_parameter_shardings[i] ) + hlo_sharding_from_executable = convert( + XLA.HloSharding, condensed_op_sharding + ) + if Reactant.Sharding.is_sharded(carg) - arg_condensed_op_sharding = convert( - XLA.CondensedOpSharding, carg.sharding.sharding.hlo_sharding - ) # Check if the sharding provided is same as the one we have - @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE." + assert_mismatched_sharding( + carg.sharding.sharding.hlo_sharding, hlo_sharding_from_executable + ) push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) else resharded_inputs[path[3:end]] = ( diff --git a/src/Sharding.jl b/src/Sharding.jl index 1010ede3e3..e185131906 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -95,9 +95,13 @@ end function get_shardy_tensor_sharding_attribute end """ - sharding_to_array_slices(sharding, size_x; client=nothing) + sharding_to_array_slices( + sharding, size_x; client=nothing, return_updated_sharding=Val(false) + ) -Given a sharding and an array size, returns the device to array slices mapping. +Given a sharding and an array size, returns the device to array slices mapping. If +`return_updated_sharding` is `Val(true)`, the updated sharding is returned as well (for +inputs requiring padding). """ function sharding_to_array_slices end @@ -125,7 +129,13 @@ function (::NoSharding)(client::XLA.IFRT.Client, device, x::Union{AbstractArray, return XLA.IFRT.AsyncArray(client, x, device), ShardInfo(NoSharding(), nothing) end -sharding_to_array_slices(::NoSharding, size_x; client=nothing) = Base.OneTo.(size_x) +function sharding_to_array_slices( + sharding::NoSharding, size_x; client=nothing, return_updated_sharding=Val(false) +) + slices = Base.OneTo.(size_x) + return_updated_sharding isa Val{true} && return (slices, sharding) + return slices +end """ NamedSharding( @@ -287,9 +297,9 @@ function get_shardy_tensor_sharding_attribute( ) end -function sharding_to_array_slices(sharding::NamedSharding, size_x; client=nothing) +function sharding_to_array_slices(sharding::NamedSharding, size_x; kwargs...) return sharding_to_array_slices( - generate_hlo_sharding_from_tensor_attribute(sharding), size_x; client + generate_hlo_sharding_from_tensor_attribute(sharding), size_x; kwargs... ) end @@ -372,8 +382,10 @@ function (sharding::DimsSharding)( return (standardize_sharding(sharding, size(x)))(client, device, x) end -function sharding_to_array_slices(sharding::DimsSharding, size_x; client=nothing) - return sharding_to_array_slices(standardize_sharding(sharding, size_x), size_x; client) +function sharding_to_array_slices(sharding::DimsSharding, size_x; kwargs...) + return sharding_to_array_slices( + standardize_sharding(sharding, size_x), size_x; kwargs... + ) end # HloSharding @@ -400,7 +412,9 @@ end return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}} end -function sharding_to_array_slices(sharding::HloSharding, size_x; client=nothing) +function sharding_to_array_slices( + sharding::HloSharding, size_x; client=nothing, return_updated_sharding=Val(false) +) # Check if the input needs to be padded. If so this sharding is not valid and we # need to request the tensor sharding from XLA condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) @@ -434,12 +448,15 @@ function sharding_to_array_slices(sharding::HloSharding, size_x; client=nothing) @assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl" end + return_updated_sharding isa Val{true} && return (device_to_array_slices, sharding) return device_to_array_slices end function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) - device_to_array_slices = sharding_to_array_slices(hlo_sharding, size(x); client) + device_to_array_slices, hlo_sharding = sharding_to_array_slices( + hlo_sharding, size(x); client, return_updated_sharding=Val(true) + ) data = ntuple(length(hlo_sharding.mesh)) do i XLA.PJRT.AsyncBuffer( @@ -539,8 +556,8 @@ function (sharding::ShardInfo)( return (sharding.sharding)(client, device, x) end -function sharding_to_array_slices(sharding::ShardInfo, size_x; client=nothing) - return sharding_to_array_slices(sharding.sharding, size_x; client) +function sharding_to_array_slices(sharding::ShardInfo, size_x; kwargs...) + return sharding_to_array_slices(sharding.sharding, size_x; kwargs...) end const NoShardInfo = ShardInfo{NoSharding,Nothing} diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index f20b6ac898..a719df2116 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -326,6 +326,11 @@ mutable struct HloSharding end end +# TODO: implement equality from the C++ side +function Base.:(==)(hsharding1::HloSharding, hsharding2::HloSharding) + return string(hsharding1) == string(hsharding2) +end + function free_hlo_sharding(hlo_sharding::HloSharding) @ccall MLIR.API.mlir_c.free_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid end @@ -367,7 +372,7 @@ function Base.string(hlo_sharding::HloSharding) return unsafe_string_and_free(str) end -function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding) +function Base.show(io::IO, hlo_sharding::HloSharding) print(io, "XLA.HloSharding(\"", string(hlo_sharding), "\")") return nothing end From 68d0068873486a8f60b6d2bf5a4c2401372c3041 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 21 Mar 2025 13:58:32 -0400 Subject: [PATCH 16/16] fix: more tracing fixes --- src/Compiler.jl | 6 ++++++ src/Sharding.jl | 14 ++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 711dbf5840..2462f98474 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -34,6 +34,12 @@ end @inline function traced_getfield( @nospecialize(obj::AbstractArray{<:Union{ConcretePJRTNumber,ConcreteIFRTNumber}}), field +) + return Base.getfield(obj, field) +end + +@inline function traced_getfield( + @nospecialize(obj::Array{<:Union{ConcretePJRTNumber,ConcreteIFRTNumber}}), field ) return Base.getindex(obj, field) end diff --git a/src/Sharding.jl b/src/Sharding.jl index e185131906..702496fb87 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -453,9 +453,11 @@ function sharding_to_array_slices( end function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) - hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) device_to_array_slices, hlo_sharding = sharding_to_array_slices( - hlo_sharding, size(x); client, return_updated_sharding=Val(true) + generate_hlo_sharding_from_tensor_attribute(sharding), + size(x); + client, + return_updated_sharding=Val(true), ) data = ntuple(length(hlo_sharding.mesh)) do i @@ -470,8 +472,12 @@ function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x) end function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x) - hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding) - device_to_array_slices = sharding_to_array_slices(hlo_sharding, size(x); client) + device_to_array_slices, hlo_sharding = sharding_to_array_slices( + generate_hlo_sharding_from_tensor_attribute(sharding), + size(x); + client, + return_updated_sharding=Val(true), + ) ifrt_sharding = XLA.IFRT.Sharding( vec(Reactant.XLA.get_device.((client,), hlo_sharding.mesh.device_ids)),