Skip to content

Commit c8d1b0a

Browse files
authored
Revert "TracedStepRangeLen (#960)"
This reverts commit 8ea467f.
1 parent 0815fac commit c8d1b0a

File tree

5 files changed

+5
-364
lines changed

5 files changed

+5
-364
lines changed

ext/ReactantCUDAExt.jl

-11
Original file line numberDiff line numberDiff line change
@@ -419,17 +419,6 @@ function Adapt.adapt_storage(::ReactantKernelAdaptor, xs::TracedRNumber{T}) wher
419419
return res
420420
end
421421

422-
import Reactant.TracedRNumberOverrides.TracedStepRangeLen
423-
424-
function Adapt.adapt_storage(::ReactantKernelAdaptor, r::TracedStepRangeLen)
425-
return TracedStepRangeLen(
426-
Adapt.adapt(ReactantKernelAdaptor(), r.ref),
427-
Adapt.adapt(ReactantKernelAdaptor(), r.step),
428-
Adapt.adapt(ReactantKernelAdaptor(), r.len),
429-
Adapt.adapt(ReactantKernelAdaptor(), r.offset),
430-
)
431-
end
432-
433422
# Since we cache these objects we cannot cache data containing MLIR operations (e.g. the entry must be a string
434423
# and not the operation itself).
435424
struct LLVMFunc{F,tt}

src/Compiler.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ end
3333
end
3434

3535
@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T}
36-
(isbitstype(T) || ancestor(obj) isa RArray || obj isa AbstractRange) &&
37-
return Base.getfield(obj, field)
36+
(isbitstype(T) || ancestor(obj) isa RArray) && return Base.getfield(obj, field)
3837
return Base.getindex(obj, field)
3938
end
4039

