Skip to content

Commit dafa186

Browse files
committed
test: sort and partial sort functions
1 parent c98e2c3 commit dafa186

File tree

2 files changed

+140
-24
lines changed

2 files changed

+140
-24
lines changed

src/TracedRArray.jl

+38-20
Original file line numberDiff line numberDiff line change
@@ -714,13 +714,18 @@ end
714714

715715
function Base.sort!(
716716
x::AnyTracedRArray;
717-
dims::Integer,
717+
dims::Union{Integer,Nothing}=nothing,
718718
lt=isless,
719719
by=identity,
720720
rev::Bool=false,
721721
alg=missing,
722722
order=missing,
723723
)
724+
if dims === nothing
725+
@assert ndims(x) == 1
726+
dims = 1
727+
end
728+
724729
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
725730
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"
726731

@@ -740,13 +745,18 @@ end
740745
function Base.sortperm!(
741746
ix::AnyTracedRArray{Int,N},
742747
x::AnyTracedRArray{<:Any,N};
743-
dims::Integer,
748+
dims::Union{Integer,Nothing}=nothing,
744749
lt=isless,
745750
by=identity,
746751
rev::Bool=false,
747752
alg=missing,
748753
order=missing,
749754
) where {N}
755+
if dims === nothing
756+
@assert ndims(x) == 1
757+
dims = 1
758+
end
759+
750760
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
751761
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"
752762

@@ -761,6 +771,7 @@ end
761771
function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
762772
values, _ = overloaded_partialsort(x, k; kwargs...)
763773
k = k .- minimum(k) .+ 1
774+
k isa Integer && return @allowscalar(values[k])
764775
return view(values, k)
765776
end
766777

@@ -769,7 +780,31 @@ function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange};
769780
kget = k .- minimum(k) .+ 1
770781
val = @allowscalar(values[kget])
771782
@allowscalar setindex!(x, val, k)
772-
return val
783+
k isa Integer && return val
784+
return view(x, k)
785+
end
786+
787+
function Base.partialsortperm(
788+
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
789+
)
790+
idxs = overloaded_partialsort(x, k; kwargs...)[2]
791+
k = k .- minimum(k) .+ 1
792+
k isa Integer && return @allowscalar(idxs[k])
793+
return view(idxs, k)
794+
end
795+
796+
function Base.partialsortperm!(
797+
ix::AnyTracedRVector{Int},
798+
x::AnyTracedRVector,
799+
k::Union{Integer,OrdinalRange};
800+
kwargs...,
801+
)
802+
_, idxs = overloaded_partialsort(x, k; kwargs...)
803+
kget = k .- minimum(k) .+ 1
804+
val = @allowscalar(idxs[kget])
805+
@allowscalar setindex!(ix, val, k)
806+
k isa Integer && return val
807+
return view(ix, k)
773808
end
774809

