@@ -3,9 +3,6 @@ module TracedRNumberOverrides
3
3
using .. Reactant:
4
4
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
5
5
using ReactantCore
6
- using Adapt
7
-
8
- import Base. TwicePrecision
9
6
10
7
ReactantCore. is_traced (:: TracedRNumber , seen) = true
11
8
ReactantCore. is_traced (:: TracedRNumber ) = true
@@ -265,42 +262,6 @@ function Base.ifelse(
265
262
end
266
263
end
267
264
268
- function Base.:* (
269
- x:: Base.TwicePrecision{T} , y:: Base.TwicePrecision{T}
270
- ) where {T<: TracedRNumber }
271
- zh, zl = Base. mul12 (x. hi, y. hi)
272
- hi, lo = Base. canonicalize2 (zh, (x. hi * y. lo + x. lo * y. hi) + zl)
273
- hi = ifelse (iszero (zh) | ! isfinite (zh), zh, hi)
274
- lo = ifelse (iszero (zl) | ! isfinite (zl), zl, lo)
275
-
276
- return Base. TwicePrecision {T} (hi, lo)
277
- end
278
-
279
- function Base.:+ (
280
- x:: Base.TwicePrecision{T} , y:: Base.TwicePrecision{T}
281
- ) where {T<: TracedRNumber }
282
- r = x. hi + y. hi
283
- @trace s = if abs (x. hi) > abs (y. hi)
284
- begin
285
- (((x. hi - r) + y. hi) + y. lo) + x. lo
286
- end
287
- else
288
- begin
289
- (((y. hi - r) + x. hi) + x. lo) + y. lo
290
- end
291
- end
292
- return Base. TwicePrecision (Base. canonicalize2 (r, s)... )
293
- end
294
-
295
- function Base.:* (x:: TwicePrecision , v:: TracedRNumber )
296
- @trace result = if v == 0
297
- TwicePrecision (x. hi * v, x. lo * v)
298
- else
299
- x * TwicePrecision (oftype (x. hi * v, v))
300
- end
301
- return result
302
- end
303
-
304
265
for (T1, T2) in zip ((Bool, Integer), (Bool, Integer))
305
266
T = promote_type (T1, T2)
306
267
@eval begin
@@ -310,54 +271,18 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer))
310
271
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
311
272
)
312
273
end
313
- function Base.:& (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
314
- return Ops. and (
315
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
316
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
317
- )
318
- end
319
- function Base.:& (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
320
- return Ops. and (
321
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
322
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
323
- )
324
- end
325
274
function Base.:| (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
326
275
return Ops. or (
327
276
TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
328
277
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
329
278
)
330
279
end
331
- function Base.:| (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
332
- return Ops. or (
333
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
334
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
335
- )
336
- end
337
- function Base.:| (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
338
- return Ops. or (
339
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
340
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
341
- )
342
- end
343
280
function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: TracedRNumber{<:$(T2)} )
344
281
return Ops. xor (
345
282
TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
346
283
TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
347
284
)
348
285
end
349
- function Base. xor (x:: TracedRNumber{<:$(T1)} , y:: $ (T2))
350
- return Ops. xor (
351
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
352
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
353
- )
354
- end
355
- function Base. xor (x:: $ (T1), y:: TracedRNumber{<:$(T2)} )
356
- return Ops. xor (
357
- TracedUtils. promote_to (TracedRNumber{$ (T)}, x),
358
- TracedUtils. promote_to (TracedRNumber{$ (T)}, y),
359
- )
360
- end
361
286
Base.:! (x:: TracedRNumber{<:$(T1)} ) = Ops. not (x)
362
287
end
363
288
end
@@ -499,188 +424,9 @@ function Base.getindex(
499
424
return Base. unsafe_getindex (r, i)
500
425
end
501
426
502
- struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}
503
- ref:: R
504
- step:: S
505
- len:: L
506
- offset:: L
507
- end
508
-
509
- function Adapt. parent_type (:: Type{TracedStepRangeLen{T,R,S,L}} ) where {T,R,S,L}
510
- return TracedStepRangeLen{T,R,S,L}
511
- end
512
-
513
- # constructors and interface implementation copied from range.jl
514
- function TracedStepRangeLen {T,R,S} (ref:: R , step:: S , len, offset= 1 ) where {T,R,S}
515
- return TracedStepRangeLen {T,R,S,typeof(len)} (ref, step, len, offset)
516
- end
517
- function TracedStepRangeLen (ref:: R , step:: S , len, offset= 1 ) where {R,S}
518
- return TracedStepRangeLen {typeof(ref + zero(step)),R,S,typeof(len)} (
519
- ref, step, len, offset
520
- )
521
- end
522
- function TracedStepRangeLen {T} (
523
- ref:: R , step:: S , len:: Integer , offset:: Integer = 1
524
- ) where {T,R,S}
525
- return TracedStepRangeLen {T,R,S,typeof(len)} (ref, step, len, offset)
526
- end
527
-
528
- Base. isempty (r:: TracedStepRangeLen ) = length (r) == 0
529
- Base. step (r:: TracedStepRangeLen ) = r. step
530
- Base. step_hp (r:: TracedStepRangeLen ) = r. step
531
- Base. length (r:: TracedStepRangeLen ) = r. len
532
- Base. first (r:: TracedStepRangeLen ) = Base. unsafe_getindex (r, 1 )
533
- Base. last (r:: TracedStepRangeLen ) = Base. unsafe_getindex (r, r. len)
534
- function Base. iterate (r:: TracedStepRangeLen , i:: Integer = 1 )
535
- @inline
536
- i += oneunit (i)
537
- length (r) < i && return nothing
538
- return Base. unsafe_getindex (r, i), i
539
- end
540
-
541
- function _tracedsteprangelen_unsafe_getindex (
542
- r:: AbstractRange{T} , i:: Union{I,TracedRNumber{I}}
543
- ) where {T,I}
544
- finalT = T
545
- offsetT = typeof (r. offset)
546
- if i isa TracedRNumber
547
- if ! (T <: TracedRNumber )
548
- finalT = TracedRNumber{T}
549
- end
550
- if ! (r. offset isa TracedRNumber)
551
- offsetT = TracedRNumber{offsetT}
552
- end
553
- end
554
- u = convert (offsetT, i) - r. offset
555
- return finalT (r. ref + u * r. step)
556
- end
557
- function Base. unsafe_getindex (r:: TracedStepRangeLen , i:: Integer )
558
- return _tracedsteprangelen_unsafe_getindex (r, i)
559
- end
560
- function Base. unsafe_getindex (r:: TracedStepRangeLen , i:: TracedRNumber{<:Integer} )
561
- return _tracedsteprangelen_unsafe_getindex (r, i)
562
- end
563
- Base. getindex (r:: TracedStepRangeLen , i:: TracedRNumber ) = Base. unsafe_getindex (r, i)
564
- function getindex (r:: TracedStepRangeLen{T} , s:: OrdinalRange{S} ) where {T,S<: Integer }
565
- @inline
566
- @boundscheck checkbounds (r, s)
567
-
568
- len = length (s)
569
- sstep = Base. step_hp (s)
570
- rstep = Base. step_hp (r)
571
- L = typeof (len)
572
- if S === Bool
573
- rstep *= one (sstep)
574
- if len == 0
575
- return TracedStepRangeLen {T} (first (r), rstep, zero (L), oneunit (L))
576
- elseif len == 1
577
- if first (s)
578
- return TracedStepRangeLen {T} (first (r), rstep, oneunit (L), oneunit (L))
579
- else
580
- return TracedStepRangeLen {T} (first (r), rstep, zero (L), oneunit (L))
581
- end
582
- else # len == 2
583
- return TracedStepRangeLen {T} (last (r), rstep, oneunit (L), oneunit (L))
584
- end
585
- else
586
- # Find closest approach to offset by s
587
- ind = LinearIndices (s)
588
- offset = L (
589
- max (min (1 + round (L, (r. offset - first (s)) / sstep), last (ind)), first (ind))
590
- )
591
- ref = Base. _getindex_hiprec (r, first (s) + (offset - oneunit (offset)) * sstep)
592
- return TracedStepRangeLen {T} (ref, rstep * sstep, len, offset)
593
- end
594
- end
595
- function Base. _getindex_hiprec (r:: TracedStepRangeLen , i:: Integer ) # without rounding by T
596
- u = oftype (r. offset, i) - r. offset
597
- return r. ref + u * r. step
598
- end
599
- function Base.:(== )(r:: T , s:: T ) where {T<: TracedStepRangeLen }
600
- return (isempty (r) & isempty (s)) |
601
- ((first (r) == first (s)) & (length (r) == length (s)) & (last (r) == last (s)))
602
- end
603
-
604
- # TODO : if there ever comes a ReactantStepRange:
605
- # ==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T}
606
-
607
- function Base.:- (r:: TracedStepRangeLen{T,R,S,L} ) where {T,R,S,L}
608
- return TracedStepRangeLen {T,R,S,L} (- r. ref, - r. step, r. len, r. offset)
609
- end
610
-
611
- # TODO : promotion from StepRangeLen{T} to TracedStepRangeLen{T}?
612
- function Base. promote_rule (
613
- :: Type{TracedStepRangeLen{T1,R1,S1,L1}} , :: Type{TracedStepRangeLen{T2,R2,S2,L2}}
614
- ) where {T1,T2,R1,R2,S1,S2,L1,L2}
615
- R, S, L = promote_type (R1, R2), promote_type (S1, S2), promote_type (L1, L2)
616
- return Base. el_same (
617
- promote_type (T1, T2), TracedStepRangeLen{T1,R,S,L}, TracedStepRangeLen{T2,R,S,L}
618
- )
619
- end
620
- TracedStepRangeLen {T,R,S,L} (r:: TracedStepRangeLen{T,R,S,L} ) where {T,R,S,L} = r
621
- function TracedStepRangeLen {T,R,S,L} (r:: TracedStepRangeLen ) where {T,R,S,L}
622
- return TracedStepRangeLen {T,R,S,L} (
623
- convert (R, r. ref), convert (S, r. step), convert (L, r. len), convert (L, r. offset)
624
- )
625
- end
626
- function TracedStepRangeLen {T} (r:: TracedStepRangeLen ) where {T}
627
- return TracedStepRangeLen (convert (T, r. ref), convert (T, r. step), r. len, r. offset)
628
- end
629
- function Base. promote_rule (
630
- a:: Type{TracedStepRangeLen{T,R,S,L}} , :: Type{OR}
631
- ) where {T,R,S,L,OR<: AbstractRange }
632
- return promote_rule (a, TracedStepRangeLen{eltype (OR),eltype (OR),eltype (OR),Int})
633
- end
634
- function TracedStepRangeLen {T,R,S,L} (r:: AbstractRange ) where {T,R,S,L}
635
- return TracedStepRangeLen {T,R,S,L} (R (first (r)), S (step (r)), length (r))
636
- end
637
- function TracedStepRangeLen {T} (r:: AbstractRange ) where {T}
638
- return TracedStepRangeLen (T (first (r)), T (step (r)), length (r))
639
- end
640
- TracedStepRangeLen (r:: AbstractRange ) = TracedStepRangeLen {eltype(r)} (r)
641
-
642
- function Base. promote_rule (
643
- :: Type{LinRange{A,L}} , b:: Type{TracedStepRangeLen{T2,R2,S2,L2}}
644
- ) where {A,L,T2,R2,S2,L2}
645
- return promote_rule (TracedStepRangeLen{A,A,A,L}, b)
646
- end
647
-
648
- function Base. _reverse (r:: TracedStepRangeLen , :: Colon )
649
- # If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
650
- # invalid. As `reverse(r)` is also empty, any offset would work so we keep
651
- # `r.offset`
652
- offset = isempty (r) ? r. offset : length (r) - r. offset + 1
653
- return typeof (r)(r. ref, negate (r. step), length (r), offset)
654
- end
655
-
656
- # TODO : +, - for TracedStepRangeLen (see Base._define_range_op)
657
-
658
- function (:: Type{T} )(x:: TwicePrecision ) where {T<: Reactant.TracedRNumber }
659
- return (T (x. hi) + T (x. lo)):: T
660
- end
661
-
662
- function (:: Type{T} )(x:: TwicePrecision ) where {T<: Reactant.ConcreteRNumber }
663
- return Reactant. ConcreteRNumber (T (x. hi) - T (x. lo)):: T
664
- end
665
-
666
- Base. nbitslen (r:: TracedStepRangeLen ) = Base. nbitslen (eltype (r), length (r), r. offset)
667
- function TracedStepRangeLen (
668
- ref:: TwicePrecision{T} , step:: TwicePrecision{T} , len, offset= 1
669
- ) where {T}
670
- return TracedStepRangeLen {T,TwicePrecision{T},TwicePrecision{T}} (ref, step, len, offset)
671
- end
672
- function Base. step (r:: TracedStepRangeLen{T,TwicePrecision{T},TwicePrecision{T}} ) where {T}
673
- return T (r. step)
674
- end
675
-
676
427
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
677
428
function Base. unsafe_getindex (
678
- r:: Union {
679
- Base. StepRangeLen{T,<: Base.TwicePrecision ,<: Base.TwicePrecision },
680
- TracedStepRangeLen{
681
- T,<: Base.TwicePrecision ,<: Base.TwicePrecision ,<: Base.TwicePrecision
682
- },
683
- },
429
+ r:: Base.StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision} ,
684
430
i:: TracedRNumber{<:Integer} ,
685
431
) where {T}
686
432
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@@ -703,9 +449,7 @@ function Base.unsafe_getindex(
703
449
end
704
450
705
451
function Base. searchsortedfirst (
706
- a:: AbstractRange{<:Union{Real,TracedRNumber}} ,
707
- x:: TracedRNumber{<:Real} ,
708
- o:: Base.DirectOrdering ,
452
+ a:: AbstractRange{<:Real} , x:: TracedRNumber{<:Real} , o:: Base.DirectOrdering
709
453
):: TracedRNumber{keytype(a)}
710
454
711
455
# require_one_based_indexing(a)
@@ -716,7 +460,7 @@ function Base.searchsortedfirst(
716
460
! Base. Order. lt (o, f, x),
717
461
1 ,
718
462
ifelse (
719
- ( h == 0 ) | Base. Order. lt (o, l, x),
463
+ h == 0 | | Base. Order. lt (o, l, x),
720
464
length (a) + 1 ,
721
465
ifelse (Base. Order. lt (o, a[n], x), n + 1 , n),
722
466
),
0 commit comments