@@ -1473,6 +1472,7 @@ function codegen_flatten!(
14731472
is_sharded &&
14741473
runtime isa Val{:PJRT} &&
14751474
(flatten_names = vcat(eachrow(reshape(flatten_names, length(mesh), :))...))
1475+
14761476
return flatten_names, flatten_code
14771477
end
14781478

src/TracedRNumber.jl

+3-259
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ module TracedRNumberOverrides
33
using ..Reactant:
44
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
55
using ReactantCore
6-
using Adapt
7-
8-
import Base.TwicePrecision
96

107
ReactantCore.is_traced(::TracedRNumber, seen) = true
118
ReactantCore.is_traced(::TracedRNumber) = true
@@ -265,42 +262,6 @@ function Base.ifelse(
265262
end
266263
end
267264

268-
function Base.:*(
269-
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
270-
) where {T<:TracedRNumber}
271-
zh, zl = Base.mul12(x.hi, y.hi)
272-
hi, lo = Base.canonicalize2(zh, (x.hi * y.lo + x.lo * y.hi) + zl)
273-
hi = ifelse(iszero(zh) | !isfinite(zh), zh, hi)
274-
lo = ifelse(iszero(zl) | !isfinite(zl), zl, lo)
275-
276-
return Base.TwicePrecision{T}(hi, lo)
277-
end
278-
279-
function Base.:+(
280-
x::Base.TwicePrecision{T}, y::Base.TwicePrecision{T}
281-
) where {T<:TracedRNumber}
282-
r = x.hi + y.hi
283-
@trace s = if abs(x.hi) > abs(y.hi)
284-
begin
285-
(((x.hi - r) + y.hi) + y.lo) + x.lo
286-
end
287-
else
288-
begin
289-
(((y.hi - r) + x.hi) + x.lo) + y.lo
290-
end
291-
end
292-
return Base.TwicePrecision(Base.canonicalize2(r, s)...)
293-
end
294-
295-
function Base.:*(x::TwicePrecision, v::TracedRNumber)
296-
@trace result = if v == 0
297-
TwicePrecision(x.hi * v, x.lo * v)
298-
else
299-
x * TwicePrecision(oftype(x.hi * v, v))
300-
end
301-
return result
302-
end
303-
304265
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
305266
T = promote_type(T1, T2)
306267
@eval begin
@@ -310,54 +271,18 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
310271
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
311272
)
312273
end
313-
function Base.:&(x::TracedRNumber{<:$(T1)}, y::$(T2))
314-
return Ops.and(
315-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
316-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
317-
)
318-
end
319-
function Base.:&(x::$(T1), y::TracedRNumber{<:$(T2)})
320-
return Ops.and(
321-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
322-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
323-
)
324-
end
325274
function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
326275
return Ops.or(
327276
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
328277
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
329278
)
330279
end
331-
function Base.:|(x::TracedRNumber{<:$(T1)}, y::$(T2))
332-
return Ops.or(
333-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
334-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
335-
)
336-
end
337-
function Base.:|(x::$(T1), y::TracedRNumber{<:$(T2)})
338-
return Ops.or(
339-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
340-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
341-
)
342-
end
343280
function Base.xor(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)})
344281
return Ops.xor(
345282
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
346283
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
347284
)
348285
end
349-
function Base.xor(x::TracedRNumber{<:$(T1)}, y::$(T2))
350-
return Ops.xor(
351-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
352-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
353-
)
354-
end
355-
function Base.xor(x::$(T1), y::TracedRNumber{<:$(T2)})
356-
return Ops.xor(
357-
TracedUtils.promote_to(TracedRNumber{$(T)}, x),
358-
TracedUtils.promote_to(TracedRNumber{$(T)}, y),
359-
)
360-
end
361286
Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x)
362287
end
363288
end
@@ -499,188 +424,9 @@ function Base.getindex(
499424
return Base.unsafe_getindex(r, i)
500425
end
501426

502-
struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
503-
ref::R
504-
step::S
505-
len::L
506-
offset::L
507-
end
508-
509-
function Adapt.parent_type(::Type{TracedStepRangeLen{T,R,S,L}}) where {T,R,S,L}
510-
return TracedStepRangeLen{T,R,S,L}
511-
end
512-
513-
# constructors and interface implementation copied from range.jl
514-
function TracedStepRangeLen{T,R,S}(ref::R, step::S, len, offset=1) where {T,R,S}
515-
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
516-
end
517-
function TracedStepRangeLen(ref::R, step::S, len, offset=1) where {R,S}
518-
return TracedStepRangeLen{typeof(ref + zero(step)),R,S,typeof(len)}(
519-
ref, step, len, offset
520-
)
521-
end
522-
function TracedStepRangeLen{T}(
523-
ref::R, step::S, len::Integer, offset::Integer=1
524-
) where {T,R,S}
525-
return TracedStepRangeLen{T,R,S,typeof(len)}(ref, step, len, offset)
526-
end
527-
528-
Base.isempty(r::TracedStepRangeLen) = length(r) == 0
529-
Base.step(r::TracedStepRangeLen) = r.step
530-
Base.step_hp(r::TracedStepRangeLen) = r.step
531-
Base.length(r::TracedStepRangeLen) = r.len
532-
Base.first(r::TracedStepRangeLen) = Base.unsafe_getindex(r, 1)
533-
Base.last(r::TracedStepRangeLen) = Base.unsafe_getindex(r, r.len)
534-
function Base.iterate(r::TracedStepRangeLen, i::Integer=1)
535-
@inline
536-
i += oneunit(i)
537-
length(r) < i && return nothing
538-
return Base.unsafe_getindex(r, i), i
539-
end
540-
541-
function _tracedsteprangelen_unsafe_getindex(
542-
r::AbstractRange{T}, i::Union{I,TracedRNumber{I}}
543-
) where {T,I}
544-
finalT = T
545-
offsetT = typeof(r.offset)
546-
if i isa TracedRNumber
547-
if !(T <: TracedRNumber)
548-
finalT = TracedRNumber{T}
549-
end
550-
if !(r.offset isa TracedRNumber)
551-
offsetT = TracedRNumber{offsetT}
552-
end
553-
end
554-
u = convert(offsetT, i) - r.offset
555-
return finalT(r.ref + u * r.step)
556-
end
557-
function Base.unsafe_getindex(r::TracedStepRangeLen, i::Integer)
558-
return _tracedsteprangelen_unsafe_getindex(r, i)
559-
end
560-
function Base.unsafe_getindex(r::TracedStepRangeLen, i::TracedRNumber{<:Integer})
561-
return _tracedsteprangelen_unsafe_getindex(r, i)
562-
end
563-
Base.getindex(r::TracedStepRangeLen, i::TracedRNumber) = Base.unsafe_getindex(r, i)
564-
function getindex(r::TracedStepRangeLen{T}, s::OrdinalRange{S}) where {T,S<:Integer}
565-
@inline
566-
@boundscheck checkbounds(r, s)
567-
568-
len = length(s)
569-
sstep = Base.step_hp(s)
570-
rstep = Base.step_hp(r)
571-
L = typeof(len)
572-
if S === Bool
573-
rstep *= one(sstep)
574-
if len == 0
575-
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
576-
elseif len == 1
577-
if first(s)
578-
return TracedStepRangeLen{T}(first(r), rstep, oneunit(L), oneunit(L))
579-
else
580-
return TracedStepRangeLen{T}(first(r), rstep, zero(L), oneunit(L))
581-
end
582-
else # len == 2
583-
return TracedStepRangeLen{T}(last(r), rstep, oneunit(L), oneunit(L))
584-
end
585-
else
586-
# Find closest approach to offset by s
587-
ind = LinearIndices(s)
588-
offset = L(
589-
max(min(1 + round(L, (r.offset - first(s)) / sstep), last(ind)), first(ind))
590-
)
591-
ref = Base._getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
592-
return TracedStepRangeLen{T}(ref, rstep * sstep, len, offset)
593-
end
594-
end
595-
function Base._getindex_hiprec(r::TracedStepRangeLen, i::Integer) # without rounding by T
596-
u = oftype(r.offset, i) - r.offset
597-
return r.ref + u * r.step
598-
end
599-
function Base.:(==)(r::T, s::T) where {T<:TracedStepRangeLen}
600-
return (isempty(r) & isempty(s)) |
601-
((first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s)))
602-
end
603-
604-
# TODO: if there ever comes a ReactantStepRange:
605-
# ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}
606-
607-
function Base.:-(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L}
608-
return TracedStepRangeLen{T,R,S,L}(-r.ref, -r.step, r.len, r.offset)
609-
end
610-
611-
# TODO: promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
612-
function Base.promote_rule(
613-
::Type{TracedStepRangeLen{T1,R1,S1,L1}}, ::Type{TracedStepRangeLen{T2,R2,S2,L2}}
614-
) where {T1,T2,R1,R2,S1,S2,L1,L2}
615-
R, S, L = promote_type(R1, R2), promote_type(S1, S2), promote_type(L1, L2)
616-
return Base.el_same(
617-
promote_type(T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
618-
)
619-
end
620-
TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen{T,R,S,L}) where {T,R,S,L} = r
621-
function TracedStepRangeLen{T,R,S,L}(r::TracedStepRangeLen) where {T,R,S,L}
622-
return TracedStepRangeLen{T,R,S,L}(
623-
convert(R, r.ref), convert(S, r.step), convert(L, r.len), convert(L, r.offset)
624-
)
625-
end
626-
function TracedStepRangeLen{T}(r::TracedStepRangeLen) where {T}
627-
return TracedStepRangeLen(convert(T, r.ref), convert(T, r.step), r.len, r.offset)
628-
end
629-
function Base.promote_rule(
630-
a::Type{TracedStepRangeLen{T,R,S,L}}, ::Type{OR}
631-
) where {T,R,S,L,OR<:AbstractRange}
632-
return promote_rule(a, TracedStepRangeLen{eltype(OR),eltype(OR),eltype(OR),Int})
633-
end
634-
function TracedStepRangeLen{T,R,S,L}(r::AbstractRange) where {T,R,S,L}
635-
return TracedStepRangeLen{T,R,S,L}(R(first(r)), S(step(r)), length(r))
636-
end
637-
function TracedStepRangeLen{T}(r::AbstractRange) where {T}
638-
return TracedStepRangeLen(T(first(r)), T(step(r)), length(r))
639-
end
640-
TracedStepRangeLen(r::AbstractRange) = TracedStepRangeLen{eltype(r)}(r)
641-
642-
function Base.promote_rule(
643-
::Type{LinRange{A,L}}, b::Type{TracedStepRangeLen{T2,R2,S2,L2}}
644-
) where {A,L,T2,R2,S2,L2}
645-
return promote_rule(TracedStepRangeLen{A,A,A,L}, b)
646-
end
647-
648-
function Base._reverse(r::TracedStepRangeLen, ::Colon)
649-
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
650-
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
651-
# `r.offset`
652-
offset = isempty(r) ? r.offset : length(r) - r.offset + 1
653-
return typeof(r)(r.ref, negate(r.step), length(r), offset)
654-
end
655-
656-
# TODO: +, - for TracedStepRangeLen (see Base._define_range_op)
657-
658-
function (::Type{T})(x::TwicePrecision) where {T<:Reactant.TracedRNumber}
659-
return (T(x.hi) + T(x.lo))::T
660-
end
661-
662-
function (::Type{T})(x::TwicePrecision) where {T<:Reactant.ConcreteRNumber}
663-
return Reactant.ConcreteRNumber(T(x.hi) - T(x.lo))::T
664-
end
665-
666-
Base.nbitslen(r::TracedStepRangeLen) = Base.nbitslen(eltype(r), length(r), r.offset)
667-
function TracedStepRangeLen(
668-
ref::TwicePrecision{T}, step::TwicePrecision{T}, len, offset=1
669-
) where {T}
670-
return TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}(ref, step, len, offset)
671-
end
672-
function Base.step(r::TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}}) where {T}
673-
return T(r.step)
674-
end
675-
676427
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
677428
function Base.unsafe_getindex(
678-
r::Union{
679-
Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
680-
TracedStepRangeLen{
681-
T,<:Base.TwicePrecision,<:Base.TwicePrecision,<:Base.TwicePrecision
682-
},
683-
},
429+
r::Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision},
684430
i::TracedRNumber{<:Integer},
685431
) where {T}
686432
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@@ -703,9 +449,7 @@ function Base.unsafe_getindex(
703449
end
704450

705451
function Base.searchsortedfirst(
706-
a::AbstractRange{<:Union{Real,TracedRNumber}},
707-
x::TracedRNumber{<:Real},
708-
o::Base.DirectOrdering,
452+
a::AbstractRange{<:Real}, x::TracedRNumber{<:Real}, o::Base.DirectOrdering
709453
)::TracedRNumber{keytype(a)}
710454

711455
# require_one_based_indexing(a)
@@ -716,7 +460,7 @@ function Base.searchsortedfirst(
716460
!Base.Order.lt(o, f, x),
717461
1,
718462
ifelse(
719-
(h == 0) | Base.Order.lt(o, l, x),
463+
h == 0 || Base.Order.lt(o, l, x),
720464
length(a) + 1,
721465
ifelse(Base.Order.lt(o, a[n], x), n + 1, n),
722466
),

0 commit comments

Comments
 (0)