775810
function overloaded_partialsort(
@@ -800,23 +835,6 @@ function overloaded_partialsort(
800835
return values, indices
801836
end
802837

803-
function Base.partialsortperm(
804-
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
805-
)
806-
return view(overloaded_partialsort(x, k; kwargs...)[2], k)
807-
end
808-
809-
function Base.partialsortperm!(
810-
ix::AnyTracedRVector{Int},
811-
x::AnyTracedRVector,
812-
k::Union{Integer,OrdinalRange};
813-
kwargs...,
814-
)
815-
_, idxs = overloaded_partialsort(x, k; kwargs...)
816-
@allowscalar setindex!(ix, idxs[k], k)
817-
return view(ix, k)
818-
end
819-
820838
# arg* functions
821839
function Base.argmin(f::F, x::AnyTracedRArray) where {F}
822840
idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1

test/sorting.jl

+102-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,110 @@
11
using Reactant, Test
22

3-
@testset "sort" begin end
3+
@testset "sort & sortperm" begin
4+
x = randn(10)
5+
x_ra = Reactant.to_rarray(x)
46

5-
@testset "sortperm" begin end
7+
srt_rev(x) = sort(x; rev=true)
8+
srtperm_rev(x) = sortperm(x; rev=true)
9+
srt_by(x) = sort(x; by=abs2)
10+
srtperm_by(x) = sortperm(x; by=abs2)
11+
srt_lt(x) = sort(x; lt=(a, b) -> a > b)
12+
srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b)
13+
14+
@test @jit(sort(x_ra)) == sort(x)
15+
@test @jit(srt_rev(x_ra)) == srt_rev(x)
16+
@test @jit(srt_lt(x_ra)) == srt_lt(x)
17+
@test @jit(srt_by(x_ra)) == srt_by(x)
18+
@test @jit(sortperm(x_ra)) == sortperm(x)
19+
@test @jit(srtperm_rev(x_ra)) == srtperm_rev(x)
20+
@test @jit(srtperm_lt(x_ra)) == srtperm_lt(x)
21+
@test @jit(srtperm_by(x_ra)) == srtperm_by(x)
22+
23+
x = rand(10)
24+
x_ra = Reactant.to_rarray(x)
25+
@jit sort!(x_ra)
26+
@test x_ra == sort(x)
627

7-
@testset "partialsort" begin end
28+
x = rand(10)
29+
x_ra = Reactant.to_rarray(x)
30+
ix = similar(x_ra, Int)
31+
@jit sortperm!(ix, x_ra)
32+
@test ix == sortperm(x)
33+
34+
x = rand(10, 4, 3)
35+
x_ra = Reactant.to_rarray(x)
836

9-
@testset "partialsortperm" begin end
37+
srt(x, d) = sort(x; dims=d)
38+
srt_rev(x, d) = sort(x; dims=d, rev=true)
39+
srt_by(x, d) = sort(x; dims=d, by=abs2)
40+
srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b)
41+
srtperm(x, d) = sortperm(x; dims=d)
42+
srtperm_rev(x, d) = sortperm(x; dims=d, rev=true)
43+
srtperm_by(x, d) = sortperm(x; dims=d, by=abs2)
44+
srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b)
45+
46+
@testset for d in 1:ndims(x)
47+
@test @jit(srt(x_ra, d)) == srt(x, d)
48+
@test @jit(srtperm(x_ra, d)) == srtperm(x, d)
49+
@test @jit(srt_rev(x_ra, d)) == srt_rev(x, d)
50+
@test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d)
51+
@test @jit(srt_by(x_ra, d)) == srt_by(x, d)
52+
@test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d)
53+
@test @jit(srt_lt(x_ra, d)) == srt_lt(x, d)
54+
@test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d)
55+
end
56+
end
57+
58+
@testset "partialsort & partialsortperm" begin
59+
x = randn(10)
60+
x_ra = Reactant.to_rarray(x)
61+
62+
@test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5)
63+
@test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5)
64+
@test @jit(partialsort(x_ra, 4)) == partialsort(x, 4)
65+
@test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4)
66+
67+
psrt_rev(x, k) = partialsort(x, k; rev=true)
68+
psrtperm_rev(x, k) = partialsortperm(x, k; rev=true)
69+
psrt_by(x, k) = partialsort(x, k; by=abs2)
70+
psrtperm_by(x, k) = partialsortperm(x, k; by=abs2)
71+
psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b)
72+
psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b)
73+
74+
@test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5)
75+
@test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5)
76+
@test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5)
77+
@test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5)
78+
@test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5)
79+
@test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5)
80+
81+
x = randn(10)
82+
x_ra = Reactant.to_rarray(x)
83+
@jit partialsort!(x_ra, 1:5)
84+
partialsort!(x, 1:5)
85+
@test Array(x_ra)[1:5] == x[1:5]
86+
87+
x = randn(10)
88+
x_ra = Reactant.to_rarray(x)
89+
@jit partialsort!(x_ra, 3)
90+
partialsort!(x, 3)
91+
@test @allowscalar(x_ra[3]) == x[3]
92+
93+
x = randn(10)
94+
x_ra = Reactant.to_rarray(x)
95+
96+
ix = similar(x, Int)
97+
ix_ra = Reactant.to_rarray(ix)
98+
@jit partialsortperm!(ix_ra, x_ra, 1:5)
99+
partialsortperm!(ix, x, 1:5)
100+
@test Array(ix_ra)[1:5] == ix[1:5]
101+
102+
ix = similar(x, Int)
103+
ix_ra = Reactant.to_rarray(ix)
104+
@jit partialsortperm!(ix_ra, x_ra, 3)
105+
partialsortperm!(ix, x, 3)
106+
@test @allowscalar(ix_ra[3]) == ix[3]
107+
end
10108

11109
@testset "argmin / argmax" begin
12110
x = rand(2, 3)

0 commit comments

Comments
 (0)