Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TracedStepRangeLen #960

Merged
merged 18 commits into from
Mar 20, 2025
11 changes: 11 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -419,6 +419,17 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
return res
end

import Reactant.TracedRNumberOverrides.TracedStepRangeLen

function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
return TracedStepRangeLen(
Adapt.adapt(ReactantKernelAdaptor(), r.ref),
Adapt.adapt(ReactantKernelAdaptor(), r.step),
Adapt.adapt(ReactantKernelAdaptor(), r.len),
Adapt.adapt(ReactantKernelAdaptor(), r.offset),
)
end

# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
# and not the operation itself).
struct LLVMFunc{F,tt}
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
@@ -33,7 +33,8 @@ end
end

@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
(isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) &&
return Base.getfield(obj, field)
return Base.getindex(obj, field)
end

@@ -1472,7 +1473,6 @@ function codegen_flatten!(
is_sharded &&
runtime isa Val{:PJRT} &&
(flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...))

return flatten_names, flatten_code
end

257 changes: 254 additions & 3 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@ using ..Reactant:
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
using ReactantCore

import Base.TwicePrecision

ReactantCore.is_traced(::TracedRNumber, seen) = true
ReactantCore.is_traced(::TracedRNumber) = true

@@ -262,6 +264,42 @@ function Base.ifelse(
end
end

function Base.:*(
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
) where {T<:TracedRNumber}
zh, zl = Base.mul12(x.hi, y.hi)
hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl)
hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi)
lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo)

return Base.TwicePrecision{T}(hi, lo)
end

function Base.:+(
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
) where {T<:TracedRNumber}
r = x.hi + y.hi
@trace s = if abs(x.hi) > abs(y.hi)
begin
(((x.hi - r) + y.hi) + y.lo) + x.lo
end
else
begin
(((y.hi - r) + x.hi) + x.lo) + y.lo
end
end
return Base.TwicePrecision(Base.canonicalize2(r, s)...)
end

function Base.:*(x::TwicePrecision, v::TracedRNumber)
@trace result = if v == 0
TwicePrecision(x.hi * v, x.lo * v)
else
x * TwicePrecision(oftype(x.hi * v, v))
end
return result
end

