diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 69c273c865..786bd8631f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -419,6 +419,17 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher return res end +import Reactant.TracedRNumberOverrides.TracedStepRangeLen + +function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen) + return TracedStepRangeLen( + Adapt.adapt(ReactantKernelAdaptor(), r.ref), + Adapt.adapt(ReactantKernelAdaptor(), r.step), + Adapt.adapt(ReactantKernelAdaptor(), r.len), + Adapt.adapt(ReactantKernelAdaptor(), r.offset), + ) +end + # Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string # and not the operation itself). struct LLVMFunc{F,tt} diff --git a/src/Compiler.jl b/src/Compiler.jl index 04aeb39e86..8e896cdc89 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -33,7 +33,8 @@ end end @inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} - (isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field) + (isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) && + return Base.getfield(obj, field) return Base.getindex(obj, field) end @@ -1472,7 +1473,6 @@ function codegen_flatten!( is_sharded && runtime isa Val{:PJRT} && (flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...)) - return flatten_names, flatten_code end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9a550d374f..23791ac1df 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -3,6 +3,9 @@ module TracedRNumberOverrides using ..Reactant: Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype using ReactantCore +using Adapt + +import Base.TwicePrecision ReactantCore.is_traced(::TracedRNumber, seen) = true ReactantCore.is_traced(::TracedRNumber) = true @@ -262,6 +265,42 @@ function Base.ifelse( end end +function Base.:*( + x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T} +) where {T<:TracedRNumber} + zh, zl = Base.mul12(x.hi, y.hi) + hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl) + hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi) + lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo) + + return Base.TwicePrecision{T}(hi, lo) +end + +function Base.:+( + x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T} +) where {T<:TracedRNumber} + r = x.hi + y.hi + @trace s = if abs(x.hi) > abs(y.hi) + begin + (((x.hi - r) + y.hi) + y.lo) + x.lo + end + else + begin + (((y.hi - r) + x.hi) + x.lo) + y.lo + end + end + return Base.TwicePrecision(Base.canonicalize2(r, s)...) +end + +function Base.:*(x::TwicePrecision, v::TracedRNumber) + @trace result = if v == 0 + TwicePrecision(x.hi * v, x.lo * v) + else + x * TwicePrecision(oftype(x.hi * v, v)) + end + return result +end + for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) T = promote_type(T1, T2) @eval begin @@ -271,18 +310,54 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end + function Base.:&(x::TracedRNumber{<:$(T1)}, y::$(T2)) + return Ops.and( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end + function Base.:&(x::$(T1), y::TracedRNumber{<:$(T2)}) + return Ops.and( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.or( TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end + function Base.:|(x::TracedRNumber{<:$(T1)}, y::$(T2)) + return Ops.or( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end + function Base.:|(x::$(T1), y::TracedRNumber{<:$(T2)}) + return Ops.or( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.xor( TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end + function Base.xor(x::TracedRNumber{<:$(T1)}, y::$(T2)) + return Ops.xor( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end + function Base.xor(x::$(T1), y::TracedRNumber{<:$(T2)}) + return Ops.xor( + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), + ) + end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) end end @@ -424,9 +499,188 @@ function Base.getindex( return Base.unsafe_getindex(r, i) end +struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T} + ref::R + step::S + len::L + offset::L +end + +function Adapt.parent_type(::Type{TracedStepRangeLen{T,R,S,L}}) where {T,R,S,L} + return TracedStepRangeLen{T,R,S,L} +end + +# constructors and interface implementation copied from range.jl +function TracedStepRangeLen{T,R,S}(ref::R, step::S, len, offset=1) where {T,R,S} + return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset) +end +function TracedStepRangeLen(ref::R, step::S, len, offset=1) where {R,S} + return TracedStepRangeLen{typeof(ref + zero(step)),R,S,typeof(len)}( + ref, step, len, offset + ) +end +function TracedStepRangeLen{T}( + ref::R, step::S, len::Integer, offset::Integer=1 +) where {T,R,S} + return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset) +end + +Base.isempty(r::TracedStepRangeLen) = length(r) == 0 +Base.step(r::TracedStepRangeLen) = r.step +Base.step_hp(r::TracedStepRangeLen) = r.step +Base.length(r::TracedStepRangeLen) = r.len +Base.first(r::TracedStepRangeLen) = Base.unsafe_getindex(r, 1) +Base.last(r::TracedStepRangeLen) = Base.unsafe_getindex(r, r.len) +function Base.iterate(r::TracedStepRangeLen, i::Integer=1) + @inline + i += oneunit(i) + length(r) < i && return nothing + return Base.unsafe_getindex(r, i), i +end + +function _tracedsteprangelen_unsafe_getindex( + r::AbstractRange{T}, i::Union{I,TracedRNumber{I}} +) where {T,I} + finalT = T + offsetT = typeof(r.offset) + if i isa TracedRNumber + if !(T <: TracedRNumber) + finalT = TracedRNumber{T} + end + if !(r.offset isa TracedRNumber) + offsetT = TracedRNumber{offsetT} + end + end + u = convert(offsetT, i) - r.offset + return finalT(r.ref + u * r.step) +end +function Base.unsafe_getindex(r::TracedStepRangeLen, i::Integer) + return _tracedsteprangelen_unsafe_getindex(r, i) +end +function Base.unsafe_getindex(r::TracedStepRangeLen, i::TracedRNumber{<:Integer}) + return _tracedsteprangelen_unsafe_getindex(r, i) +end +Base.getindex(r::TracedStepRangeLen, i::TracedRNumber) = Base.unsafe_getindex(r, i) +function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T,S<:Integer} + @inline + @boundscheck checkbounds(r, s) + + len = length(s) + sstep = Base.step_hp(s) + rstep = Base.step_hp(r) + L = typeof(len) + if S === Bool + rstep *= one(sstep) + if len == 0 + return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L)) + elseif len == 1 + if first(s) + return TracedStepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L)) + else + return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L)) + end + else # len == 2 + return TracedStepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L)) + end + else + # Find closest approach to offset by s + ind = LinearIndices(s) + offset = L( + max(min(1 + round(L, (r.offset - first(s)) / sstep), last(ind)), first(ind)) + ) + ref = Base._getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep) + return TracedStepRangeLen{T}(ref, rstep * sstep, len, offset) + end +end +function Base._getindex_hiprec(r::TracedStepRangeLen, i::Integer) # without rounding by T + u = oftype(r.offset, i) - r.offset + return r.ref + u * r.step +end +function Base.:(==)(r::T, s::T) where {T<:TracedStepRangeLen} + return (isempty(r) & isempty(s)) | + ((first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s))) +end + +# TODO: if there ever comes a ReactantStepRange: +# ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T} + +function Base.:-(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} + return TracedStepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset) +end + +# TODO: promotion from StepRangeLen{T} to TracedStepRangeLen{T}? +function Base.promote_rule( + ::Type{TracedStepRangeLen{T1,R1,S1,L1}}, ::Type{TracedStepRangeLen{T2,R2,S2,L2}} +) where {T1,T2,R1,R2,S1,S2,L1,L2} + R, S, L = promote_type(R1, R2), promote_type(S1, S2), promote_type(L1, L2) + return Base.el_same( + promote_type(T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L} + ) +end +TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} = r +function TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen) where {T,R,S,L} + return TracedStepRangeLen{T,R,S,L}( + convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset) + ) +end +function TracedStepRangeLen{T}(r::TracedStepRangeLen) where {T} + return TracedStepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset) +end +function Base.promote_rule( + a::Type{TracedStepRangeLen{T,R,S,L}}, ::Type{OR} +) where {T,R,S,L,OR<:AbstractRange} + return promote_rule(a, TracedStepRangeLen{eltype(OR),eltype(OR),eltype(OR),Int}) +end +function TracedStepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L} + return TracedStepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r)) +end +function TracedStepRangeLen{T}(r::AbstractRange) where {T} + return TracedStepRangeLen(T(first(r)), T(step(r)), length(r)) +end +TracedStepRangeLen(r::AbstractRange) = TracedStepRangeLen{eltype(r)}(r) + +function Base.promote_rule( + ::Type{LinRange{A,L}}, b::Type{TracedStepRangeLen{T2,R2,S2,L2}} +) where {A,L,T2,R2,S2,L2} + return promote_rule(TracedStepRangeLen{A,A,A,L}, b) +end + +function Base._reverse(r::TracedStepRangeLen, ::Colon) + # If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence + # invalid. As `reverse(r)` is also empty, any offset would work so we keep + # `r.offset` + offset = isempty(r) ? r.offset : length(r) - r.offset + 1 + return typeof(r)(r.ref, negate(r.step), length(r), offset) +end + +# TODO: +, - for TracedStepRangeLen (see Base._define_range_op) + +function (::Type{T})(x::TwicePrecision) where {T<:Reactant.TracedRNumber} + return (T(x.hi) + T(x.lo))::T +end + +function (::Type{T})(x::TwicePrecision) where {T<:Reactant.ConcreteRNumber} + return Reactant.ConcreteRNumber(T(x.hi) - T(x.lo))::T +end + +Base.nbitslen(r::TracedStepRangeLen) = Base.nbitslen(eltype(r), length(r), r.offset) +function TracedStepRangeLen( + ref::TwicePrecision{T}, step::TwicePrecision{T}, len, offset=1 +) where {T} + return TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}(ref, step, len, offset) +end +function Base.step(r::TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}) where {T} + return T(r.step) +end + # This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact function Base.unsafe_getindex( - r::Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision}, + r::Union{ + Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision}, + TracedStepRangeLen{ + T,<:Base.TwicePrecision,<:Base.TwicePrecision,<:Base.TwicePrecision + }, + }, i::TracedRNumber{<:Integer}, ) where {T} # Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12 @@ -449,7 +703,9 @@ function Base.unsafe_getindex( end function Base.searchsortedfirst( - a::AbstractRange{<:Real}, x::TracedRNumber{<:Real}, o::Base.DirectOrdering + a::AbstractRange{<:Union{Real,TracedRNumber}}, + x::TracedRNumber{<:Real}, + o::Base.DirectOrdering, )::TracedRNumber{keytype(a)} # require_one_based_indexing(a) @@ -460,7 +716,7 @@ function Base.searchsortedfirst( !Base.Order.lt(o, f, x), 1, ifelse( - h == 0 || Base.Order.lt(o, l, x), + (h == 0) | Base.Order.lt(o, l, x), length(a) + 1, ifelse(Base.Order.lt(o, a[n], x), n + 1, n), ), diff --git a/src/Tracing.jl b/src/Tracing.jl index 04dd7a9cb6..57e1034be0 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1837,3 +1837,87 @@ end end return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding, runtime) end + +function Reactant.traced_type_inner( + @nospecialize(RT::Type{<:StepRangeLen}), + seen, + mode::Reactant.TraceMode, + track_numbers::Type, + sharding, + runtime, +) + if !(Number <: track_numbers) + modified_track_numbers = Number + else + modified_track_numbers = track_numbers + end + T, R, S, L = RT.parameters + return TracedRNumberOverrides.TracedStepRangeLen{ + Reactant.traced_type_inner( + T, seen, mode, modified_track_numbers, sharding, runtime + ), + Reactant.traced_type_inner( + R, seen, mode, modified_track_numbers, sharding, runtime + ), + Reactant.traced_type_inner( + S, seen, mode, modified_track_numbers, sharding, runtime + ), + Reactant.traced_type_inner( + L, seen, mode, modified_track_numbers, sharding, runtime + ), + } +end + +function Reactant.make_tracer( + seen, + @nospecialize(prev::StepRangeLen), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + kwargs..., +) + Reactant.Sharding.is_sharded(sharding) && + error("Cannot specify sharding for StepRangeLen") + if mode == Reactant.TracedToTypes + push!(path, Core.Typeof(prev)) + make_tracer(seen, prev.ref, path, mode; kwargs...) + make_tracer(seen, prev.step, path, mode; kwargs...) + make_tracer(seen, prev.len, path, mode; kwargs...) + make_tracer(seen, prev.offset, path, mode; kwargs...) + return nothing + end + return TracedRNumberOverrides.TracedStepRangeLen( + Reactant.make_tracer( + seen, + prev.ref, + Reactant.append_path(path, :ref), + mode; + kwargs..., + track_numbers=Number, + ), + Reactant.make_tracer( + seen, + prev.step, + Reactant.append_path(path, :step), + mode; + kwargs..., + track_numbers=Number, + ), + Reactant.make_tracer( + seen, + prev.len, + Reactant.append_path(path, :len), + mode; + kwargs..., + track_numbers=Number, + ), + Reactant.make_tracer( + seen, + prev.offset, + Reactant.append_path(path, :offset), + mode; + kwargs..., + track_numbers=Number, + ), + ) +end diff --git a/test/basic.jl b/test/basic.jl index f1b3f74db8..49bfb33fef 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -999,6 +999,14 @@ end @test res[3] == 216 end +@testset "Traced fractional index" begin + times = Reactant.to_rarray(0:0.01:4.5) + res = @jit fractional_idx(times, ConcreteRNumber(2.143)) + @test res[1] == 0.29999999999997334 + @test res[2] == 215 + @test res[3] == 216 +end + mulpi(x) = π * x @testset "Irrational promotion" begin