|
1 | 1 | using Reactant, Test
|
2 | 2 |
|
3 |
| -@testset "sort" begin end |
| 3 | +@testset "sort & sortperm" begin |
| 4 | + x = randn(10) |
| 5 | + x_ra = Reactant.to_rarray(x) |
4 | 6 |
|
5 |
| -@testset "sortperm" begin end |
| 7 | + srt_rev(x) = sort(x; rev=true) |
| 8 | + srtperm_rev(x) = sortperm(x; rev=true) |
| 9 | + srt_by(x) = sort(x; by=abs2) |
| 10 | + srtperm_by(x) = sortperm(x; by=abs2) |
| 11 | + srt_lt(x) = sort(x; lt=(a, b) -> a > b) |
| 12 | + srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b) |
| 13 | + |
| 14 | + @test @jit(sort(x_ra)) == sort(x) |
| 15 | + @test @jit(srt_rev(x_ra)) == srt_rev(x) |
| 16 | + @test @jit(srt_lt(x_ra)) == srt_lt(x) |
| 17 | + @test @jit(srt_by(x_ra)) == srt_by(x) |
| 18 | + @test @jit(sortperm(x_ra)) == sortperm(x) |
| 19 | + @test @jit(srtperm_rev(x_ra)) == srtperm_rev(x) |
| 20 | + @test @jit(srtperm_lt(x_ra)) == srtperm_lt(x) |
| 21 | + @test @jit(srtperm_by(x_ra)) == srtperm_by(x) |
| 22 | + |
| 23 | + x = rand(10) |
| 24 | + x_ra = Reactant.to_rarray(x) |
| 25 | + @jit sort!(x_ra) |
| 26 | + @test x_ra == sort(x) |
6 | 27 |
|
7 |
| -@testset "partialsort" begin end |
| 28 | + x = rand(10) |
| 29 | + x_ra = Reactant.to_rarray(x) |
| 30 | + ix = similar(x_ra, Int) |
| 31 | + @jit sortperm!(ix, x_ra) |
| 32 | + @test ix == sortperm(x) |
| 33 | + |
| 34 | + x = rand(10, 4, 3) |
| 35 | + x_ra = Reactant.to_rarray(x) |
8 | 36 |
|
9 |
| -@testset "partialsortperm" begin end |
| 37 | + srt(x, d) = sort(x; dims=d) |
| 38 | + srt_rev(x, d) = sort(x; dims=d, rev=true) |
| 39 | + srt_by(x, d) = sort(x; dims=d, by=abs2) |
| 40 | + srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b) |
| 41 | + srtperm(x, d) = sortperm(x; dims=d) |
| 42 | + srtperm_rev(x, d) = sortperm(x; dims=d, rev=true) |
| 43 | + srtperm_by(x, d) = sortperm(x; dims=d, by=abs2) |
| 44 | + srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b) |
| 45 | + |
| 46 | + @testset for d in 1:ndims(x) |
| 47 | + @test @jit(srt(x_ra, d)) == srt(x, d) |
| 48 | + @test @jit(srtperm(x_ra, d)) == srtperm(x, d) |
| 49 | + @test @jit(srt_rev(x_ra, d)) == srt_rev(x, d) |
| 50 | + @test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d) |
| 51 | + @test @jit(srt_by(x_ra, d)) == srt_by(x, d) |
| 52 | + @test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d) |
| 53 | + @test @jit(srt_lt(x_ra, d)) == srt_lt(x, d) |
| 54 | + @test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d) |
| 55 | + end |
| 56 | +end |
| 57 | + |
| 58 | +@testset "partialsort & partialsortperm" begin |
| 59 | + x = randn(10) |
| 60 | + x_ra = Reactant.to_rarray(x) |
| 61 | + |
| 62 | + @test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5) |
| 63 | + @test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5) |
| 64 | + @test @jit(partialsort(x_ra, 4)) == partialsort(x, 4) |
| 65 | + @test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4) |
| 66 | + |
| 67 | + psrt_rev(x, k) = partialsort(x, k; rev=true) |
| 68 | + psrtperm_rev(x, k) = partialsortperm(x, k; rev=true) |
| 69 | + psrt_by(x, k) = partialsort(x, k; by=abs2) |
| 70 | + psrtperm_by(x, k) = partialsortperm(x, k; by=abs2) |
| 71 | + psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b) |
| 72 | + psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b) |
| 73 | + |
| 74 | + @test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5) |
| 75 | + @test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5) |
| 76 | + @test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5) |
| 77 | + @test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5) |
| 78 | + @test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5) |
| 79 | + @test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5) |
| 80 | + |
| 81 | + x = randn(10) |
| 82 | + x_ra = Reactant.to_rarray(x) |
| 83 | + @jit partialsort!(x_ra, 1:5) |
| 84 | + partialsort!(x, 1:5) |
| 85 | + @test Array(x_ra)[1:5] == x[1:5] |
| 86 | + |
| 87 | + x = randn(10) |
| 88 | + x_ra = Reactant.to_rarray(x) |
| 89 | + @jit partialsort!(x_ra, 3) |
| 90 | + partialsort!(x, 3) |
| 91 | + @test @allowscalar(x_ra[3]) == x[3] |
| 92 | + |
| 93 | + x = randn(10) |
| 94 | + x_ra = Reactant.to_rarray(x) |
| 95 | + |
| 96 | + ix = similar(x, Int) |
| 97 | + ix_ra = Reactant.to_rarray(ix) |
| 98 | + @jit partialsortperm!(ix_ra, x_ra, 1:5) |
| 99 | + partialsortperm!(ix, x, 1:5) |
| 100 | + @test Array(ix_ra)[1:5] == ix[1:5] |
| 101 | + |
| 102 | + ix = similar(x, Int) |
| 103 | + ix_ra = Reactant.to_rarray(ix) |
| 104 | + @jit partialsortperm!(ix_ra, x_ra, 3) |
| 105 | + partialsortperm!(ix, x, 3) |
| 106 | + @test @allowscalar(ix_ra[3]) == ix[3] |
| 107 | +end |
10 | 108 |
|
11 | 109 | @testset "argmin / argmax" begin
|
12 | 110 | x = rand(2, 3)
|
|
0 commit comments