for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
T = promote_type(T1, T2)
@eval begin
@@ -271,18 +309,54 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:&(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.and(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:&(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.and(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x)
end
end
@@ -424,9 +498,184 @@ function Base.getindex(
return Base.unsafe_getindex(r, i)
end

struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
ref::R
step::S
len::L
offset::L
end

# constructors and interface implementation copied from range.jl
function TracedStepRangeLen{T,R,S}(ref::R, step::S, len, offset=1) where {T,R,S}
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
end
function TracedStepRangeLen(ref::R, step::S, len, offset=1) where {R,S}
return TracedStepRangeLen{typeof(ref + zero(step)),R,S,typeof(len)}(
ref, step, len, offset
)
end
function TracedStepRangeLen{T}(
ref::R, step::S, len::Integer, offset::Integer=1
) where {T,R,S}
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
end

Base.isempty(r::TracedStepRangeLen) = length(r) == 0
Base.step(r::TracedStepRangeLen) = r.step
Base.step_hp(r::TracedStepRangeLen) = r.step
Base.length(r::TracedStepRangeLen) = r.len
Base.first(r::TracedStepRangeLen) = Base.unsafe_getindex(r, 1)
Base.last(r::TracedStepRangeLen) = Base.unsafe_getindex(r, r.len)
function Base.iterate(r::TracedStepRangeLen, i::Integer=1)
@inline
i += oneunit(i)
length(r) < i && return nothing
return Base.unsafe_getindex(r, i), i
end

function _tracedsteprangelen_unsafe_getindex(
r::AbstractRange{T}, i::Union{I,TracedRNumber{I}}
) where {T,I}
finalT = T
offsetT = typeof(r.offset)
if i isa TracedRNumber
if !(T <: TracedRNumber)
finalT = TracedRNumber{T}
end
if !(r.offset isa TracedRNumber)
offsetT = TracedRNumber{offsetT}
end
end
u = convert(offsetT, i) - r.offset
return finalT(r.ref + u * r.step)
end
function Base.unsafe_getindex(r::TracedStepRangeLen, i::Integer)
return _tracedsteprangelen_unsafe_getindex(r, i)
end
function Base.unsafe_getindex(r::TracedStepRangeLen, i::TracedRNumber{<:Integer})
return _tracedsteprangelen_unsafe_getindex(r, i)
end
Base.getindex(r::TracedStepRangeLen, i::TracedRNumber) = Base.unsafe_getindex(r, i)
function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T,S<:Integer}
@inline
@boundscheck checkbounds(r, s)

len = length(s)
sstep = Base.step_hp(s)
rstep = Base.step_hp(r)
L = typeof(len)
if S === Bool
rstep *= one(sstep)
if len == 0
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
elseif len == 1
if first(s)
return TracedStepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
else
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
end
else # len == 2
return TracedStepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = L(
max(min(1 + round(L, (r.offset - first(s)) / sstep), last(ind)), first(ind))
)
ref = Base._getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
return TracedStepRangeLen{T}(ref, rstep * sstep, len, offset)
end
end
function Base._getindex_hiprec(r::TracedStepRangeLen, i::Integer) # without rounding by T
u = oftype(r.offset, i) - r.offset
return r.ref + u * r.step
end
function Base.:(==)(r::T, s::T) where {T<:TracedStepRangeLen}
return (isempty(r) & isempty(s)) |
((first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s)))
end

# TODO: if there ever comes a ReactantStepRange:
# ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}

function Base.:-(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
end

# TODO: promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
function Base.promote_rule(
::Type{TracedStepRangeLen{T1,R1,S1,L1}}, ::Type{TracedStepRangeLen{T2,R2,S2,L2}}
) where {T1,T2,R1,R2,S1,S2,L1,L2}
R, S, L = promote_type(R1, R2), promote_type(S1, S2), promote_type(L1, L2)
return Base.el_same(
promote_type(T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
)
end
TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} = r
function TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(
convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset)
)
end
function TracedStepRangeLen{T}(r::TracedStepRangeLen) where {T}
return TracedStepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)
end
function Base.promote_rule(
a::Type{TracedStepRangeLen{T,R,S,L}}, ::Type{OR}
) where {T,R,S,L,OR<:AbstractRange}
return promote_rule(a, TracedStepRangeLen{eltype(OR),eltype(OR),eltype(OR),Int})
end
function TracedStepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
end
function TracedStepRangeLen{T}(r::AbstractRange) where {T}
return TracedStepRangeLen(T(first(r)), T(step(r)), length(r))
end
TracedStepRangeLen(r::AbstractRange) = TracedStepRangeLen{eltype(r)}(r)

function Base.promote_rule(
::Type{LinRange{A,L}}, b::Type{TracedStepRangeLen{T2,R2,S2,L2}}
) where {A,L,T2,R2,S2,L2}
return promote_rule(TracedStepRangeLen{A,A,A,L}, b)
end

function Base._reverse(r::TracedStepRangeLen, ::Colon)
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
# `r.offset`
offset = isempty(r) ? r.offset : length(r) - r.offset + 1
return typeof(r)(r.ref, negate(r.step), length(r), offset)
end

# TODO: +, - for TracedStepRangeLen (see Base._define_range_op)

function (::Type{T})(x::TwicePrecision) where {T<:Reactant.TracedRNumber}
return (T(x.hi) + T(x.lo))::T
end

function (::Type{T})(x::TwicePrecision) where {T<:Reactant.ConcreteRNumber}
return Reactant.ConcreteRNumber(T(x.hi) - T(x.lo))::T
end

Base.nbitslen(r::TracedStepRangeLen) = Base.nbitslen(eltype(r), length(r), r.offset)
function TracedStepRangeLen(
ref::TwicePrecision{T}, step::TwicePrecision{T}, len, offset=1
) where {T}
return TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}(ref, step, len, offset)
end
function Base.step(r::TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}) where {T}
return T(r.step)
end

# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
function Base.unsafe_getindex(
r::Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
r::Union{
Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
TracedStepRangeLen{
T,<:Base.TwicePrecision,<:Base.TwicePrecision,<:Base.TwicePrecision
},
},
i::TracedRNumber{<:Integer},
) where {T}
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@@ -449,7 +698,9 @@ function Base.unsafe_getindex(
end

function Base.searchsortedfirst(
a::AbstractRange{<:Real}, x::TracedRNumber{<:Real}, o::Base.DirectOrdering
a::AbstractRange{<:Union{Real,TracedRNumber}},
x::TracedRNumber{<:Real},
o::Base.DirectOrdering,
)::TracedRNumber{keytype(a)}

# require_one_based_indexing(a)
@@ -460,7 +711,7 @@ function Base.searchsortedfirst(
!Base.Order.lt(o, f, x),
1,
ifelse(
h == 0 || Base.Order.lt(o, l, x),
(h == 0) | Base.Order.lt(o, l, x),
length(a) + 1,
ifelse(Base.Order.lt(o, a[n], x), n + 1, n),
),
84 changes: 84 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
@@ -1837,3 +1837,87 @@ end
end
return @invoke to_rarray_internal(x::Any, track_numbers::Type, sharding, runtime)
end

function Reactant.traced_type_inner(
@nospecialize(RT::Type{<:StepRangeLen}),
seen,
mode::Reactant.TraceMode,
track_numbers::Type,
sharding,
runtime,
)
if !(Number <: track_numbers)
modified_track_numbers = Number
else
modified_track_numbers = track_numbers
end
T, R, S, L = RT.parameters
return TracedRNumberOverrides.TracedStepRangeLen{
Reactant.traced_type_inner(
T, seen, mode, modified_track_numbers, sharding, runtime
),
Reactant.traced_type_inner(
R, seen, mode, modified_track_numbers, sharding, runtime
),
Reactant.traced_type_inner(
S, seen, mode, modified_track_numbers, sharding, runtime
),
Reactant.traced_type_inner(
L, seen, mode, modified_track_numbers, sharding, runtime
),
}
end

function Reactant.make_tracer(
seen,
@nospecialize(prev::StepRangeLen),
@nospecialize(path),
mode;
@nospecialize(sharding = Sharding.NoSharding()),
kwargs...,
)
Reactant.Sharding.is_sharded(sharding) &&
error("Cannot specify sharding for StepRangeLen")
if mode == Reactant.TracedToTypes
push!(path, Core.Typeof(prev))
make_tracer(seen, prev.ref, path, mode; kwargs...)
make_tracer(seen, prev.step, path, mode; kwargs...)
make_tracer(seen, prev.len, path, mode; kwargs...)
make_tracer(seen, prev.offset, path, mode; kwargs...)
return nothing
end
return TracedRNumberOverrides.TracedStepRangeLen(
Reactant.make_tracer(
seen,
prev.ref,
Reactant.append_path(path, :ref),
mode;
kwargs...,
track_numbers=Number,
),
Reactant.make_tracer(
seen,
prev.step,
Reactant.append_path(path, :step),
mode;
kwargs...,
track_numbers=Number,
),
Reactant.make_tracer(
seen,
prev.len,
Reactant.append_path(path, :len),
mode;
kwargs...,
track_numbers=Number,
),
Reactant.make_tracer(
seen,
prev.offset,
Reactant.append_path(path, :offset),
mode;
kwargs...,
track_numbers=Number,
),
)
end
8 changes: 8 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -999,6 +999,14 @@ end
@test res[3] == 216
end

@testset "Traced fractional index" begin
times = Reactant.to_rarray(0:0.01:4.5)
res = @jit fractional_idx(times, ConcreteRNumber(2.143))
@test res[1] == 0.29999999999997334
@test res[2] == 215
@test res[3] == 216
end

mulpi(x) = π * x

@testset "Irrational promotion" begin