@@ -10,6 +10,7 @@ using ..Reactant:
10
10
ReactantPrimitive,
11
11
WrappedTracedRArray,
12
12
AnyTracedRArray,
13
+ AnyTracedRVector,
13
14
Ops,
14
15
MLIR,
15
16
ancestor,
@@ -19,10 +20,12 @@ using ..Reactant:
19
20
using .. TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array
20
21
21
22
using ReactantCore: ReactantCore
22
- using GPUArraysCore: GPUArraysCore
23
+ using GPUArraysCore: GPUArraysCore, @allowscalar
23
24
24
25
ReactantCore. is_traced (:: TracedRArray ) = true
25
26
27
+ Base. strides (x:: TracedRArray ) = Base. size_to_strides (1 , size (x)... )
28
+
26
29
function Base. convert (:: Type{TracedRArray{T,N}} , x:: AbstractArray ) where {T,N}
27
30
@assert ndims (x) == N
28
31
if x isa TracedRArray
@@ -86,6 +89,17 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh
86
89
return idxs
87
90
end
88
91
92
+ function scalar_index_to_cartesian (idx:: T , sz:: NTuple{N,Int} ) where {T<: Number ,N}
93
+ idx = idx - 1
94
+ idxs = (idx % T (sz[1 ]),)
95
+ idx = idx ÷ T (sz[1 ])
96
+ for i in 2 : N
97
+ idxs = (idxs... , idx % T (sz[i]))
98
+ idx = idx ÷ T (sz[i])
99
+ end
100
+ return idxs
101
+ end
102
+
89
103
function Base. getindex (
90
104
a:: TracedRArray{T,N} , indices:: Union{Int,TracedRNumber{Int}}
91
105
) where {T,N}
@@ -509,7 +523,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
509
523
510
524
args = (TracedUtils. broadcast_to_size (Base. materialize (a), size (bc)) for a in bc. args)
511
525
512
- res = TracedUtils. elem_apply (bc. f, args... )
526
+ res = TracedUtils. promote_to (
527
+ TracedRArray{unwrapped_eltype (dest),ndims (dest)},
528
+ TracedUtils. elem_apply (bc. f, args... ),
529
+ )
513
530
TracedUtils. set_mlir_data! (dest, res. mlir_data)
514
531
return dest
515
532
end
@@ -687,4 +704,234 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
687
704
return cat (res... ; dims)
688
705
end
689
706
707
+ # sort
708
+ function Base. sort (x:: AnyTracedRArray ; alg= missing , order= missing , kwargs... )
709
+ return sort! (copy (x); alg, order, kwargs... )
710
+ end
711
+ function Base. sort (x:: AnyTracedRVector ; alg= missing , order= missing , kwargs... )
712
+ return sort! (copy (x); alg, order, dims= 1 , kwargs... )
713
+ end
714
+
715
+ function Base. sort! (
716
+ x:: AnyTracedRArray ;
717
+ dims:: Union{Integer,Nothing} = nothing ,
718
+ lt= isless,
719
+ by= identity,
720
+ rev:: Bool = false ,
721
+ alg= missing ,
722
+ order= missing ,
723
+ )
724
+ if dims === nothing
725
+ @assert ndims (x) == 1
726
+ dims = 1
727
+ end
728
+
729
+ @assert alg === missing " Reactant doesn't support `alg` kwarg for `sort!`"
730
+ @assert order === missing " Reactant doesn't support `order` kwarg for `sort!`"
731
+
732
+ comparator = rev ? (a, b) -> ! lt (by (a), by (b)) : (a, b) -> lt (by (a), by (b))
733
+ res = only (Ops. sort (materialize_traced_array (x); dimension= dims, comparator))
734
+ set_mlir_data! (x, get_mlir_data (res))
735
+ return x
736
+ end
737
+
738
+ function Base. sortperm (x:: AnyTracedRArray ; alg= missing , order= missing , kwargs... )
739
+ return sortperm! (similar (x, Int), x; alg, order, kwargs... )
740
+ end
741
+ function Base. sortperm (x:: AnyTracedRVector ; alg= missing , order= missing , kwargs... )
742
+ return sortperm! (similar (x, Int), x; alg, order, dims= 1 , kwargs... )
743
+ end
744
+
745
+ function Base. sortperm! (
746
+ ix:: AnyTracedRArray{Int,N} ,
747
+ x:: AnyTracedRArray{<:Any,N} ;
748
+ dims:: Union{Integer,Nothing} = nothing ,
749
+ lt= isless,
750
+ by= identity,
751
+ rev:: Bool = false ,
752
+ alg= missing ,
753
+ order= missing ,
754
+ ) where {N}
755
+ if dims === nothing
756
+ @assert ndims (x) == 1
757
+ dims = 1
758
+ end
759
+
760
+ @assert alg === missing " Reactant doesn't support `alg` kwarg for `sortperm!`"
761
+ @assert order === missing " Reactant doesn't support `order` kwarg for `sortperm!`"
762
+
763
+ comparator =
764
+ rev ? (a, b, i1, i2) -> ! lt (by (a), by (b)) : (a, b, i1, i2) -> lt (by (a), by (b))
765
+ idxs = Ops. constant (collect (LinearIndices (x)))
766
+ _, res = Ops. sort (materialize_traced_array (x), idxs; dimension= dims, comparator)
767
+ set_mlir_data! (ix, get_mlir_data (res))
768
+ return ix
769
+ end
770
+
771
+ function Base. partialsort (x:: AnyTracedRVector , k:: Union{Integer,OrdinalRange} ; kwargs... )
772
+ values, _ = overloaded_partialsort (x, k; kwargs... )
773
+ k = k .- minimum (k) .+ 1
774
+ k isa Integer && return @allowscalar (values[k])
775
+ return view (values, k)
776
+ end
777
+
778
+ function Base. partialsort! (x:: AnyTracedRVector , k:: Union{Integer,OrdinalRange} ; kwargs... )
779
+ values, _ = overloaded_partialsort (x, k; kwargs... )
780
+ kget = k .- minimum (k) .+ 1
781
+ val = @allowscalar (values[kget])
782
+ @allowscalar setindex! (x, val, k)
783
+ k isa Integer && return val
784
+ return view (x, k)
785
+ end
786
+
787
+ function Base. partialsortperm (
788
+ x:: AnyTracedRVector , k:: Union{Integer,OrdinalRange} ; kwargs...
789
+ )
790
+ idxs = overloaded_partialsort (x, k; kwargs... )[2 ]
791
+ k = k .- minimum (k) .+ 1
792
+ k isa Integer && return @allowscalar (idxs[k])
793
+ return view (idxs, k)
794
+ end
795
+
796
+ function Base. partialsortperm! (
797
+ ix:: AnyTracedRVector{Int} ,
798
+ x:: AnyTracedRVector ,
799
+ k:: Union{Integer,OrdinalRange} ;
800
+ kwargs... ,
801
+ )
802
+ _, idxs = overloaded_partialsort (x, k; kwargs... )
803
+ kget = k .- minimum (k) .+ 1
804
+ val = @allowscalar (idxs[kget])
805
+ @allowscalar setindex! (ix, val, k)
806
+ k isa Integer && return val
807
+ return view (ix, k)
808
+ end
809
+
810
+ function overloaded_partialsort (
811
+ x:: AnyTracedRVector ,
812
+ k:: Union{Integer,OrdinalRange} ;
813
+ by= identity,
814
+ rev:: Bool = false ,
815
+ lt= isless,
816
+ )
817
+ if lt != = isless || by != = identity
818
+ comparator =
819
+ rev ? (a, b, i1, i2) -> ! lt (by (a), by (b)) : (a, b, i1, i2) -> lt (by (a), by (b))
820
+ idxs = Ops. constant (collect (LinearIndices (x)))
821
+ sorted_x, sorted_idxs = Ops. sort (
822
+ materialize_traced_array (x), idxs; dimension= 1 , comparator
823
+ )
824
+ return sorted_x[1 : maximum (k)], sorted_idxs[1 : maximum (k)]
825
+ end
826
+
827
+ # XXX : If `maxk` is beyond a threshold should we emit a sort directly?
828
+ ! rev && (k = length (x) .- k .+ 1 )
829
+ ! (k isa Integer) && (k = maximum (k))
830
+ (; values, indices) = Ops. top_k (materialize_traced_array (x), k)
831
+ if ! rev
832
+ values = Ops. reverse (values; dimensions= [1 ])
833
+ indices = Ops. reverse (indices; dimensions= [1 ])
834
+ end
835
+ return values, indices
836
+ end
837
+
838
+ # arg* functions
839
+ function Base. argmin (f:: F , x:: AnyTracedRArray ) where {F}
840
+ idx = scalar_index_to_cartesian (argmin (f .(x)), size (x)) .+ 1
841
+ return @allowscalar x[idx... ]
842
+ end
843
+
844
+ function Base. argmax (f:: F , x:: AnyTracedRArray ) where {F}
845
+ idx = scalar_index_to_cartesian (argmax (f .(x)), size (x)) .+ 1
846
+ return @allowscalar x[idx... ]
847
+ end
848
+
849
+ Base. argmin (x:: AnyTracedRArray ; kwargs... ) = findmin (identity, x; kwargs... )[2 ]
850
+ Base. argmax (x:: AnyTracedRArray ; kwargs... ) = findmax (identity, x; kwargs... )[2 ]
851
+
852
+ # find* functions
853
+ Base. findfirst (x:: AnyTracedRArray ) = findfirst (identity, x)
854
+ Base. findlast (x:: AnyTracedRArray ) = findlast (identity, x)
855
+
856
+ function Base. findfirst (f:: Function , x:: AnyTracedRArray )
857
+ fA = materialize_traced_array (vec (f .(x)))
858
+ (; indices) = Ops. top_k (fA, 1 )
859
+ return @allowscalar indices[1 ]
860
+ end
861
+
862
+ function Base. findlast (f:: Function , x:: AnyTracedRArray )
863
+ fA = Ops. reverse (materialize_traced_array (vec (f .(x))); dimensions= [1 ])
864
+ (; indices) = Ops. top_k (fA, 1 )
865
+ return length (x) - @allowscalar (indices[1 ]) + 1
866
+ end
867
+
868
+ Base. findmin (x:: AnyTracedRVector ) = findmin (identity, x; dims= 1 )
869
+ function Base. findmin (x:: AnyTracedRArray ; dims:: Union{Integer,Nothing} = nothing )
870
+ return findmin (identity, x; dims)
871
+ end
872
+
873
+ Base. findmax (x:: AnyTracedRVector ) = findmax (identity, x; dims= 1 )
874
+ function Base. findmax (x:: AnyTracedRArray ; dims:: Union{Integer,Nothing} = nothing )
875
+ return findmax (identity, x; dims)
876
+ end
877
+
878
+ # # To avoid scalar indexing and constructing an array of tuples, we return the linear index
879
+ # # instead of the cartesian index
880
+ function Base. findmin (f, x:: AnyTracedRArray ; dims:: Union{Integer,Nothing} = nothing )
881
+ if dims === nothing
882
+ if ndims (x) == 1
883
+ dims = 1
884
+ else
885
+ return findmin (f, vec (x); dims= 1 )
886
+ end
887
+ end
888
+
889
+ fx = Ops. negate (materialize_traced_array (f .(x)))
890
+ (; values, indices) = Ops. top_k (fx, 1 ; dimension= dims)
891
+
892
+ # Compute linear indices
893
+ strds = strides (x)
894
+ iotas = [Ops. iota (Int64, [size (indices)... ]; iota_dimension= i) for i in 1 : ndims (x)]
895
+ iotas[dims] = Ops. subtract (indices, Ops. constant (fill (Int64 (1 ), size (indices))))
896
+ linear_indices = Ops. constant (fill (Int64 (1 ), size (indices)))
897
+ for d in eachindex (iotas)
898
+ linear_indices = Ops. add (
899
+ linear_indices,
900
+ Ops. multiply (iotas[d], Ops. constant (fill (Int64 (strds[d]), size (iotas[d])))),
901
+ )
902
+ end
903
+
904
+ values = Ops. negate (values)
905
+ ndims (x) == 1 && return @allowscalar (values[1 ], linear_indices[1 ])
906
+ return (values, linear_indices)
907
+ end
908
+
909
+ function Base. findmax (f, x:: AnyTracedRArray ; dims:: Union{Integer,Nothing} = nothing )
910
+ if dims === nothing
911
+ if ndims (x) == 1
912
+ dims = 1
913
+ else
914
+ return findmax (f, vec (x); dims= 1 )
915
+ end
916
+ end
917
+
918
+ fx = materialize_traced_array (f .(x))
919
+ (; values, indices) = Ops. top_k (fx, 1 ; dimension= dims)
920
+
921
+ # Compute linear indices
922
+ strds = strides (x)
923
+ iotas = [Ops. iota (Int64, [size (indices)... ]; iota_dimension= i) for i in 1 : ndims (x)]
924
+ iotas[dims] = Ops. subtract (indices, Ops. constant (fill (Int64 (1 ), size (indices))))
925
+ linear_indices = Ops. constant (fill (Int64 (1 ), size (indices)))
926
+ for d in eachindex (iotas)
927
+ linear_indices = Ops. add (
928
+ linear_indices,
929
+ Ops. multiply (iotas[d], Ops. constant (fill (Int64 (strds[d]), size (iotas[d])))),
930
+ )
931
+ end
932
+
933
+ ndims (x) == 1 && return @allowscalar (values[1 ], linear_indices[1 ])
934
+ return (values, linear_indices)
935
+ end
936
+
690
937
end
0 commit comments