Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TracedStepRangeLen #960

Merged
merged 18 commits into from
Mar 20, 2025
11 changes: 11 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,17 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
return res
end

import Reactant.TracedRNumberOverrides.TracedStepRangeLen

function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
return TracedStepRangeLen(
Adapt.adapt(ReactantKernelAdaptor(), r.ref),
Adapt.adapt(ReactantKernelAdaptor(), r.step),
Adapt.adapt(ReactantKernelAdaptor(), r.len),
Adapt.adapt(ReactantKernelAdaptor(), r.offset),
)
end

# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
# and not the operation itself).
struct LLVMFunc{F,tt}
Expand Down
4 changes: 2 additions & 2 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ end
end

@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
(isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) &&
return Base.getfield(obj, field)
return Base.getindex(obj, field)
end

Expand Down Expand Up @@ -1472,7 +1473,6 @@ function codegen_flatten!(
is_sharded &&
runtime isa Val{:PJRT} &&
(flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...))

return flatten_names, flatten_code
end

Expand Down
255 changes: 252 additions & 3 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using ..Reactant:
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
using ReactantCore

import Base.TwicePrecision

ReactantCore.is_traced(::TracedRNumber, seen) = true
ReactantCore.is_traced(::TracedRNumber) = true

Expand Down Expand Up @@ -262,6 +264,42 @@ function Base.ifelse(
end
end

function Base.:*(
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
) where {T<:TracedRNumber}
zh, zl = Base.mul12(x.hi, y.hi)
hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl)
hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi)
lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo)

return Base.TwicePrecision{T}(hi, lo)
end

function Base.:+(
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
) where {T<:TracedRNumber}
r = x.hi + y.hi
@trace s = if abs(x.hi) > abs(y.hi)
begin
(((x.hi - r) + y.hi) + y.lo) + x.lo
end
else
begin
(((y.hi - r) + x.hi) + x.lo) + y.lo
end
end
return Base.TwicePrecision(Base.canonicalize2(r, s)...)
end

function Base.:*(x::TwicePrecision, v::TracedRNumber)
@trace result = if v == 0
TwicePrecision(x.hi * v, x.lo * v)
else
x * TwicePrecision(oftype(x.hi * v, v))
end
return result
end

