From 632233775677f6fb4c4083327960e8f5ea7bb9af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Jan 2025 22:03:19 -0500 Subject: [PATCH 01/13] feat: implement `Ops.batch` --- src/Ops.jl | 112 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/src/Ops.jl b/src/Ops.jl index b953e1c596..16df2836f0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1967,4 +1967,116 @@ end return corrected_traced_results end +@noinline function batch( + f, + args::Vector{<:TracedRArray}, + batch_dims::Vector{Union{Int,Nothing}}, + result_dims::Union{Vector{Int},Nothing}=nothing, +) + @assert length(batch_dims) == length(args) + + batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)] + filter!(x -> x != 1, batch_sizes) + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = length(batch_sizes) == 0 ? 1 : first(batch_sizes) + + args = map(zip(args, batch_dims)) do (arg, dim) + if dim === nothing + return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) + end + order = collect(1:ndims(arg)) + order[dim] = 1 + order[1] = dim + return permutedims(arg, order) + end + results = batch(f, args) + result_dims === nothing && (result_dims = ones(Int64, length(results))) + @assert length(results) == length(result_dims) + return map(zip(results, result_dims)) do (result, dim) + order = collect(1:ndims(result)) + order[dim] = 1 + order[1] = dim + return permutedims(result, order) + end +end + +@noinline function batch(f, args::Vector{<:TracedRArray}) + batch_sizes = [size(x, 1) for x in args] + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = first(batch_sizes) + + in_tys = [ + MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) + for arg in args + ] + + sym_visibility = MLIR.IR.Attribute("private") + + mod = MLIR.IR.mmodule() + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(f) * "_batch_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args]) + push!(MLIR.IR.region(func, 1), fnbody) + + linear_args = [ + TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}( + (), nothing, size(arg)[2:end] + ) for arg in args + ] + + MLIR.IR.activate!(fnbody) + result = try + for (i, arg) in enumerate(linear_args) + raw_arg = MLIR.IR.argument(fnbody, i) + Reactant.TracedUtils.set_mlir_data!(arg, raw_arg) + end + res = Reactant.call_with_reactant(f, linear_args...) + (res isa TracedRArray || res isa TracedRNumber) && (res = [res]) + MLIR.Dialects.func.return_([r.mlir_data for r in res]) + res + finally + MLIR.IR.deactivate!(fnbody) + end + + comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(f) * "_batch", + function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(r) for r in result]), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1)) + MLIR.API.mlirOperationDestroy(func.operation) + func.operation = MLIR.API.MlirOperation(C_NULL) + + fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = [x.mlir_data for x in args] + out_tys = [ + MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for + r in result + ] + + op = MLIR.Dialects.enzyme.batch( + batch_inputs; + outputs=out_tys, + fn=fname, + batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]), + ) + + return [ + TracedRArray{Reactant.unwrapped_eltype(r),ndims(r) + 1}( + (), MLIR.IR.result(op, i), (B, size(r)...) + ) for (i, r) in enumerate(result) + ] +end + end # module Ops From 11fdc36b8f375c93f366b39d2812ea2f8b37f4d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 09:10:37 -0500 Subject: [PATCH 02/13] refactor: cleanup elem_apply --- src/Ops.jl | 29 +++++++++++++++++++---------- src/TracedUtils.jl | 24 +++++++++--------------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 16df2836f0..fc573b4257 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2056,26 +2056,35 @@ end MLIR.API.mlirOperationDestroy(func.operation) func.operation = MLIR.API.MlirOperation(C_NULL) - fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = [x.mlir_data for x in args] out_tys = [ MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for r in result ] + return batch(args, out_tys, Int64[B]; fn=comp_func) +end + +@noinline function batch( + inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}}, + output_types::Vector{<:MLIR.IR.Type}, + batch_shape::Vector{Int64}; + fn, + location=mlir_stacktrace("batch", @__FILE__, @__LINE__), +) + fname = Reactant.TracedUtils.get_attribute_by_name(fn, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) op = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys, + [i isa TracedRArray ? i.mlir_data : i for i in inputs]; + outputs=output_types, fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]), + batch_shape=MLIR.IR.DenseArrayAttribute(batch_shape), + location, ) return [ - TracedRArray{Reactant.unwrapped_eltype(r),ndims(r) + 1}( - (), MLIR.IR.result(op, i), (B, size(r)...) - ) for (i, r) in enumerate(result) + TracedRArray{MLIR.IR.julia_type(eltype(out_type)),ndims(out_type)}( + (), MLIR.IR.result(op, i), size(out_type) + ) for (i, out_type) in enumerate(output_types) ] end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 8802ed0833..a7e47cf010 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -339,7 +339,11 @@ function set!(x, path, tostore; emptypath=false) x = Reactant.Compiler.traced_getfield(x, p) end - set_mlir_data!(x, tostore) + if tostore isa TracedRArray + set_mlir_data!(x, tostore.mlir_data) + else + set_mlir_data!(x, tostore) + end return emptypath && set_paths!(x, ()) end @@ -392,16 +396,11 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} OutShape = isempty(seen_args) ? nothing : first(input_shapes) @assert !isnothing(OutShape) - in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args] - out_tys2 = [ MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for arg in linear_results ] - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - batch_inputs = MLIR.IR.Value[] for a in linear_args @@ -416,30 +415,25 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} end end - res = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys2, - fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), - ) + res = Ops.batch(batch_inputs, out_tys2, collect(Int64, OutShape); fn=func2) residx = 1 for a in linear_results if TracedUtils.has_residx(a) path = TracedUtils.get_residx(a) - TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) + TracedUtils.set!(result, path[2:end], res[residx]) residx += 1 else idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) + TracedUtils.set!(f, path[3:end], res[residx]) residx += 1 else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + TracedUtils.set!(args[idx], path[3:end], res[residx]) residx += 1 end end From e5113bbc9936d10ab16842ba513ed3ec78281fa6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 11:45:04 -0500 Subject: [PATCH 03/13] refactor: move elem_apply to ops and generalize batching in tracing --- src/Ops.jl | 110 +++++++++++++++++++++++++++++++++++++++++--- src/TracedRArray.jl | 6 +-- src/TracedUtils.jl | 103 +++-------------------------------------- src/Tracing.jl | 60 +++++++++--------------- 4 files changed, 135 insertions(+), 144 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index fc573b4257..341cbcac58 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1970,8 +1970,8 @@ end @noinline function batch( f, args::Vector{<:TracedRArray}, - batch_dims::Vector{Union{Int,Nothing}}, - result_dims::Union{Vector{Int},Nothing}=nothing, + batch_dims::Vector{<:Union{Int64,Nothing}}, + result_dims::Union{Vector{Int64},Nothing}=nothing, ) @assert length(batch_dims) == length(args) @@ -1984,6 +1984,11 @@ end if dim === nothing return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) end + if size(arg, dim) == 1 && size(arg, dim) != B + new_dims = collect(Int64, size(arg)) + new_dims[dim] = B + arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) + end order = collect(1:ndims(arg)) order[dim] = 1 order[1] = dim @@ -2070,13 +2075,12 @@ end fn, location=mlir_stacktrace("batch", @__FILE__, @__LINE__), ) - fname = Reactant.TracedUtils.get_attribute_by_name(fn, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - op = MLIR.Dialects.enzyme.batch( [i isa TracedRArray ? i.mlir_data : i for i in inputs]; outputs=output_types, - fn=fname, + fn=MLIR.IR.FlatSymbolRefAttribute( + String(Reactant.TracedUtils.get_attribute_by_name(fn, "sym_name")) + ), batch_shape=MLIR.IR.DenseArrayAttribute(batch_shape), location, ) @@ -2088,4 +2092,98 @@ end ] end +""" + elem_apply(f, args...) + +This is equivalent to `f.(args...)` but generates optimized code using +Reactant.MLIR.Dialects.enzyme.batch. +""" +@noinline function elem_apply(f, args::Vararg) + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return Reactant.TracedUtils.promote_to( + TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg + ) + end + return f(scalar_args...) + end + + fnwrap, func2, _, result, seen_args, _, linear_args, _, linear_results = Reactant.TracedUtils.make_mlir_fn( + f, + args, + (), + string(f) * "_broadcast_scalar", + false; + batchmode=Reactant.BatchScalar, + no_args_in_result=true, + ) + + input_shapes = [size(k) for k in keys(seen_args) if k isa Reactant.TracedType] + @assert allequal(input_shapes) "input shapes are $(input_shapes)" + OutShape = first(input_shapes) + + batch_inputs = MLIR.IR.Value[] + for a in linear_args + idx, path = Reactant.TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + Reactant.TracedUtils.push_val!(batch_inputs, f, path[3:end]) + else + fnwrap && (idx -= 1) + Reactant.TracedUtils.push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + res = batch( + batch_inputs, + [ + MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for + arg in linear_results + ], + collect(Int64, OutShape); + fn=func2, + ) + + residx = 1 + for a in linear_results + if Reactant.TracedUtils.has_residx(a) + path = Reactant.TracedUtils.get_residx(a) + Reactant.TracedUtils.set!(result, path[2:end], res[residx]) + residx += 1 + else + idx, path = Reactant.TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + Reactant.TracedUtils.set!(f, path[3:end], res[residx]) + residx += 1 + else + fnwrap && (idx -= 1) + Reactant.TracedUtils.set!(args[idx], path[3:end], res[residx]) + residx += 1 + end + end + end + + seen_results = Reactant.OrderedIdDict() + traced2_result = Reactant.make_tracer( + seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape + ) + func2.operation = MLIR.API.MlirOperation(C_NULL) + + return traced2_result +end + +@noinline elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Reactant.ReactantPrimitive} = + x + +struct TypeCast{T<:Reactant.ReactantPrimitive} <: Function end + +@noinline (f::TypeCast{T})(x::TracedRNumber) where {T} = + Reactant.TracedUtils.promote_to(TracedRNumber{T}, x) + +@noinline function elem_apply( + ::Type{T}, x::TracedRArray +) where {T<:Reactant.ReactantPrimitive} + # Special Path to prevent going down a despecialized path + return elem_apply(TypeCast{T}(), x) +end + end # module Ops diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 165a9c71eb..af395bf90b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -549,7 +549,7 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) res = TracedUtils.promote_to( TracedRArray{unwrapped_eltype(dest),ndims(dest)}, - TracedUtils.elem_apply(bc.f, args...), + Ops.elem_apply(bc.f, args...), ) TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest @@ -563,8 +563,8 @@ function _copyto!(dest::AbstractArray{<:TracedRNumber}, bc::Broadcasted) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) - for I in 1:length(dest) + res = Ops.elem_apply(bc.f, args...) + for I in eachindex(dest) dest[I] = Reactant.@allowscalar res[I] end return dest diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index a7e47cf010..28cdcbd185 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -107,7 +107,7 @@ function make_mlir_fn( kwargs, name="main", concretein=true; - toscalar=false, + batchmode=Reactant.BatchNone, return_dialect=:func, do_transpose=true, no_args_in_result=false, @@ -121,7 +121,7 @@ function make_mlir_fn( kwargs, name, concretein; - toscalar, + batchmode, return_dialect, do_transpose, no_args_in_result, @@ -138,7 +138,7 @@ function make_mlir_fn( args[i], (:args, i), concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; - toscalar, + batchmode ) end @@ -148,11 +148,12 @@ function make_mlir_fn( push!(linear_args, v) end - in_tys = if toscalar + in_tys = if batchmode == Reactant.BatchScalar [ - MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for - arg in linear_args + MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for arg in linear_args ] + elseif batchmode == Reactant.BatchArray + error("Not implemented") elseif do_transpose [transpose_ty(Ops.mlir_type(arg)) for arg in linear_args] else @@ -283,19 +284,6 @@ function __take_region(compiled_fn) return region end -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x - -struct TypeCast{T<:ReactantPrimitive} <: Function end - -function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} - return TracedUtils.promote_to(TracedRNumber{T}, x) -end - -function elem_apply(::Type{T}, x::TracedRArray) where {T<:ReactantPrimitive} - # Special Path to prevent going down a despecialized path - return elem_apply(TypeCast{T}(), x) -end - function promote_to end function get_attribute_by_name(operation, name) @@ -372,83 +360,6 @@ function has_residx(x) return false end -function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - if all(iszero ∘ ndims, args) - scalar_args = map(args) do arg - return promote_to(TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg) - end - return f(scalar_args...) - end - - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true - ) - - invmap = IdDict() - for (k, v) in seen_args - invmap[v] = k - end - - keys_seen = [k for k in keys(seen_args) if k isa Reactant.TracedType] - input_shapes = size.(keys_seen) - # by the time we reach here all args must have same size - @assert allequal(input_shapes) "input shapes are $(input_shapes)" - OutShape = isempty(seen_args) ? nothing : first(input_shapes) - @assert !isnothing(OutShape) - - out_tys2 = [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for - arg in linear_results - ] - - batch_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = TracedUtils.get_argidx(a) - if idx == 1 && fnwrap - push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - res = Ops.batch(batch_inputs, out_tys2, collect(Int64, OutShape); fn=func2) - - residx = 1 - - for a in linear_results - if TracedUtils.has_residx(a) - path = TracedUtils.get_residx(a) - TracedUtils.set!(result, path[2:end], res[residx]) - residx += 1 - else - idx, path = TracedUtils.get_argidx(a) - if idx == 1 && fnwrap - TracedUtils.set!(f, path[3:end], res[residx]) - residx += 1 - else - if fnwrap - idx -= 1 - end - TracedUtils.set!(args[idx], path[3:end], res[residx]) - residx += 1 - end - end - end - - seen_results = OrderedIdDict() - traced2_result = Reactant.make_tracer( - seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape - ) - - func2.operation = MLIR.API.MlirOperation(C_NULL) - - return traced2_result -end - function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize) return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize) end diff --git a/src/Tracing.jl b/src/Tracing.jl index d07c5178c7..14eea554be 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -7,9 +7,13 @@ NoStopTracedTrack = 6 end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type) -) +@enum BatchMode begin + BatchNone = 1 + BatchScalar = 2 + BatchArray = 3 +end + +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) if T === Any return T end @@ -575,9 +579,7 @@ function make_tracer( @nospecialize(prev), @nospecialize(path), mode; - toscalar=false, - tobatch=nothing, - @nospecialize(track_numbers::Type = Union{}), + @nospecialize(track_numbers::Type=Union{}), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -601,14 +603,7 @@ function make_tracer( if isdefined(prev, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, - xi, - append_path(path, i), - mode; - toscalar, - tobatch, - track_numbers, - kwargs..., + seen, xi, append_path(path, i), mode; track_numbers, kwargs... ) if xi !== xi2 changed = true @@ -633,14 +628,7 @@ function make_tracer( if isdefined(prev, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, - xi, - append_path(path, i), - mode; - toscalar, - tobatch, - track_numbers, - kwargs..., + seen, xi, append_path(path, i), mode; track_numbers, kwargs... ) if xi !== xi2 changed = true @@ -700,7 +688,7 @@ function make_tracer( @nospecialize(prev::TracedRArray{T,N}), @nospecialize(path), mode; - toscalar=false, + batchmode=BatchNone, tobatch=nothing, kwargs..., ) where {T,N} @@ -725,8 +713,10 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if toscalar + res = if batchmode == BatchScalar TracedRNumber{T}((path,), nothing) + elseif batchmode == BatchArray + error("Not implemented") elseif tobatch !== nothing error("This should not happen...") else @@ -754,7 +744,7 @@ function make_tracer( @nospecialize(path), mode; tobatch=nothing, - toscalar=false, + batchmode=BatchNone, kwargs..., ) where {T} if mode == ConcreteToTraced @@ -778,8 +768,10 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if toscalar + res = if batchmode == BatchScalar TracedRNumber{T}((path,), nothing) + elseif batchmode == BatchArray + error("Cannot BatchArray on a scalar") elseif tobatch !== nothing TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) else @@ -874,21 +866,11 @@ make_tracer(seen, @nospecialize(prev::Type), @nospecialize(path), mode; kwargs.. make_tracer(seen, prev::Symbol, @nospecialize(path), mode; kwargs...) = prev function make_tracer( - seen, - @nospecialize(prev::Complex), - @nospecialize(path), - mode; - toscalar=false, - tobatch=nothing, - kwargs..., + seen, @nospecialize(prev::Complex), @nospecialize(path), mode; kwargs... ) return Complex( - make_tracer( - seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... - ), - make_tracer( - seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs... - ), + make_tracer(seen, prev.re, append_path(path, :re), mode; kwargs...), + make_tracer(seen, prev.im, append_path(path, :im), mode; kwargs...), ) end From c52022d4e6ef61503ff945e960e89b92ed164a89 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 12:54:05 -0500 Subject: [PATCH 04/13] feat: support arbitrary structures for batching --- src/Compiler.jl | 4 +- src/Ops.jl | 233 ++++++++++++++++++++------------------------- src/TracedUtils.jl | 9 +- src/Tracing.jl | 40 +++++--- 4 files changed, 138 insertions(+), 148 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 7b07942b4a..d1be9afe16 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -795,7 +795,7 @@ function codegen_unflatten!( paths = ( ( p for p in Reactant.TracedUtils.get_paths(result) if - length(p) > 0 && (p[1] == :result || p[1] == :resargs) + length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs) )..., ) for path in paths @@ -865,7 +865,7 @@ function codegen_unflatten!( paths = ( ( p for p in Reactant.TracedUtils.get_paths(result) if - length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args) + length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args) )..., ) diff --git a/src/Ops.jl b/src/Ops.jl index 341cbcac58..8f246de16a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1967,107 +1967,25 @@ end return corrected_traced_results end -@noinline function batch( - f, - args::Vector{<:TracedRArray}, - batch_dims::Vector{<:Union{Int64,Nothing}}, - result_dims::Union{Vector{Int64},Nothing}=nothing, -) - @assert length(batch_dims) == length(args) - - batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)] - filter!(x -> x != 1, batch_sizes) - @assert allequal(batch_sizes) "batching dimensions must be equal" - B = length(batch_sizes) == 0 ? 1 : first(batch_sizes) - - args = map(zip(args, batch_dims)) do (arg, dim) - if dim === nothing - return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) - end - if size(arg, dim) == 1 && size(arg, dim) != B - new_dims = collect(Int64, size(arg)) - new_dims[dim] = B - arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) - end - order = collect(1:ndims(arg)) - order[dim] = 1 - order[1] = dim - return permutedims(arg, order) - end - results = batch(f, args) - result_dims === nothing && (result_dims = ones(Int64, length(results))) - @assert length(results) == length(result_dims) - return map(zip(results, result_dims)) do (result, dim) - order = collect(1:ndims(result)) - order[dim] = 1 - order[1] = dim - return permutedims(result, order) - end -end - -@noinline function batch(f, args::Vector{<:TracedRArray}) - batch_sizes = [size(x, 1) for x in args] - @assert allequal(batch_sizes) "batching dimensions must be equal" - B = first(batch_sizes) - - in_tys = [ - MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) - for arg in args - ] - - sym_visibility = MLIR.IR.Attribute("private") - - mod = MLIR.IR.mmodule() - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=string(f) * "_batch_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - sym_visibility, - ) - end - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args]) - push!(MLIR.IR.region(func, 1), fnbody) - - linear_args = [ - TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}( - (), nothing, size(arg)[2:end] - ) for arg in args - ] - - MLIR.IR.activate!(fnbody) - result = try - for (i, arg) in enumerate(linear_args) - raw_arg = MLIR.IR.argument(fnbody, i) - Reactant.TracedUtils.set_mlir_data!(arg, raw_arg) - end - res = Reactant.call_with_reactant(f, linear_args...) - (res isa TracedRArray || res isa TracedRNumber) && (res = [res]) - MLIR.Dialects.func.return_([r.mlir_data for r in res]) - res - finally - MLIR.IR.deactivate!(fnbody) - end +""" + batch( + inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}}, + output_types::Vector{<:MLIR.IR.Type}, + batch_shape::Vector{Int64}; + fn, + location=mlir_stacktrace("batch", @__FILE__, @__LINE__), + ) - comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=string(f) * "_batch", - function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(r) for r in result]), - body=MLIR.IR.Region(), - sym_visibility, - ) - end - MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1)) - MLIR.API.mlirOperationDestroy(func.operation) - func.operation = MLIR.API.MlirOperation(C_NULL) +Generates a Reactant.MLIR.Dialects.enzyme.batch operation. It is recommended to use +`Ops.batch(f, args, batch_dims, result_dims)` or `Ops.elem_apply(f, args...)` instead +of calling this directly. - out_tys = [ - MLIR.IR.TensorType((B, size(r)...), MLIR.IR.Type(Reactant.unwrapped_eltype(r))) for - r in result - ] - return batch(args, out_tys, Int64[B]; fn=comp_func) -end +!!! warning + This function batches the inputs based on the starting dimensions of the inputs. This + aligns with the default ordering in Python frameworks like JAX and PyTorch, but is + opposite to the default ordering in Julia. +""" @noinline function batch( inputs::Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}}, output_types::Vector{<:MLIR.IR.Type}, @@ -2092,6 +2010,43 @@ end ] end +# This function assumes that the last dimension of each element is the batch dimension by +# default. This is the standard Julia ordering for batching. We permutedims the ordering to +# make sure the first dimension is the batch dimension when calling `batch_internal` below. +@noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) + # @assert length(batch_dims) == length(args) + + # batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)] + # filter!(x -> x != 1, batch_sizes) + # @assert allequal(batch_sizes) "batching dimensions must be equal" + # B = length(batch_sizes) == 0 ? 1 : first(batch_sizes) + + # args = map(zip(args, batch_dims)) do (arg, dim) + # if dim === nothing + # return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) + # end + # if size(arg, dim) == 1 && size(arg, dim) != B + # new_dims = collect(Int64, size(arg)) + # new_dims[dim] = B + # arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) + # end + # order = collect(1:ndims(arg)) + # order[dim] = 1 + # order[1] = dim + # return permutedims(arg, order) + # end + # results = batch(f, args) + # result_dims === nothing && (result_dims = ones(Int64, length(results))) + # @assert length(results) == length(result_dims) + # return map(zip(results, result_dims)) do (result, dim) + # order = collect(1:ndims(result)) + # order[dim] = 1 + # order[1] = dim + # return permutedims(result, order) + # end + return error(1) +end + """ elem_apply(f, args...) @@ -2099,28 +2054,56 @@ This is equivalent to `f.(args...)` but generates optimized code using Reactant.MLIR.Dialects.enzyme.batch. """ @noinline function elem_apply(f, args::Vararg) - if all(iszero ∘ ndims, args) - scalar_args = map(args) do arg - return Reactant.TracedUtils.promote_to( - TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg - ) + return batch_internal(f, args...; batchmode=Reactant.BatchScalar) +end + +@noinline function elem_apply( + ::Type{T}, x::TracedRArray{T} +) where {T<:Reactant.ReactantPrimitive} + return x +end + +@noinline function elem_apply( + ::Type{T}, x::TracedRArray +) where {T<:Reactant.ReactantPrimitive} + # Special Path to prevent going down a despecialized path + return elem_apply(Reactant.TracedUtils.TypeCast{T}(), x) +end + +@noinline function batch_internal(f, args::Vararg; batchmode=Reactant.BatchArray) + @assert batchmode != Reactant.BatchNone + + if batchmode == Reactant.BatchScalar + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return Reactant.TracedUtils.promote_to( + TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg + ) + end + return f(scalar_args...) end - return f(scalar_args...) end fnwrap, func2, _, result, seen_args, _, linear_args, _, linear_results = Reactant.TracedUtils.make_mlir_fn( f, args, (), - string(f) * "_broadcast_scalar", + string(f) * (batchmode == Reactant.BatchArray ? "_batch" : "_broadcast_scalar"), false; - batchmode=Reactant.BatchScalar, - no_args_in_result=true, + batchmode, + no_args_in_result=batchmode == Reactant.BatchScalar, + do_transpose=false, ) - input_shapes = [size(k) for k in keys(seen_args) if k isa Reactant.TracedType] - @assert allequal(input_shapes) "input shapes are $(input_shapes)" - OutShape = first(input_shapes) + if batchmode == Reactant.BatchArray + batch_sizes = [size(k, 1) for k in keys(seen_args) if k isa Reactant.TracedType] + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = first(batch_sizes) + else + input_shapes = [size(k) for k in keys(seen_args) if k isa Reactant.TracedType] + @assert allequal(input_shapes) "input shapes are $(input_shapes)" + output_shape = first(input_shapes) + end batch_inputs = MLIR.IR.Value[] for a in linear_args @@ -2136,10 +2119,12 @@ Reactant.MLIR.Dialects.enzyme.batch. res = batch( batch_inputs, [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for - arg in linear_results + MLIR.IR.TensorType( + batchmode == Reactant.BatchArray ? (B, size(arg)...) : output_shape, + MLIR.IR.Type(Reactant.unwrapped_eltype(arg)), + ) for arg in linear_results ], - collect(Int64, OutShape); + batchmode == Reactant.BatchArray ? Int64[B] : collect(Int64, output_shape); fn=func2, ) @@ -2164,26 +2149,16 @@ Reactant.MLIR.Dialects.enzyme.batch. seen_results = Reactant.OrderedIdDict() traced2_result = Reactant.make_tracer( - seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape + seen_results, + result, + (), + Reactant.TracedSetPath; + tobatch=batchmode == Reactant.BatchArray ? output_shape : (B,), + batchmode, ) func2.operation = MLIR.API.MlirOperation(C_NULL) return traced2_result end -@noinline elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Reactant.ReactantPrimitive} = - x - -struct TypeCast{T<:Reactant.ReactantPrimitive} <: Function end - -@noinline (f::TypeCast{T})(x::TracedRNumber) where {T} = - Reactant.TracedUtils.promote_to(TracedRNumber{T}, x) - -@noinline function elem_apply( - ::Type{T}, x::TracedRArray -) where {T<:Reactant.ReactantPrimitive} - # Special Path to prevent going down a despecialized path - return elem_apply(TypeCast{T}(), x) -end - end # module Ops diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 28cdcbd185..7f44f09b80 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -150,10 +150,9 @@ function make_mlir_fn( in_tys = if batchmode == Reactant.BatchScalar [ - MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for arg in linear_args + MLIR.IR.TensorType((), MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) for + arg in linear_args ] - elseif batchmode == Reactant.BatchArray - error("Not implemented") elseif do_transpose [transpose_ty(Ops.mlir_type(arg)) for arg in linear_args] else @@ -399,4 +398,8 @@ end return Ops.broadcast_in_dim(x, collect(Int64, 1:ndims(x)), collect(Int64, rsize)) end +struct TypeCast{T<:Reactant.ReactantPrimitive} <: Function end + +@noinline (f::TypeCast{T})(x::TracedRNumber) where {T} = promote_to(TracedRNumber{T}, x) + end diff --git a/src/Tracing.jl b/src/Tracing.jl index 14eea554be..fb1d01fbef 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -713,14 +713,23 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if batchmode == BatchScalar - TracedRNumber{T}((path,), nothing) - elseif batchmode == BatchArray - error("Not implemented") - elseif tobatch !== nothing - error("This should not happen...") - else + res = if batchmode == BatchNone + @assert tobatch === nothing TracedRArray{T,N}((path,), prev.mlir_data, size(prev)) + elseif batchmode == BatchScalar + if tobatch === nothing + TracedRNumber{T}((path,), nothing) + else + error("BatchScalar + tobatch for TracedRArray doesn't make sense") + end + else + if tobatch === nothing + TracedRArray{T,N - 1}((path,), nothing, size(prev)[2:end]) + else + TracedRArray{T,N + length(tobatch)}( + (path,), prev.mlir_data, (tobatch..., size(prev)...) + ) + end end seen[prev] = res return res @@ -768,14 +777,17 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - res = if batchmode == BatchScalar - TracedRNumber{T}((path,), nothing) - elseif batchmode == BatchArray - error("Cannot BatchArray on a scalar") - elseif tobatch !== nothing - TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) - else + res = if batchmode == BatchNone + @assert tobatch === nothing TracedRNumber{T}((path,), prev.mlir_data) + elseif batchmode == BatchScalar + if tobatch === nothing + TracedRNumber{T}((path,), nothing) + else + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + end + else + error("Cannot BatchArray on a scalar") end seen[prev] = res return res From fc03c4e7704e60f4c92a2007666a41ce5852933d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 13:52:24 -0500 Subject: [PATCH 05/13] feat: attempt a tracer impl for rewriting the args --- src/Ops.jl | 106 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 8f246de16a..c5eb837103 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2013,29 +2013,87 @@ end # This function assumes that the last dimension of each element is the batch dimension by # default. This is the standard Julia ordering for batching. We permutedims the ordering to # make sure the first dimension is the batch dimension when calling `batch_internal` below. +# XXX: Mutation inside a batched function is not supported yet (need to set the results +# correctly) @noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) - # @assert length(batch_dims) == length(args) - - # batch_sizes = [dim === nothing ? 1 : size(x, dim) for (x, dim) in zip(args, batch_dims)] - # filter!(x -> x != 1, batch_sizes) - # @assert allequal(batch_sizes) "batching dimensions must be equal" - # B = length(batch_sizes) == 0 ? 1 : first(batch_sizes) - - # args = map(zip(args, batch_dims)) do (arg, dim) - # if dim === nothing - # return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) - # end - # if size(arg, dim) == 1 && size(arg, dim) != B - # new_dims = collect(Int64, size(arg)) - # new_dims[dim] = B - # arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) - # end - # order = collect(1:ndims(arg)) - # order[dim] = 1 - # order[1] = dim - # return permutedims(arg, order) + N = length(args) + seen_args = Reactant.OrderedIdDict() + traced_args = Vector{Any}(undef, N) + for i in 1:N + @inbounds traced_args[i] = Reactant.make_tracer( + seen_args, args[i], (:batchargs, i), Reactant.NoStopTracedTrack + ) + end + linear_args = [v for v in values(seen_args) if v isa Reactant.TracedType] + + batching_dims = Union{Int64,Nothing}[] + batch_sizes = Int64[] + arg_paths = Tuple[] + for (i, arg) in enumerate(linear_args) + if batch_dims === nothing # assume last dimension is batch dimension + push!(batching_dims, ndims(arg)) + else + paths = Reactant.TracedUtils.get_paths(arg) + path = paths[findfirst(p -> p[1] == :batchargs, paths)][3:end] + push!(arg_paths, path) + bdim = batch_dims[i] + for p in path + bdim = Reactant.Compiler.traced_getfield(bdim, p) + end + if bdim === nothing # Input is not batched + push!(batching_dims, nothing) + else + @assert bdim isa Integer "batching dimension must be an integer or nothing" + push!(batching_dims, bdim) + end + end + batching_dims[i] !== nothing && push!(batch_sizes, size(arg, batching_dims[i])) + end + + batch_sizes_no_ones = filter(x -> x != 1, batch_sizes) + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = length(batch_sizes_no_ones) == 0 ? 1 : first(batch_sizes_no_ones) + + corrected_linear_args = map(zip(linear_args, batching_dims)) do (arg, dim) + if dim === nothing # repeat the input along dim=0 + return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) + end + if size(arg, dim) == 1 && size(arg, dim) != B # If batch_dim is 1, then expand that dim + new_dims = collect(Int64, size(arg)) + new_dims[dim] = B + arg = broadcast_in_dim(arg, collect(1:ndims(arg)), new_dims) + end + order = collect(Int64, 1:ndims(arg)) + order[dim] = 1 + order[1] = dim + return permutedims(arg, order) # Ensure batch dim is moved to the first position + end + + # argidx = 1 + # for (i, arg) in enumerate(corrected_linear_args) + # Reactant.TracedUtils.set!(args[argidx], arg_paths[i], arg) + # argidx += 1 + # end + + # @show linear_args + # @show corrected_linear_args + + # traced_args = Vector{Any}(undef, N) + # seen_args = Reactant.OrderedIdDict() + # for i in 1:N + # @inbounds traced_args[i] = Reactant.make_tracer( + # seen_args, args[i], (), Reactant.TracedSetPath + # ) # end - # results = batch(f, args) + + # @show corrected_linear_args + # @show args + # @show traced_args + + # @show traced_args[1].x + # @show traced_args[2].x[1] + + results = batch_internal(f, traced_args...) # result_dims === nothing && (result_dims = ones(Int64, length(results))) # @assert length(results) == length(result_dims) # return map(zip(results, result_dims)) do (result, dim) @@ -2044,6 +2102,9 @@ end # order[1] = dim # return permutedims(result, order) # end + + # TODO: Restore the args here? + return error(1) end @@ -2147,9 +2208,8 @@ end end end - seen_results = Reactant.OrderedIdDict() traced2_result = Reactant.make_tracer( - seen_results, + Reactant.OrderedIdDict(), result, (), Reactant.TracedSetPath; From c509de5ef2178c434e9af2fa7aca4a24e76d3d90 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 14:29:56 -0500 Subject: [PATCH 06/13] feat: support arbitrary structures for batching --- ext/ReactantCUDAExt.jl | 8 +- ext/ReactantOffsetArraysExt.jl | 4 +- src/Ops.jl | 95 ++++---------- src/Reactant.jl | 2 + src/Tracing.jl | 225 +++++++++++++-------------------- 5 files changed, 125 insertions(+), 209 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index e5f6e4fc3b..81e3dc4a8c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -758,7 +758,7 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(A::Type{<:CuTracedArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) return A end @@ -767,18 +767,18 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(A::Type{<:CUDA.CuArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(args::Vararg) ) T = eltype(A) N = ndims(A) if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive return Reactant.ConcreteRArray{T,N} else - TT = Reactant.traced_type_inner(T, seen, mode, track_numbers) + TT = Reactant.traced_type_inner(T, seen, mode, args...) if TT === T return A else - return Array{Reactant.traced_type_inner(T, seen, mode, track_numbers),N} + return Array{Reactant.traced_type_inner(T, seen, mode, args...),N} end end end diff --git a/ext/ReactantOffsetArraysExt.jl b/ext/ReactantOffsetArraysExt.jl index fc77ef0e1e..798366b682 100644 --- a/ext/ReactantOffsetArraysExt.jl +++ b/ext/ReactantOffsetArraysExt.jl @@ -8,11 +8,11 @@ Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(OA::Type{<:OffsetArray}), seen, mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type = Union{}) + @nospecialize(args::Vararg) ) N = ndims(OA) T = OffsetArrays.parenttype(OA) - T2 = Reactant.traced_type_inner(T, seen, mode, track_numbers) + T2 = Reactant.traced_type_inner(T, seen, mode, args...) return OffsetArray{eltype(T2),N,T2} end diff --git a/src/Ops.jl b/src/Ops.jl index c5eb837103..50d3b2cbfc 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -12,7 +12,7 @@ using ..Reactant: RNumber, MissingTracedValue, unwrapped_eltype -using Functors: fmap +using Functors: Functors, fmap function mlir_type(x::Union{RNumber,RArray}) return MLIR.IR.TensorType(size(x), MLIR.IR.Type(unwrapped_eltype(x))) @@ -2016,45 +2016,26 @@ end # XXX: Mutation inside a batched function is not supported yet (need to set the results # correctly) @noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) - N = length(args) - seen_args = Reactant.OrderedIdDict() - traced_args = Vector{Any}(undef, N) - for i in 1:N - @inbounds traced_args[i] = Reactant.make_tracer( - seen_args, args[i], (:batchargs, i), Reactant.NoStopTracedTrack - ) - end - linear_args = [v for v in values(seen_args) if v isa Reactant.TracedType] - - batching_dims = Union{Int64,Nothing}[] batch_sizes = Int64[] - arg_paths = Tuple[] - for (i, arg) in enumerate(linear_args) - if batch_dims === nothing # assume last dimension is batch dimension - push!(batching_dims, ndims(arg)) - else - paths = Reactant.TracedUtils.get_paths(arg) - path = paths[findfirst(p -> p[1] == :batchargs, paths)][3:end] - push!(arg_paths, path) - bdim = batch_dims[i] - for p in path - bdim = Reactant.Compiler.traced_getfield(bdim, p) - end - if bdim === nothing # Input is not batched - push!(batching_dims, nothing) - else - @assert bdim isa Integer "batching dimension must be an integer or nothing" - push!(batching_dims, bdim) - end + batching_dims = if batch_dims === nothing + fmap(args) do x + tmp = ndims(x) + push!(batch_sizes, size(x, tmp)) + return tmp + end + else + fmap(args, batch_dims) do x, dim + dim !== nothing && push!(batch_sizes, size(x, dim)) + @assert dim isa Integer || dim === nothing + dim end - batching_dims[i] !== nothing && push!(batch_sizes, size(arg, batching_dims[i])) end batch_sizes_no_ones = filter(x -> x != 1, batch_sizes) @assert allequal(batch_sizes) "batching dimensions must be equal" B = length(batch_sizes_no_ones) == 0 ? 1 : first(batch_sizes_no_ones) - corrected_linear_args = map(zip(linear_args, batching_dims)) do (arg, dim) + corrected_args = fmap(args, batching_dims) do arg, dim if dim === nothing # repeat the input along dim=0 return broadcast_in_dim(arg, collect(1:ndims(arg)) .+ 1, Int64[B, size(arg)...]) end @@ -2069,43 +2050,21 @@ end return permutedims(arg, order) # Ensure batch dim is moved to the first position end - # argidx = 1 - # for (i, arg) in enumerate(corrected_linear_args) - # Reactant.TracedUtils.set!(args[argidx], arg_paths[i], arg) - # argidx += 1 - # end - - # @show linear_args - # @show corrected_linear_args - - # traced_args = Vector{Any}(undef, N) - # seen_args = Reactant.OrderedIdDict() - # for i in 1:N - # @inbounds traced_args[i] = Reactant.make_tracer( - # seen_args, args[i], (), Reactant.TracedSetPath - # ) - # end - - # @show corrected_linear_args - # @show args - # @show traced_args - - # @show traced_args[1].x - # @show traced_args[2].x[1] + results = batch_internal(f, corrected_args...) - results = batch_internal(f, traced_args...) - # result_dims === nothing && (result_dims = ones(Int64, length(results))) - # @assert length(results) == length(result_dims) - # return map(zip(results, result_dims)) do (result, dim) - # order = collect(1:ndims(result)) - # order[dim] = 1 - # order[1] = dim - # return permutedims(result, order) - # end - - # TODO: Restore the args here? + if result_dims === nothing + return fmap(results) do result + order = Int64[2:ndims(result)..., 1] + return permutedims(result, order) + end + end - return error(1) + return fmap(results, result_dims) do result, dim + order = collect(Int64, 1:ndims(result)) + order[dim] = 1 + order[1] = dim + return permutedims(result, order) + end end """ @@ -2213,7 +2172,7 @@ end result, (), Reactant.TracedSetPath; - tobatch=batchmode == Reactant.BatchArray ? output_shape : (B,), + tobatch=batchmode == Reactant.BatchArray ? (B,) : output_shape, batchmode, ) func2.operation = MLIR.API.MlirOperation(C_NULL) diff --git a/src/Reactant.jl b/src/Reactant.jl index 41f6ab9298..0dcae598cb 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -9,6 +9,8 @@ using Functors: @leaf using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` +using Functors: @leaf + export @allowscalar # re-exported from GPUArraysCore # auxiliary types and functions diff --git a/src/Tracing.jl b/src/Tracing.jl index fb1d01fbef..e1f718bc18 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -13,7 +13,7 @@ end BatchArray = 3 end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(track_numbers::Type)) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg)) if T === Any return T end @@ -43,8 +43,8 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, if T isa Union return Union{ - traced_type_inner(T.a, seen, mode, track_numbers), - traced_type_inner(T.b, seen, mode, track_numbers), + traced_type_inner(T.a, seen, mode, args...), + traced_type_inner(T.b, seen, mode, args...), } end @@ -71,7 +71,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, subTys = Type[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type_inner(subT, seen2, mode, track_numbers) + subTT = traced_type_inner(subT, seen2, mode, args...) changed |= subT != subTT push!(subTys, subTT) end @@ -89,17 +89,17 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, subParms = [] for (i, SST) in enumerate(T.parameters) if wrapped_carray && i == 1 && SST isa Type && SST <: ReactantPrimitive - TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, track_numbers) + TrT = traced_type_inner(ConcreteRNumber{SST}, seen, mode, args...) push!(subParms, TrT) elseif wrapped_tracedarray && i == 1 && SST isa Type && SST <: TracedRNumber{<:ReactantPrimitive} - TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, track_numbers) + TrT = traced_type_inner(unwrapped_eltype(SST), seen, mode, args...) push!(subParms, TrT) else if SST isa Type - TrT = traced_type_inner(SST, seen, mode, track_numbers) + TrT = traced_type_inner(SST, seen, mode, args...) push!(subParms, TrT) else push!(subParms, SST) @@ -119,7 +119,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type_inner(subT, seen3, mode, track_numbers) + subTT = traced_type_inner(subT, seen3, mode, args...) if subT2 != subTT legal = false break @@ -137,41 +137,18 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, throw(NoFieldMatchError(T, TT2)) end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{Union{}}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg)) return T end -for T in ( - DataType, - Module, - Nothing, - Symbol, - AbstractChar, - AbstractString, - AbstractFloat, - Integer, - RNumber, -) - @eval Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:$T}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) - ) +for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, AbstractFloat, Integer, RNumber) + @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg)) return T end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ReactantPrimitive}), - seen, - @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(args::Vararg) ) if Mode == ArrayToConcrete && T <: track_numbers return ConcreteRNumber{T} @@ -182,24 +159,16 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(C::Type{<:Complex}), - seen, - @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) + @nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg) ) if !(C isa UnionAll) - return Complex{traced_type_inner(C.parameters[1], seen, mode, track_numbers)} + return Complex{traced_type_inner(C.parameters[1], seen, mode, args...)} else return C end end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Function}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg)) # functions are directly returned if sizeof(T) == 0 return T @@ -210,7 +179,7 @@ Base.@nospecializeinfer function traced_type_inner( changed = false traced_fieldtypes = Type[] for i in 1:N - next = traced_type_inner(fieldtype(T, i), seen, mode, track_numbers) + next = traced_type_inner(fieldtype(T, i), seen, mode, args...) changed |= next != fieldtype(T, i) push!(traced_fieldtypes, next) end @@ -226,12 +195,7 @@ end @inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:Tuple}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)) if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) @@ -239,21 +203,16 @@ Base.@nospecializeinfer function traced_type_inner( throw(AssertionError("Type tuple of vararg $T is not supported")) end TT = [ - traced_type_inner(T.parameters[i], seen, mode, track_numbers) for + traced_type_inner(T.parameters[i], seen, mode, args...) for i in 1:length(T.parameters) ] return Tuple{TT...} end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:NamedTuple}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)) N = T.parameters[1] V = T.parameters[2] - return NamedTuple{N,traced_type_inner(V, seen, mode, track_numbers)} + return NamedTuple{N,traced_type_inner(V, seen, mode, args...)} end Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing @@ -263,18 +222,13 @@ Base.@nospecializeinfer @inline dict_value( ::Type{<:(AbstractDict{K,V} where {K})} ) where {V} = V -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:AbstractDict}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:AbstractDict}), seen, mode::TraceMode, @nospecialize(args::Vararg)) V = dict_value(T) if V === nothing return T else K = dict_key(T) - V2 = traced_type_inner(V, seen, mode, track_numbers) + V2 = traced_type_inner(V, seen, mode, args...) if V == V2 return T end @@ -292,10 +246,7 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T0::Type{<:ConcreteRNumber}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T0::Type{<:ConcreteRNumber}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) T = T0.parameters[1] if mode == ConcreteToTraced @@ -318,10 +269,7 @@ Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(T::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if mode == ConcreteToTraced return base_typet(T) @@ -332,12 +280,7 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRNG}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:ConcreteRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg)) if mode == ConcreteToTraced return TracedRNG elseif mode == TracedToConcrete @@ -348,29 +291,43 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:TracedType}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + ::Type{<:MissingTracedValue}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch) ) - T <: MissingTracedValue && error("TODO") + error("This should not happen...") +end + +@inline base_typec(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, base_typec(TV.body)) +@inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...} + +Base.@nospecializeinfer function traced_type_inner( + TR::Type{<:TracedRArray}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch) +) + T = TR.parameters[1] + N = TR.parameters[2] if mode == ConcreteToTraced - throw("TracedRArray $T cannot be traced") + throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") elseif mode == TracedToConcrete - return base_typec(T) - elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath - return T + return base_typec(TracedRArray{T,N}) + elseif mode == TracedTrack || mode == NoStopTracedTrack + return TracedRArray{T,N} + elseif mode == TracedSetPath + if batchmode == BatchNone + return T + elseif batchmode == BatchArray + if tobatch === nothing + TracedRArray{T,N - 1} + else + TracedRArray{T,N + length(tobatch)} + end + else + error("Not implemented") + end else - throw("Abstract RArray $T cannot be made concrete in mode $mode") + throw("Abstract RArray $(TracedRArray{T,N}) cannot be made concrete in mode $mode") end end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:TracedRNG}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg)) if mode == ConcreteToTraced throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete @@ -382,53 +339,35 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:XLAArray}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg)) throw("XLA $T array cannot be traced") end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(A::Type{<:Array}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) + @nospecialize(A::Type{<:Array}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) T = eltype(A) N = ndims(A) if mode == ArrayToConcrete && T <: ReactantPrimitive return ConcreteRArray{T,N} else - return Array{traced_type_inner(T, seen, mode, track_numbers),N} + return Array{traced_type_inner(T, seen, mode, args...),N} end end for P in (Ptr, Core.LLVMPtr, Base.RefValue) - @eval Base.@nospecializeinfer function traced_type_inner( - @nospecialize(PT::Type{<:$P}), - seen, - mode::TraceMode, - @nospecialize(track_numbers::Type) - ) + @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg)) T = eltype(PT) - return $P{traced_type_inner(T, seen, mode, track_numbers)} + return $P{traced_type_inner(T, seen, mode, args...)} end end -Base.@nospecializeinfer function traced_type_inner( - @nospecialize(VT::Type{<:Val}), - seen, - @nospecialize(mode::TraceMode), - @nospecialize(track_numbers::Type) -) +Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg)) if VT isa UnionAll return VT end T = VT.parameters[1] - if traced_type_inner(typeof(T), seen, mode, track_numbers) == typeof(T) + if traced_type_inner(typeof(T), seen, mode, args...) == typeof(T) return Val{T} end throw("Val type $(Val{T}) cannot be traced") @@ -580,13 +519,15 @@ function make_tracer( @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), + @nospecialize(batchmode=BatchNone), + @nospecialize(tobatch=nothing), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) return seen[prev] end RT = Core.Typeof(prev) - TT = traced_type(RT, Val(mode), track_numbers) + TT = traced_type(RT, Val(mode), track_numbers, batchmode, tobatch) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -603,7 +544,13 @@ function make_tracer( if isdefined(prev, i) xi = Base.getfield(prev, i) xi2 = make_tracer( - seen, xi, append_path(path, i), mode; track_numbers, kwargs... + seen, + xi, + append_path(path, i), + mode; + track_numbers, + batchmode, + kwargs..., ) if xi !== xi2 changed = true @@ -887,12 +834,7 @@ function make_tracer( end function make_tracer( - seen, - @nospecialize(prev::Array), - @nospecialize(path), - mode; - @nospecialize(track_numbers::Type = Union{}), - kwargs..., + seen, @nospecialize(prev::Array), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), @nospecialize(batchmode=BatchNone), @nospecialize(tobatch=nothing), kwargs... ) RT = Core.Typeof(prev) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -901,14 +843,23 @@ function make_tracer( if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end - TT = traced_type(eltype(RT), Val(mode), track_numbers) + TT = traced_type(eltype(RT), Val(mode), track_numbers, batchmode, tobatch) newa = Array{TT,ndims(RT)}(undef, size(prev)) seen[prev] = newa same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = make_tracer(seen, pv, append_path(path, I), mode; track_numbers, kwargs...) + nv = make_tracer( + seen, + pv, + append_path(path, I), + mode; + track_numbers, + batchmode, + tobatch, + kwargs..., + ) if pv !== nv same = false end @@ -936,13 +887,15 @@ function make_tracer( @nospecialize(prev::NamedTuple), @nospecialize(path), mode; - @nospecialize(track_numbers::Type = Union{}), + @nospecialize(track_numbers::Type=Union{}), + @nospecialize(batchmode=BatchNone), + @nospecialize(tobatch=nothing), kwargs..., ) NT = Core.Typeof(prev) A = NT.parameters[1] RT = NT.parameters[2] - return NamedTuple{A,traced_type(RT, Val(mode), track_numbers)}(( + return NamedTuple{A,traced_type(RT, Val(mode), track_numbers, batchmode, tobatch)}(( ( make_tracer( seen, @@ -950,6 +903,8 @@ function make_tracer( append_path(path, i), mode; track_numbers, + batchmode, + tobatch, kwargs..., ) for i in 1:length(A) )..., From 08620450e67c970049448d683eb13c0b1c0f0d0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 14:59:28 -0500 Subject: [PATCH 07/13] fix: tests used older defn --- src/Tracing.jl | 2 +- test/tracing.jl | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index e1f718bc18..7cb80ee173 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -312,7 +312,7 @@ Base.@nospecializeinfer function traced_type_inner( return TracedRArray{T,N} elseif mode == TracedSetPath if batchmode == BatchNone - return T + return TracedRArray{T,N} elseif batchmode == BatchArray if tobatch === nothing TracedRArray{T,N - 1} diff --git a/test/tracing.jl b/test/tracing.jl index c196f562b6..e72634ac5a 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -144,7 +144,9 @@ using Test Base.Pairs{Symbol,Union{}}, ), ] - tracedty = traced_type(origty, Val(ConcreteToTraced), Union{}) + tracedty = traced_type( + origty, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing + ) @test tracedty == targetty tracedty2 = traced_type(origty, Val(ConcreteToTraced), ReactantPrimitive) @@ -158,7 +160,7 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Val(ConcreteToTraced), Union{} + type, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing ) end end @@ -167,7 +169,9 @@ using Test x::Vector{Float64} y::Union{Nothing,Node} end - @test_throws NoFieldMatchError traced_type(Node, Val(ArrayToConcrete), Union{}) + @test_throws NoFieldMatchError traced_type( + Node, Val(ArrayToConcrete), Union{}, Reactant.BatchNone, nothing + ) end end From bcc4d5988023c42a266a7ecd0e836eee5599d234 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 22:57:29 -0500 Subject: [PATCH 08/13] fix: update traced_type --- src/Tracing.jl | 138 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 34 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 7cb80ee173..bad5f3997c 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -13,7 +13,9 @@ end BatchArray = 3 end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg) +) if T === Any return T end @@ -137,18 +139,36 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, throw(NoFieldMatchError(T, TT2)) end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) return T end -for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, AbstractFloat, Integer, RNumber) - @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +for T in ( + DataType, + Module, + Nothing, + Symbol, + AbstractChar, + AbstractString, + AbstractFloat, + Integer, + RNumber, +) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg) + ) return T end end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(args::Vararg) + @nospecialize(T::Type{<:ReactantPrimitive}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(args::Vararg) ) if Mode == ArrayToConcrete && T <: track_numbers return ConcreteRNumber{T} @@ -159,7 +179,10 @@ Base.@nospecializeinfer function traced_type_inner( end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg) + @nospecialize(C::Type{<:Complex}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(args::Vararg) ) if !(C isa UnionAll) return Complex{traced_type_inner(C.parameters[1], seen, mode, args...)} @@ -168,7 +191,9 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) # functions are directly returned if sizeof(T) == 0 return T @@ -195,7 +220,9 @@ end @inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) @@ -209,7 +236,9 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple return Tuple{TT...} end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) N = T.parameters[1] V = T.parameters[2] return NamedTuple{N,traced_type_inner(V, seen, mode, args...)} @@ -222,7 +251,12 @@ Base.@nospecializeinfer @inline dict_value( ::Type{<:(AbstractDict{K,V} where {K})} ) where {V} = V -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:AbstractDict}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:AbstractDict}), + seen, + mode::TraceMode, + @nospecialize(args::Vararg) +) V = dict_value(T) if V === nothing return T @@ -246,7 +280,10 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Abstr end Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T0::Type{<:ConcreteRNumber}), seen, mode::TraceMode, @nospecialize(args::Vararg) + @nospecialize(T0::Type{<:ConcreteRNumber}), + seen, + mode::TraceMode, + @nospecialize(args::Vararg) ) T = T0.parameters[1] if mode == ConcreteToTraced @@ -269,7 +306,10 @@ Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) + @nospecialize(T::Type{<:ConcreteRArray}), + seen, + mode::TraceMode, + @nospecialize(args::Vararg) ) if mode == ConcreteToTraced return base_typet(T) @@ -280,7 +320,12 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:ConcreteRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:ConcreteRNG}), + seen, + mode::TraceMode, + @nospecialize(args::Vararg) +) if mode == ConcreteToTraced return TracedRNG elseif mode == TracedToConcrete @@ -291,16 +336,23 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Concr end Base.@nospecializeinfer function traced_type_inner( - ::Type{<:MissingTracedValue}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch) + ::Type{<:MissingTracedValue}, + seen, + mode::TraceMode, + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) ) - error("This should not happen...") + return error("This should not happen...") end -@inline base_typec(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, base_typec(TV.body)) -@inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...} - Base.@nospecializeinfer function traced_type_inner( - TR::Type{<:TracedRArray}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch) + TR::Type{<:TracedRArray}, + seen, + mode::TraceMode, + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) ) T = TR.parameters[1] N = TR.parameters[2] @@ -327,7 +379,9 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) if mode == ConcreteToTraced throw("TracedRNG cannot be traced") elseif mode == TracedToConcrete @@ -339,7 +393,9 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Trace end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) +) throw("XLA $T array cannot be traced") end @@ -356,13 +412,20 @@ Base.@nospecializeinfer function traced_type_inner( end for P in (Ptr, Core.LLVMPtr, Base.RefValue) - @eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg)) + @eval Base.@nospecializeinfer function traced_type_inner( + @nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg) + ) T = eltype(PT) return $P{traced_type_inner(T, seen, mode, args...)} end end -Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg)) +Base.@nospecializeinfer function traced_type_inner( + @nospecialize(VT::Type{<:Val}), + seen, + @nospecialize(mode::TraceMode), + @nospecialize(args::Vararg) +) if VT isa UnionAll return VT end @@ -373,7 +436,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val} throw("Val type $(Val{T}) cannot be traced") end -const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() +const traced_type_cache = Dict{Tuple{TraceMode,Type,Any,Any},Dict{Type,Type}}() # function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type)) # @nospecialize @@ -467,17 +530,17 @@ const traced_type_cache = Dict{Tuple{TraceMode,Type},Dict{Type,Type}}() # end Base.@assume_effects :total @inline function traced_type( - T::Type, ::Val{mode}, track_numbers::Type + T::Type, ::Val{mode}, track_numbers::Type, batchmode, tobatch ) where {mode} cache = nothing - cache_key = (mode, track_numbers) + cache_key = (mode, track_numbers, batchmode, tobatch) if haskey(traced_type_cache, cache_key) cache = traced_type_cache[cache_key] else cache = Dict{Type,Type}() traced_type_cache[cache_key] = cache end - return res1 = traced_type_inner(T, cache, mode, track_numbers) + return traced_type_inner(T, cache, mode, track_numbers, batchmode, tobatch) end abstract type TracedTypeException <: Exception end @@ -518,9 +581,9 @@ function make_tracer( @nospecialize(prev), @nospecialize(path), mode; - @nospecialize(track_numbers::Type=Union{}), - @nospecialize(batchmode=BatchNone), - @nospecialize(tobatch=nothing), + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), kwargs..., ) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -834,7 +897,14 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::Array), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), @nospecialize(batchmode=BatchNone), @nospecialize(tobatch=nothing), kwargs... + seen, + @nospecialize(prev::Array), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), + kwargs..., ) RT = Core.Typeof(prev) if mode != NoStopTracedTrack && haskey(seen, prev) @@ -887,9 +957,9 @@ function make_tracer( @nospecialize(prev::NamedTuple), @nospecialize(path), mode; - @nospecialize(track_numbers::Type=Union{}), - @nospecialize(batchmode=BatchNone), - @nospecialize(tobatch=nothing), + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(batchmode = BatchNone), + @nospecialize(tobatch = nothing), kwargs..., ) NT = Core.Typeof(prev) From 6f776a090d178246b14373cd5aac8f7607ff12df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Jan 2025 23:03:05 -0500 Subject: [PATCH 09/13] test: update call --- test/tracing.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/tracing.jl b/test/tracing.jl index e72634ac5a..0538ecb7fe 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -149,7 +149,13 @@ using Test ) @test tracedty == targetty - tracedty2 = traced_type(origty, Val(ConcreteToTraced), ReactantPrimitive) + tracedty2 = traced_type( + origty, + Val(ConcreteToTraced), + ReactantPrimitive, + Reactant.BatchNone, + nothing, + ) @test tracedty2 == targetty end From 0ad6330e7ff670285d19fa2012eeaaf3c92dc7a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 17:33:47 -0500 Subject: [PATCH 10/13] fix: tracing --- src/Tracing.jl | 55 ++++++++++++++++++++++++++++++++++---------------- test/basic.jl | 13 +++++------- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index bad5f3997c..181fa6c7f4 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -285,36 +285,25 @@ Base.@nospecializeinfer function traced_type_inner( mode::TraceMode, @nospecialize(args::Vararg) ) - T = T0.parameters[1] if mode == ConcreteToTraced - return TracedRNumber{T} + return TracedRNumber{T0.parameters[1]} elseif mode == TracedToConcrete - return ConcreteRNumber{T} + return T0 else throw("Abstract RNumber cannot be made concrete") end end -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typet(TV.body)) -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = - TracedRArray{TV.parameters...} - -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typec(TV.body)) -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = - (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} - Base.@nospecializeinfer function traced_type_inner( - @nospecialize(T::Type{<:ConcreteRArray}), + @nospecialize(CA::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(args::Vararg) ) if mode == ConcreteToTraced - return base_typet(T) + return TracedRArray{CA.parameters[1],CA.parameters[2]} elseif mode == TracedToConcrete - return T + return CA else throw("Abstract RArray cannot be made concrete") end @@ -346,6 +335,38 @@ Base.@nospecializeinfer function traced_type_inner( return error("This should not happen...") end +Base.@nospecializeinfer function traced_type_inner( + TR::Type{<:TracedRNumber}, + seen, + mode::TraceMode, + @nospecialize(track_numbers), + @nospecialize(batchmode), + @nospecialize(tobatch) +) + T = TR.parameters[1] + if mode == ConcreteToTraced + throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") + elseif mode == TracedToConcrete + return ConcreteRNumber{T} + elseif mode == TracedTrack || mode == NoStopTracedTrack + return TracedRNumber{T} + elseif mode == TracedSetPath + if batchmode == BatchNone + return TracedRNumber{T} + elseif batchmode == BatchScalar + if tobatch === nothing + return TracedRNumber{T} + else + return TracedRArray{T,length(tobatch)} + end + else + error("Cannot BatchArray on a scalar") + end + else + throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode") + end +end + Base.@nospecializeinfer function traced_type_inner( TR::Type{<:TracedRArray}, seen, @@ -359,7 +380,7 @@ Base.@nospecializeinfer function traced_type_inner( if mode == ConcreteToTraced throw("TracedRArray $(TracedRArray{T,N}) cannot be traced") elseif mode == TracedToConcrete - return base_typec(TracedRArray{T,N}) + return ConcreteRArray{T,N} elseif mode == TracedTrack || mode == NoStopTracedTrack return TracedRArray{T,N} elseif mode == TracedSetPath diff --git a/test/basic.jl b/test/basic.jl index c3952549a6..ad282f332c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -34,19 +34,18 @@ end end sinexp(x) = sin(exp(x)) -sinexpbc(x) = sinexp.(x) @testset "Broadcast combined" begin x = rand(2, 10) - r_res = sinexpbc(x) + r_res = sinexp.(x) a = Reactant.ConcreteRArray(x) - c_res = @allowscalar sinexpbc(a) + c_res = @allowscalar sinexp.(a) @test c_res ≈ r_res - @test @jit(sinexpbc(a)) ≈ r_res + @test @jit(sinexp.(a)) ≈ r_res end sumexp(x) = sum(exp, x) @@ -82,13 +81,11 @@ end @test f_res ≈ r_res end -bcast_cos(x) = cos.(x) - @testset "Basic cos" begin x = rand(3, 2) c = Reactant.ConcreteRArray(x) - @test @jit(bcast_cos(c)) ≈ cos.(x) + @test @jit(cos.(c)) ≈ cos.(x) end f_var(args...) = sum(args) @@ -376,7 +373,7 @@ end b = Reactant.to_rarray(_b) c = Reactant.to_rarray(_c) - # vcat test + # vcat test y = @jit vcat(a, b) @test y == vcat(a, _b) @test y isa ConcreteRArray{typeof_a,1} From dd222194957a57e5c2fa5aab823b5f6cb1609827 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 17:45:11 -0500 Subject: [PATCH 11/13] docs: setup batching tutorial --- docs/make.jl | 12 ++++++++++-- docs/src/.vitepress/config.mts | 10 +++++++++- docs/src/tutorials/batching.md | 3 +++ docs/src/tutorials/index.md | 1 + src/Compiler.jl | 4 ++-- src/Ops.jl | 22 ++++++++++++++++++++-- src/Reactant.jl | 2 -- test/batching.jl | 2 ++ test/runtests.jl | 1 + 9 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 docs/src/tutorials/batching.md create mode 100644 test/batching.jl diff --git a/docs/make.jl b/docs/make.jl index 64fbf78042..42100ab21a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,8 +24,11 @@ examples = [ pages = [ "Reactant.jl" => "index.md", "Introduction" => ["Getting Started" => "introduction/index.md"], - "Tutorials" => - ["Overview" => "tutorials/index.md", "Profiling" => "tutorials/profiling.md"], + "Tutorials" => [ + "Overview" => "tutorials/index.md", + "Profiling" => "tutorials/profiling.md", + "Batching Functions with `Reactant.Ops.batch`" => "tutorials/batching.md", + ], "API Reference" => [ "Reactant API" => "api/api.md", "Ops" => "api/ops.md", @@ -38,6 +41,11 @@ pages = [ "Func" => "api/func.md", "StableHLO" => "api/stablehlo.md", "VHLO" => "api/vhlo.md", + "GPU" => "api/gpu.md", + "LLVM" => "api/llvm.md", + "NVVM" => "api/nvvm.md", + "TPU" => "api/tpu.md", + "Triton" => "api/triton.md", ], "MLIR API" => "api/mlirc.md", "XLA" => "api/xla.md", diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 46b7347039..75f78b0dfa 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -56,8 +56,12 @@ export default defineConfig({ { text: "Tutorials", items: [ - {text: "Overview", link: "/tutorials/"}, + { text: "Overview", link: "/tutorials/" }, {text: "Profiling", link: "/tutorials/profiling"}, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching" + }, ], }, { @@ -112,6 +116,10 @@ export default defineConfig({ items: [ { text: "Overview", link: "/tutorials/" }, { text: "Profiling", link: "/tutorials/profiling" }, + { + text: "Batching Functions with `Reactant.Ops.batch`", + link: "/tutorials/batching", + }, ], }, "/api/": { diff --git a/docs/src/tutorials/batching.md b/docs/src/tutorials/batching.md new file mode 100644 index 0000000000..d3ca778847 --- /dev/null +++ b/docs/src/tutorials/batching.md @@ -0,0 +1,3 @@ +# [Batching Functions with [`Reactant.Ops.batch`](@ref)](@id batching-tutorial) + + diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 87c2c8ddd3..7be1f71047 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -1,5 +1,6 @@ # Tutorials - [Profiling](@ref profiling). + - [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) We are currently working on adding more tutorials to Reactant!! Please check back soon! diff --git a/src/Compiler.jl b/src/Compiler.jl index d1be9afe16..7b07942b4a 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -795,7 +795,7 @@ function codegen_unflatten!( paths = ( ( p for p in Reactant.TracedUtils.get_paths(result) if - length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs) + length(p) > 0 && (p[1] == :result || p[1] == :resargs) )..., ) for path in paths @@ -865,7 +865,7 @@ function codegen_unflatten!( paths = ( ( p for p in Reactant.TracedUtils.get_paths(result) if - length(p) ≥ 1 && (p[1] == :result || p[1] == :resargs || p[1] == :args) + length(p) > 0 && (p[1] == :result || p[1] == :resargs || p[1] == :args) )..., ) diff --git a/src/Ops.jl b/src/Ops.jl index 50d3b2cbfc..e001ce9fca 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2013,8 +2013,24 @@ end # This function assumes that the last dimension of each element is the batch dimension by # default. This is the standard Julia ordering for batching. We permutedims the ordering to # make sure the first dimension is the batch dimension when calling `batch_internal` below. -# XXX: Mutation inside a batched function is not supported yet (need to set the results -# correctly) +""" + batch(f, args...; batch_dims=nothing, result_dims=nothing) + +Map `f` over the arguments `args` along the batch dimensions `batch_dims` and return the results with the corresponding batch dimensions specified by `result_dims`. (For users +familiar with `jax`, this operation corresponds to `jax.vmap`.) + +If `batch_dims` is `nothing`, we assume that the last dimension of each leaf of `args` is the batch dimension. If `result_dims` is `nothing`, we assume that the last dimension of each leaf of the returned values is the batch dimension. + +To avoid batching a specific leaf, pass `nothing` for the corresponding `batch_dims`. + +## Examples + +For usage examples, see the [Batching Functions with `Reactant.Ops.batch`](@ref batching-tutorial) tutorial. + +!!! danger + + Mutation inside a batched function is not supported yet and will lead to unexpected results. +""" @noinline function batch(f, args...; batch_dims=nothing, result_dims=nothing) batch_sizes = Int64[] batching_dims = if batch_dims === nothing @@ -2060,6 +2076,8 @@ end end return fmap(results, result_dims) do result, dim + @assert dim !== nothing "Result batch dimension cannot be `nothing`" + order = collect(Int64, 1:ndims(result)) order[dim] = 1 order[1] = dim diff --git a/src/Reactant.jl b/src/Reactant.jl index 0dcae598cb..41f6ab9298 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -9,8 +9,6 @@ using Functors: @leaf using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` -using Functors: @leaf - export @allowscalar # re-exported from GPUArraysCore # auxiliary types and functions diff --git a/test/batching.jl b/test/batching.jl new file mode 100644 index 0000000000..cd6ae6bbf5 --- /dev/null +++ b/test/batching.jl @@ -0,0 +1,2 @@ +using Reactant, Test + diff --git a/test/runtests.jl b/test/runtests.jl index be17750042..9c6e094508 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") @safetestset "Sorting" include("sorting.jl") + @safetestset "Batching" include("batching.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" From f29babcb7c22e2b73f03462e93c2190062aee630 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 19 Jan 2025 20:49:53 -0500 Subject: [PATCH 12/13] fix: call with reactant --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index e001ce9fca..ac1b1b61a4 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2118,7 +2118,7 @@ end TracedRNumber{Reactant.unwrapped_eltype(arg)}, arg ) end - return f(scalar_args...) + return Reactant.call_with_reactant(f, scalar_args...) end end From 2241e54281014d0be3322128304076edd68fc819 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 11:11:09 -0500 Subject: [PATCH 13/13] fix: convert scalars to arrays in batcharray mode --- src/Tracing.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 181fa6c7f4..48d6d0b011 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -359,8 +359,14 @@ Base.@nospecializeinfer function traced_type_inner( else return TracedRArray{T,length(tobatch)} end + elseif batchmode == BatchArray + if tobatch === nothing + error("For scalars with batchmode=BatchArray, tobatch must be specified") + else + TracedRArray{T,length(tobatch)} + end else - error("Cannot BatchArray on a scalar") + error("Unknown batchmode $batchmode") end else throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode") @@ -817,8 +823,14 @@ function make_tracer( else TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) end + elseif batchmode == BatchArray + if tobatch === nothing + error("For scalars with batchmode=BatchArray, tobatch must be specified") + else + TracedRArray{T,length(tobatch)}((path,), prev.mlir_data, tobatch) + end else - error("Cannot BatchArray on a scalar") + error("Unknown batchmode $batchmode") end seen[prev] = res return res