diff --git a/docs/make.jl b/docs/make.jl index 64fbf78042..42100ab21a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,8 +24,11 @@ examples = [ pages = [ "Reactant.jl" => "index.md", "Introduction" => ["Getting Started" => "introduction/index.md"], - "Tutorials" => - ["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"], + "Tutorials" => [ + "Overview" => "tutorials/index.md", + "Profiling" => "tutorials/profiling.md", + "Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md", + ], "API Reference" => [ "Reactant API" => "api/api.md", "Ops" => "api/ops.md", @@ -38,6 +41,11 @@ pages = [ "Func" => "api/func.md", "StableHLO" => "api/stablehlo.md", "VHLO" => "api/vhlo.md", + "GPU" => "api/gpu.md", + "LLVM" => "api/llvm.md", + "NVVM" => "api/nvvm.md", + "TPU" => "api/tpu.md", + "Triton" => "api/triton.md", ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 46b7347039..75f78b0dfa 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -56,8 +56,12 @@ export default defineConfig({ { text: "Tutorials", items: [ - {text: "Overview", link: "/tutorials/"}, + { text: "Overview", link: "/tutorials/" }, {text: "Profiling", link: "/tutorials/profiling"}, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching" + }, ], }, { @@ -112,6 +116,10 @@ export default defineConfig({ items: [ { text: "Overview", link: "/tutorials/" }, { text: "Profiling", link: "/tutorials/profiling" }, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching", + }, ], }, "/api/": { diff --git a/docs/src/tutorials/batching.md b/docs/src/tutorials/batching.md new file mode 100644 index 0000000000..d3ca778847 --- /dev/null +++ b/docs/src/tutorials/batching.md @@ -0,0 +1,3 @@ +# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial) + + diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 87c2c8ddd3..7be1f71047 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,5 +1,6 @@ # Tutorials - [Profiling](@ref profiling). + - [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) We are currently working on adding more tutorials to Reactant!! Please check back soon! diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index e5f6e4fc3b..81e3dc4a8c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -758,7 +758,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) return A end @@ -767,18 +767,18 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) T = eltype(A) N = ndims(A) if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive return Reactant.ConcreteRArray{T,N} else - TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) + TT = Reactant.traced_type_inner(T, seen, mode, args...) if TT === T return A else - return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N} + return Array{Reactant.traced_type_inner(T, seen, mode, args...),N} end end end diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl index fc77ef0e1e..798366b682 100644 --- a/ext/ReactantOffsetArraysExt.jl +++ b/ext/ReactantOffsetArraysExt.jl @@ -8,11 +8,11 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type = Union{}) + @nospecialize(args::Vararg) ) N = ndims(OA) T = OffsetArrays.parenttype(OA) - T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers) + T2 = Reactant.traced_type_inner(T, seen, mode, args...) return OffsetArray{eltype(T2),N,T2} end diff --git a/src/Ops.jl b/src/Ops.jl index b953e1c596..ac1b1b61a4 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -12,7 +12,7 @@ using ..Reactant: RNumber, MissingTracedValue, unwrapped_eltype -using Functors: fmap +using Functors: Functors, fmap function mlir_type(x::Union{RNumber,RArray}) return MLIR.IR.TensorType(size(x), MLIR.IR.Type(unwrapped_eltype(x))) @@ -1967,4 +1967,235 @@ end return corrected_traced_results end +""" + batch( + inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}}, + output_types::Vector{<:MLIR.IR.Type}, + batch_shape::Vector{Int64}; + fn, + location=mlir_stacktrace("batch", @__FILE__, @__LINE__), + ) + +Generates a Reactant.MLIR.Dialects.enzyme.batch operation. It is recommended to use +`Ops.batch(f, args, batch_dims, result_dims)` or `Ops.elem_apply(f, args...)` instead +of calling this directly. + +!!! warning + + This function batches the inputs based on the starting dimensions of the inputs. This + aligns with the default ordering in Python frameworks like JAX and PyTorch, but is + opposite to the default ordering in Julia. +""" +@noinline function batch( + inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}}, + output_types::Vector{<:MLIR.IR.Type}, + batch_shape::Vector{Int64}; + fn, + location=mlir_stacktrace("batch", @__FILE__, @__LINE__), +) + op = MLIR.Dialects.enzyme.batch( + [i isa TracedRArray ? i.mlir_data : i for i in inputs]; + outputs=output_types, + fn=MLIR.IR.FlatSymbolRefAttribute( + String(Reactant.TracedUtils.get_attribute_by_name(fn, "sym_name")) + ), + batch_shape=MLIR.IR.DenseArrayAttribute(batch_shape), + location, + ) + + return [ + TracedRArray{MLIR.IR.julia_type(eltype(out_type)),ndims(out_type)}( + (), MLIR.IR.result(op, i), size(out_type) + ) for (i, out_type) in enumerate(output_types) + ] +end + +# This function assumes that the last dimension of each element is the batch dimension by +# default. This is the standard Julia ordering for batching. We permutedims the ordering to +# make sure the first dimension is the batch dimension when calling `batch_internal` below. +""" + batch(f, args...; batch_dims=nothing, result_dims=nothing) + +Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users +familiar with `jax`, this operation corresponds to `jax.vmap`.) + +If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension. + +To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`. + +## Examples + +For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial. + +!!! danger + + Mutation inside a batched function is not supported yet and will lead to unexpected results. +""" +@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) + batch_sizes = Int64[] + batching_dims = if batch_dims === nothing + fmap(args) do x + tmp = ndims(x) + push!(batch_sizes, size(x, tmp)) + return tmp + end + else + fmap(args, batch_dims) do x, dim + dim !== nothing && push!(batch_sizes, size(x, dim)) + @assert dim isa Integer || dim === nothing + dim + end + end + + batch_sizes_no_ones = filter(x -> x != 1, batch_sizes) + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = length(batch_sizes_no_ones) == 0 ? 1 : first(batch_sizes_no_ones) + + corrected_args = fmap(args, batching_dims) do arg, dim + if dim === nothing # repeat the input along dim=0 + return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) + end + if size(arg, dim) == 1 && size(arg, dim) != B # If batch_dim is 1, then expand that dim + new_dims = collect(Int64, size(arg)) + new_dims[dim] = B + arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) + end + order = collect(Int64, 1:ndims(arg)) + order[dim] = 1 + order[1] = dim + return permutedims(arg, order) # Ensure batch dim is moved to the first position + end + + results = batch_internal(f, corrected_args...) + + if result_dims === nothing + return fmap(results) do result + order = Int64[2:ndims(result)..., 1] + return permutedims(result, order) + end + end + + return fmap(results, result_dims) do result, dim + @assert dim !== nothing "Result batch dimension cannot be `nothing`" + + order = collect(Int64, 1:ndims(result)) + order[dim] = 1 + order[1] = dim + return permutedims(result, order) + end +end + +""" + elem_apply(f, args...) + +This is equivalent to `f.(args...)` but generates optimized code using +Reactant.MLIR.Dialects.enzyme.batch. +""" +@noinline function elem_apply(f, args::Vararg) + return batch_internal(f, args...; batchmode=Reactant.BatchScalar) +end + +@noinline function elem_apply( + ::Type{T}, x::TracedRArray{T} +) where {T<:Reactant.ReactantPrimitive} + return x +end + +@noinline function elem_apply( + ::Type{T}, x::TracedRArray +) where {T<:Reactant.ReactantPrimitive} + # Special Path to prevent going down a despecialized path + return elem_apply(Reactant.TracedUtils.TypeCast{T}(), x) +end + +@noinline function batch_internal(f, args::Vararg; batchmode=Reactant.BatchArray) + @assert batchmode != Reactant.BatchNone + + if batchmode == Reactant.BatchScalar + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return Reactant.TracedUtils.promote_to( + TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg + ) + end + return Reactant.call_with_reactant(f, scalar_args...) + end + end + + fnwrap, func2, _, result, seen_args, _, linear_args, _, linear_results = Reactant.TracedUtils.make_mlir_fn( + f, + args, + (), + string(f) * (batchmode == Reactant.BatchArray ? "_batch" : "_broadcast_scalar"), + false; + batchmode, + no_args_in_result=batchmode == Reactant.BatchScalar, + do_transpose=false, + ) + + if batchmode == Reactant.BatchArray + batch_sizes = [size(k, 1) for k in keys(seen_args) if k isa Reactant.TracedType] + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = first(batch_sizes) + else + input_shapes = [size(k) for k in keys(seen_args) if k isa Reactant.TracedType] + @assert allequal(input_shapes) "input shapes are $(input_shapes)" + output_shape = first(input_shapes) + end + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = Reactant.TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + Reactant.TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + fnwrap && (idx -= 1) + Reactant.TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + res = batch( + batch_inputs, + [ + MLIR.IR.TensorType( + batchmode == Reactant.BatchArray ? (B, size(arg)...) : output_shape, + MLIR.IR.Type(Reactant.unwrapped_eltype(arg)), + ) for arg in linear_results + ], + batchmode == Reactant.BatchArray ? Int64[B] : collect(Int64, output_shape); + fn=func2, + ) + + residx = 1 + for a in linear_results + if Reactant.TracedUtils.has_residx(a) + path = Reactant.TracedUtils.get_residx(a) + Reactant.TracedUtils.set!(result, path[2:end], res[residx]) + residx += 1 + else + idx, path = Reactant.TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + Reactant.TracedUtils.set!(f, path[3:end], res[residx]) + residx += 1 + else + fnwrap && (idx -= 1) + Reactant.TracedUtils.set!(args[idx], path[3:end], res[residx]) + residx += 1 + end + end + end + + traced2_result = Reactant.make_tracer( + Reactant.OrderedIdDict(), + result, + (), + Reactant.TracedSetPath; + tobatch=batchmode == Reactant.BatchArray ? (B,) : output_shape, + batchmode, + ) + func2.operation = MLIR.API.MlirOperation(C_NULL) + + return traced2_result +end + end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 165a9c71eb..af395bf90b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -549,7 +549,7 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) res = TracedUtils.promote_to( TracedRArray{unwrapped_eltype(dest),ndims(dest)}, - TracedUtils.elem_apply(bc.f, args...), + Ops.elem_apply(bc.f, args...), ) TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest @@ -563,8 +563,8 @@ function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) - for I in 1:length(dest) + res = Ops.elem_apply(bc.f, args...) + for I in eachindex(dest) dest[I] = Reactant.@allowscalar res[I] end return dest diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 8802ed0833..7f44f09b80 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -107,7 +107,7 @@ function make_mlir_fn( kwargs, name="main", concretein=true; - toscalar=false, + batchmode=Reactant.BatchNone, return_dialect=:func, do_transpose=true, no_args_in_result=false, @@ -121,7 +121,7 @@ function make_mlir_fn( kwargs, name, concretein; - toscalar, + batchmode, return_dialect, do_transpose, no_args_in_result, @@ -138,7 +138,7 @@ function make_mlir_fn( args[i], (:args, i), concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; - toscalar, + batchmode ) end @@ -148,7 +148,7 @@ function make_mlir_fn( push!(linear_args, v) end - in_tys = if toscalar + in_tys = if batchmode == Reactant.BatchScalar [ MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for arg in linear_args @@ -283,19 +283,6 @@ function __take_region(compiled_fn) return region end -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x - -struct TypeCast{T<:ReactantPrimitive} <: Function end - -function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - -function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive} - # Special Path to prevent going down a despecialized path - return elem_apply(TypeCast{T}(), x) -end - function promote_to end function get_attribute_by_name(operation, name) @@ -339,7 +326,11 @@ function set!(x, path, tostore; emptypath=false) x = Reactant.Compiler.traced_getfield(x, p) end - set_mlir_data!(x, tostore) + if tostore isa TracedRArray + set_mlir_data!(x, tostore.mlir_data) + else + set_mlir_data!(x, tostore) + end return emptypath && set_paths!(x, ()) end @@ -368,93 +359,6 @@ function has_residx(x) return false end -function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - if all(iszero ∘ ndims, args) - scalar_args = map(args) do arg - return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg) - end - return f(scalar_args...) - end - - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true - ) - - invmap = IdDict() - for (k, v) in seen_args - invmap[v] = k - end - - keys_seen = [k for k in keys(seen_args) if k isa Reactant.TracedType] - input_shapes = size.(keys_seen) - # by the time we reach here all args must have same size - @assert allequal(input_shapes) "input shapes are $(input_shapes)" - OutShape = isempty(seen_args) ? nothing : first(input_shapes) - @assert !isnothing(OutShape) - - in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args] - - out_tys2 = [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for - arg in linear_results - ] - - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = TracedUtils.get_argidx(a) - if idx == 1 && fnwrap - push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - res = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys2, - fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), - ) - - residx = 1 - - for a in linear_results - if TracedUtils.has_residx(a) - path = TracedUtils.get_residx(a) - TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) - residx += 1 - else - idx, path = TracedUtils.get_argidx(a) - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - end - end - end - - seen_results = OrderedIdDict() - traced2_result = Reactant.make_tracer( - seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape - ) - - func2.operation = MLIR.API.MlirOperation(C_NULL) - - return traced2_result -end - function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) end @@ -494,4 +398,8 @@ end return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) end +struct TypeCast{T<:Reactant.ReactantPrimitive} <: Function end + +@noinline (f::TypeCast{T})(x::TracedRNumber) where {T} = promote_to(TracedRNumber{T}, x) + end diff --git a/src/Tracing.jl b/src/Tracing.jl index d07c5178c7..48d6d0b011 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -7,8 +7,14 @@ NoStopTracedTrack = 6 end +@enum BatchMode begin + BatchNone = 1 + BatchScalar = 2 + BatchArray = 3 +end + Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type) + @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if T === Any return T @@ -39,8 +45,8 @@ Base.@nospecializeinfer function traced_type_inner( if T isa Union return Union{ - traced_type_inner(T.a, seen, mode, track_numbers), - traced_type_inner(T.b, seen, mode, track_numbers), + traced_type_inner(T.a, seen, mode, args...), + traced_type_inner(T.b, seen, mode, args...), } end @@ -67,7 +73,7 @@ Base.@nospecializeinfer function traced_type_inner( subTys = Type[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type_inner(subT, seen2, mode, track_numbers) + subTT = traced_type_inner(subT, seen2, mode, args...) changed |= subT != subTT push!(subTys, subTT) end @@ -85,17 +91,17 @@ Base.@nospecializeinfer function traced_type_inner( subParms = [] for (i, SST) in enumerate(T.parameters) if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive - TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) + TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, args...) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber{<:ReactantPrimitive} - TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) + TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, args...) push!(subParms, TrT) else if SST isa Type - TrT = traced_type_inner(SST, seen, mode, track_numbers) + TrT = traced_type_inner(SST, seen, mode, args...) push!(subParms, TrT) else push!(subParms, SST) @@ -115,7 +121,7 @@ Base.@nospecializeinfer function traced_type_inner( for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type_inner(subT, seen3, mode, track_numbers) + subTT = traced_type_inner(subT, seen3, mode, args...) if subT2 != subTT legal = false break @@ -134,10 +140,7 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{Union{}}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) return T end @@ -154,10 +157,7 @@ for T in ( RNumber, ) @eval Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:$T}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) return T end @@ -167,7 +167,8 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) + @nospecialize(track_numbers::Type), + @nospecialize(args::Vararg) ) if Mode == ArrayToConcrete && T <: track_numbers return ConcreteRNumber{T} @@ -181,20 +182,17 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) if !(C isa UnionAll) - return Complex{traced_type_inner(C.parameters[1], seen, mode, track_numbers)} + return Complex{traced_type_inner(C.parameters[1], seen, mode, args...)} else return C end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Function}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) # functions are directly returned if sizeof(T) == 0 @@ -206,7 +204,7 @@ Base.@nospecializeinfer function traced_type_inner( changed = false traced_fieldtypes = Type[] for i in 1:N - next = traced_type_inner(fieldtype(T, i), seen, mode, track_numbers) + next = traced_type_inner(fieldtype(T, i), seen, mode, args...) changed |= next != fieldtype(T, i) push!(traced_fieldtypes, next) end @@ -223,10 +221,7 @@ end (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Tuple}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) @@ -235,21 +230,18 @@ Base.@nospecializeinfer function traced_type_inner( throw(AssertionError("Type tuple of vararg $T is not supported")) end TT = [ - traced_type_inner(T.parameters[i], seen, mode, track_numbers) for + traced_type_inner(T.parameters[i], seen, mode, args...) for i in 1:length(T.parameters) ] return Tuple{TT...} end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:NamedTuple}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) N = T.parameters[1] V = T.parameters[2] - return NamedTuple{N,traced_type_inner(V, seen, mode, track_numbers)} + return NamedTuple{N,traced_type_inner(V, seen, mode, args...)} end Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing @@ -263,14 +255,14 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{<:AbstractDict}), seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) V = dict_value(T) if V === nothing return T else K = dict_key(T) - V2 = traced_type_inner(V, seen, mode, track_numbers) + V2 = traced_type_inner(V, seen, mode, args...) if V == V2 return T end @@ -291,38 +283,27 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(T0::Type{<:ConcreteRNumber}), seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) - T = T0.parameters[1] if mode == ConcreteToTraced - return TracedRNumber{T} + return TracedRNumber{T0.parameters[1]} elseif mode == TracedToConcrete - return ConcreteRNumber{T} + return T0 else throw("Abstract RNumber cannot be made concrete") end end -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typet(TV.body)) -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = - TracedRArray{TV.parameters...} - -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typec(TV.body)) -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = - (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} - Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), + @nospecialize(CA::Type{<:ConcreteRArray}), seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) if mode == ConcreteToTraced - return base_typet(T) + return TracedRArray{CA.parameters[1],CA.parameters[2]} elseif mode == TracedToConcrete - return T + return CA else throw("Abstract RArray cannot be made concrete") end @@ -332,7 +313,7 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{<:ConcreteRNG}), seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) if mode == ConcreteToTraced return TracedRNG @@ -344,28 +325,89 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:TracedType}), + ::Type{<:MissingTracedValue}, + seen, + mode::TraceMode, + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) +) + return error("This should not happen...") +end + +Base.@nospecializeinfer function traced_type_inner( + TR::Type{<:TracedRNumber}, seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) ) - T <: MissingTracedValue && error("TODO") + T = TR.parameters[1] if mode == ConcreteToTraced - throw("TracedRArray $T cannot be traced") + throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") elseif mode == TracedToConcrete - return base_typec(T) - elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath - return T + return ConcreteRNumber{T} + elseif mode == TracedTrack || mode == NoStopTracedTrack + return TracedRNumber{T} + elseif mode == TracedSetPath + if batchmode == BatchNone + return TracedRNumber{T} + elseif batchmode == BatchScalar + if tobatch === nothing + return TracedRNumber{T} + else + return TracedRArray{T,length(tobatch)} + end + elseif batchmode == BatchArray + if tobatch === nothing + error("For scalars with batchmode=BatchArray, tobatch must be specified") + else + TracedRArray{T,length(tobatch)} + end + else + error("Unknown batchmode $batchmode") + end else - throw("Abstract RArray $T cannot be made concrete in mode $mode") + throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode") end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:TracedRNG}), + TR::Type{<:TracedRArray}, seen, mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) +) + T = TR.parameters[1] + N = TR.parameters[2] + if mode == ConcreteToTraced + throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") + elseif mode == TracedToConcrete + return ConcreteRArray{T,N} + elseif mode == TracedTrack || mode == NoStopTracedTrack + return TracedRArray{T,N} + elseif mode == TracedSetPath + if batchmode == BatchNone + return TracedRArray{T,N} + elseif batchmode == BatchArray + if tobatch === nothing + TracedRArray{T,N - 1} + else + TracedRArray{T,N + length(tobatch)} + end + else + error("Not implemented") + end + else + throw("Abstract RArray $(TracedRArray{T,N}) cannot be made concrete in mode $mode") + end +end + +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if mode == ConcreteToTraced throw("TracedRNG cannot be traced") @@ -379,38 +421,29 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:XLAArray}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) throw("XLA $T array cannot be traced") end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(A::Type{<:Array}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(A::Type{<:Array}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) T = eltype(A) N = ndims(A) if mode == ArrayToConcrete && T <: ReactantPrimitive return ConcreteRArray{T,N} else - return Array{traced_type_inner(T, seen, mode, track_numbers),N} + return Array{traced_type_inner(T, seen, mode, args...),N} end end for P in (Ptr, Core.LLVMPtr, Base.RefValue) @eval Base.@nospecializeinfer function traced_type_inner( - @nospecialize(PT::Type{<:$P}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) T = eltype(PT) - return $P{traced_type_inner(T, seen, mode, track_numbers)} + return $P{traced_type_inner(T, seen, mode, args...)} end end @@ -418,19 +451,19 @@ Base.@nospecializeinfer function traced_type_inner( @nospecialize(VT::Type{<:Val}), seen, @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) if VT isa UnionAll return VT end T = VT.parameters[1] - if traced_type_inner(typeof(T), seen, mode, track_numbers) == typeof(T) + if traced_type_inner(typeof(T), seen, mode, args...) == typeof(T) return Val{T} end throw("Val type $(Val{T}) cannot be traced") end -const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() +const traced_type_cache = Dict{Tuple{TraceMode,Type,Any,Any},Dict{Type,Type}}() # function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) # @nospecialize @@ -524,17 +557,17 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() # end Base.@assume_effects :total @inline function traced_type( - T::Type, ::Val{mode}, track_numbers::Type + T::Type, ::Val{mode}, track_numbers::Type, batchmode, tobatch ) where {mode} cache = nothing - cache_key = (mode, track_numbers) + cache_key = (mode, track_numbers, batchmode, tobatch) if haskey(traced_type_cache, cache_key) cache = traced_type_cache[cache_key] else cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - return res1 = traced_type_inner(T, cache, mode, track_numbers) + return traced_type_inner(T, cache, mode, track_numbers, batchmode, tobatch) end abstract type TracedTypeException <: Exception end @@ -575,16 +608,16 @@ function make_tracer( @nospecialize(prev), @nospecialize(path), mode; - toscalar=false, - tobatch=nothing, @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end RT = Core.Typeof(prev) - TT = traced_type(RT, Val(mode), track_numbers) + TT = traced_type(RT, Val(mode), track_numbers, batchmode, tobatch) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -605,9 +638,8 @@ function make_tracer( xi, append_path(path, i), mode; - toscalar, - tobatch, track_numbers, + batchmode, kwargs..., ) if xi !== xi2 @@ -633,14 +665,7 @@ function make_tracer( if isdefined(prev, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, - xi, - append_path(path, i), - mode; - toscalar, - tobatch, - track_numbers, - kwargs..., + seen, xi, append_path(path, i), mode; track_numbers, kwargs... ) if xi !== xi2 changed = true @@ -700,7 +725,7 @@ function make_tracer( @nospecialize(prev::TracedRArray{T,N}), @nospecialize(path), mode; - toscalar=false, + batchmode=BatchNone, tobatch=nothing, kwargs..., ) where {T,N} @@ -725,12 +750,23 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if toscalar - TracedRNumber{T}((path,), nothing) - elseif tobatch !== nothing - error("This should not happen...") - else + res = if batchmode == BatchNone + @assert tobatch === nothing TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) + elseif batchmode == BatchScalar + if tobatch === nothing + TracedRNumber{T}((path,), nothing) + else + error("BatchScalar + tobatch for TracedRArray doesn't make sense") + end + else + if tobatch === nothing + TracedRArray{T,N - 1}((path,), nothing, size(prev)[2:end]) + else + TracedRArray{T,N + length(tobatch)}( + (path,), prev.mlir_data, (tobatch..., size(prev)...) + ) + end end seen[prev] = res return res @@ -754,7 +790,7 @@ function make_tracer( @nospecialize(path), mode; tobatch=nothing, - toscalar=false, + batchmode=BatchNone, kwargs..., ) where {T} if mode == ConcreteToTraced @@ -778,12 +814,23 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if toscalar - TracedRNumber{T}((path,), nothing) - elseif tobatch !== nothing - TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) - else + res = if batchmode == BatchNone + @assert tobatch === nothing TracedRNumber{T}((path,), prev.mlir_data) + elseif batchmode == BatchScalar + if tobatch === nothing + TracedRNumber{T}((path,), nothing) + else + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + end + elseif batchmode == BatchArray + if tobatch === nothing + error("For scalars with batchmode=BatchArray, tobatch must be specified") + else + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + end + else + error("Unknown batchmode $batchmode") end seen[prev] = res return res @@ -874,21 +921,11 @@ make_tracer(seen, @nospecialize(prev::Type), @nospecialize(path), mode; kwargs.. make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev function make_tracer( - seen, - @nospecialize(prev::Complex), - @nospecialize(path), - mode; - toscalar=false, - tobatch=nothing, - kwargs..., + seen, @nospecialize(prev::Complex), @nospecialize(path), mode; kwargs... ) return Complex( - make_tracer( - seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... - ), - make_tracer( - seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs... - ), + make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), + make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), ) end @@ -898,6 +935,8 @@ function make_tracer( @nospecialize(path), mode; @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), kwargs..., ) RT = Core.Typeof(prev) @@ -907,14 +946,23 @@ function make_tracer( if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end - TT = traced_type(eltype(RT), Val(mode), track_numbers) + TT = traced_type(eltype(RT), Val(mode), track_numbers, batchmode, tobatch) newa = Array{TT,ndims(RT)}(undef, size(prev)) seen[prev] = newa same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = make_tracer(seen, pv, append_path(path, I), mode; track_numbers, kwargs...) + nv = make_tracer( + seen, + pv, + append_path(path, I), + mode; + track_numbers, + batchmode, + tobatch, + kwargs..., + ) if pv !== nv same = false end @@ -943,12 +991,14 @@ function make_tracer( @nospecialize(path), mode; @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), kwargs..., ) NT = Core.Typeof(prev) A = NT.parameters[1] RT = NT.parameters[2] - return NamedTuple{A,traced_type(RT, Val(mode), track_numbers)}(( + return NamedTuple{A,traced_type(RT, Val(mode), track_numbers, batchmode, tobatch)}(( ( make_tracer( seen, @@ -956,6 +1006,8 @@ function make_tracer( append_path(path, i), mode; track_numbers, + batchmode, + tobatch, kwargs..., ) for i in 1:length(A) )..., diff --git a/test/basic.jl b/test/basic.jl index c3952549a6..ad282f332c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -34,19 +34,18 @@ end end sinexp(x) = sin(exp(x)) -sinexpbc(x) = sinexp.(x) @testset "Broadcast combined" begin x = rand(2, 10) - r_res = sinexpbc(x) + r_res = sinexp.(x) a = Reactant.ConcreteRArray(x) - c_res = @allowscalar sinexpbc(a) + c_res = @allowscalar sinexp.(a) @test c_res ≈ r_res - @test @jit(sinexpbc(a)) ≈ r_res + @test @jit(sinexp.(a)) ≈ r_res end sumexp(x) = sum(exp, x) @@ -82,13 +81,11 @@ end @test f_res ≈ r_res end -bcast_cos(x) = cos.(x) - @testset "Basic cos" begin x = rand(3, 2) c = Reactant.ConcreteRArray(x) - @test @jit(bcast_cos(c)) ≈ cos.(x) + @test @jit(cos.(c)) ≈ cos.(x) end f_var(args...) = sum(args) @@ -376,7 +373,7 @@ end b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - # vcat test + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1} diff --git a/test/batching.jl b/test/batching.jl new file mode 100644 index 0000000000..cd6ae6bbf5 --- /dev/null +++ b/test/batching.jl @@ -0,0 +1,2 @@ +using Reactant, Test + diff --git a/test/runtests.jl b/test/runtests.jl index be17750042..9c6e094508 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Batching" include("batching.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/tracing.jl b/test/tracing.jl index c196f562b6..0538ecb7fe 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -144,10 +144,18 @@ using Test Base.Pairs{Symbol,Union{}}, ), ] - tracedty = traced_type(origty, Val(ConcreteToTraced), Union{}) + tracedty = traced_type( + origty, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing + ) @test tracedty == targetty - tracedty2 = traced_type(origty, Val(ConcreteToTraced), ReactantPrimitive) + tracedty2 = traced_type( + origty, + Val(ConcreteToTraced), + ReactantPrimitive, + Reactant.BatchNone, + nothing, + ) @test tracedty2 == targetty end @@ -158,7 +166,7 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Val(ConcreteToTraced), Union{} + type, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing ) end end @@ -167,7 +175,9 @@ using Test x::Vector{Float64} y::Union{Nothing,Node} end - @test_throws NoFieldMatchError traced_type(Node, Val(ArrayToConcrete), Union{}) + @test_throws NoFieldMatchError traced_type( + Node, Val(ArrayToConcrete), Union{}, Reactant.BatchNone, nothing + ) end end