for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
T = promote_type(T1, T2)
@eval begin
Expand All @@ -271,18 +309,54 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:&(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.and(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:&(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.and(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.:|(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.or(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::TracedRNumber{<:$(T1)}, y::$(T2))
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
function Base.xor(x::$(T1), y::TracedRNumber{<:$(T2)})
return Ops.xor(
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
)
end
Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x)
end
end
Expand Down Expand Up @@ -424,9 +498,182 @@ function Base.getindex(
return Base.unsafe_getindex(r, i)
end

struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
ref::R
step::S
len::L
offset::L
end

# constructors and interface implementation copied from range.jl
function TracedStepRangeLen{T,R,S}(ref::R, step::S, len, offset=1) where {T,R,S}
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
end
function TracedStepRangeLen(ref::R, step::S, len, offset=1) where {R,S}
return TracedStepRangeLen{typeof(ref + zero(step)),R,S,typeof(len)}(
ref, step, len, offset
)
end
function TracedStepRangeLen{T}(
ref::R, step::S, len::Integer, offset::Integer=1
) where {T,R,S}
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
end

Base.isempty(r::TracedStepRangeLen) = length(r) == 0
Base.step(r::TracedStepRangeLen) = r.step
Base.step_hp(r::TracedStepRangeLen) = r.step
Base.length(r::TracedStepRangeLen) = r.len
Base.first(r::TracedStepRangeLen) = Base.unsafe_getindex(r, 1)
Base.last(r::TracedStepRangeLen) = Base.unsafe_getindex(r, r.len)
function Base.iterate(r::TracedStepRangeLen, i::Integer=1)
@inline
i += oneunit(i)
length(r) < i && return nothing
return Base.unsafe_getindex(r, i), i
end

function _tracedsteprangelen_unsafe_getindex(
r::AbstractRange{T}, i::Union{I,TracedRNumber{I}}
) where {T,I}
u = if i isa TracedRNumber
if !(T isa TracedRNumber)
finalT = TracedRNumber{T}
end
convert(TracedRNumber{typeof(r.offset)}, i) - r.offset
else
finalT = T
oftype(r.offset, i) - r.offset
end
return finalT(r.ref + u * r.step)
end
function Base.unsafe_getindex(r::TracedStepRangeLen, i::Integer)
return _tracedsteprangelen_unsafe_getindex(r, i)
end
function Base.unsafe_getindex(r::TracedStepRangeLen, i::TracedRNumber{<:Integer})
return _tracedsteprangelen_unsafe_getindex(r, i)
end
Base.getindex(r::TracedStepRangeLen, i::TracedRNumber) = Base.unsafe_getindex(r, i)
function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T,S<:Integer}
@inline
@boundscheck checkbounds(r, s)

len = length(s)
sstep = Base.step_hp(s)
rstep = Base.step_hp(r)
L = typeof(len)
if S === Bool
rstep *= one(sstep)
if len == 0
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
elseif len == 1
if first(s)
return TracedStepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
else
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
end
else # len == 2
return TracedStepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = L(
max(min(1 + round(L, (r.offset - first(s)) / sstep), last(ind)), first(ind))
)
ref = Base._getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
return TracedStepRangeLen{T}(ref, rstep * sstep, len, offset)
end
end
function Base._getindex_hiprec(r::TracedStepRangeLen, i::Integer) # without rounding by T
u = oftype(r.offset, i) - r.offset
return r.ref + u * r.step
end
function Base.:(==)(r::T, s::T) where {T<:TracedStepRangeLen}
return (isempty(r) & isempty(s)) |
((first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s)))
end

# TODO: if there ever comes a ReactantStepRange:
# ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}

function Base.:-(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
end

# TODO: promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
function Base.promote_rule(
::Type{TracedStepRangeLen{T1,R1,S1,L1}}, ::Type{TracedStepRangeLen{T2,R2,S2,L2}}
) where {T1,T2,R1,R2,S1,S2,L1,L2}
R, S, L = promote_type(R1, R2), promote_type(S1, S2), promote_type(L1, L2)
return Base.el_same(
promote_type(T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
)
end
TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} = r
function TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(
convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset)
)
end
function TracedStepRangeLen{T}(r::TracedStepRangeLen) where {T}
return TracedStepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)
end
function Base.promote_rule(
a::Type{TracedStepRangeLen{T,R,S,L}}, ::Type{OR}
) where {T,R,S,L,OR<:AbstractRange}
return promote_rule(a, TracedStepRangeLen{eltype(OR),eltype(OR),eltype(OR),Int})
end
function TracedStepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L}
return TracedStepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
end
function TracedStepRangeLen{T}(r::AbstractRange) where {T}
return TracedStepRangeLen(T(first(r)), T(step(r)), length(r))
end
TracedStepRangeLen(r::AbstractRange) = TracedStepRangeLen{eltype(r)}(r)

function Base.promote_rule(
::Type{LinRange{A,L}}, b::Type{TracedStepRangeLen{T2,R2,S2,L2}}
) where {A,L,T2,R2,S2,L2}
return promote_rule(TracedStepRangeLen{A,A,A,L}, b)
end

function Base._reverse(r::TracedStepRangeLen, ::Colon)
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
# `r.offset`
offset = isempty(r) ? r.offset : length(r) - r.offset + 1
return typeof(r)(r.ref, negate(r.step), length(r), offset)
end

# TODO: +, - for TracedStepRangeLen (see Base._define_range_op)

function (::Type{T})(x::TwicePrecision) where {T<:Reactant.TracedRNumber}
return (T(x.hi) + T(x.lo))::T
end

function (::Type{T})(x::TwicePrecision) where {T<:Reactant.ConcreteRNumber}
return Reactant.ConcreteRNumber(T(x.hi) - T(x.lo))::T
end

Base.nbitslen(r::TracedStepRangeLen) = Base.nbitslen(eltype(r), length(r), r.offset)
function TracedStepRangeLen(
ref::TwicePrecision{T}, step::TwicePrecision{T}, len, offset=1
) where {T}
return TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}(ref, step, len, offset)
end
function Base.step(r::TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}) where {T}
return T(r.step)
end

# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
function Base.unsafe_getindex(
r::Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
r::Union{
Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
TracedStepRangeLen{
T,<:Base.TwicePrecision,<:Base.TwicePrecision,<:Base.TwicePrecision
},
},
i::TracedRNumber{<:Integer},
) where {T}
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
Expand All @@ -449,7 +696,9 @@ function Base.unsafe_getindex(
end

function Base.searchsortedfirst(
a::AbstractRange{<:Real}, x::TracedRNumber{<:Real}, o::Base.DirectOrdering
a::AbstractRange{<:Union{Real,TracedRNumber}},
x::TracedRNumber{<:Real},
o::Base.DirectOrdering,
)::TracedRNumber{keytype(a)}

# require_one_based_indexing(a)
Expand All @@ -460,7 +709,7 @@ function Base.searchsortedfirst(
!Base.Order.lt(o, f, x),
1,
ifelse(
h == 0 || Base.Order.lt(o, l, x),
(h == 0) | Base.Order.lt(o, l, x),
length(a) + 1,
ifelse(Base.Order.lt(o, a[n], x), n + 1, n),
),
Expand Down
Loading
Loading