Skip to content

Commit 4fd0492

Browse files
authored
feat: sorting and related functions (#529)
* feat: implement sort * feat: generalize Ops.sort to take in multiple args * feat: implement perm related functions * feat: implement partialsort * feat: implement argmin and argmax * fix: general support for other kwargs * feat: keep lazy indexing * feat: support lt and by by directly emitting sort * feat: more argmin/argmax support and testing * fix: always return tuple from sort * feat: findmin/findmax/findlast/findfirst * fix: more tests and fixes for find functions * test: sort and partial sort functions
1 parent b436b48 commit 4fd0492

File tree

7 files changed

+514
-24
lines changed

7 files changed

+514
-24
lines changed

Diff for: src/ConcreteRArray.jl

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T}
2828
return ConcreteRNumber(Base.rtoldefault(T))
2929
end
3030

31+
Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...)
32+
3133
# Ensure the device and client are the same as the input
3234
function Base.float(x::ConcreteRNumber{T}) where {T}
3335
client = XLA.client(x.data)

Diff for: src/Ops.jl

+49-20
Original file line numberDiff line numberDiff line change
@@ -956,19 +956,26 @@ function broadcast_in_dim(
956956
end
957957

958958
@noinline function sort(
959-
x::TracedRArray{T,N};
959+
xs::TracedRArray...;
960960
comparator,
961961
dimension=1,
962962
is_stable=false,
963963
location=mlir_stacktrace("sort", @__FILE__, @__LINE__),
964-
) where {T,N}
964+
)
965965
#C4:
966-
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
966+
for x in xs
967+
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
968+
end
967969

968-
(a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0)))
970+
sample_inputs = Vector{Reactant.ConcreteRNumber}(undef, length(xs) * 2)
971+
for i in eachindex(xs)
972+
T = Reactant.unwrapped_eltype(xs[i])
973+
sample_inputs[2i - 1] = Reactant.ConcreteRNumber(T(0))
974+
sample_inputs[2i] = Reactant.ConcreteRNumber(T(0))
975+
end
969976
func = Reactant.TracedUtils.make_mlir_fn(
970977
comparator,
971-
(a, b),
978+
(sample_inputs...,),
972979
(),
973980
"comparator";
974981
no_args_in_result=true,
@@ -993,30 +1000,52 @@ end
9931000
dimension = MLIR.IR.Attribute(dimension - 1)
9941001
is_stable = MLIR.IR.Attribute(is_stable)
9951002

996-
res = MLIR.IR.result(
997-
stablehlo.sort(
998-
[x.mlir_data];
999-
result_0=[mlir_type(TracedRArray{T,N}, size(x))],
1000-
dimension,
1001-
is_stable,
1002-
comparator,
1003-
location,
1004-
),
1003+
op = stablehlo.sort(
1004+
[x.mlir_data for x in xs];
1005+
result_0=[mlir_type(typeof(x), size(x)) for x in xs],
1006+
dimension,
1007+
is_stable,
1008+
comparator,
1009+
location,
10051010
)
1006-
return TracedRArray{T,N}((), res, size(x))
1011+
return [
1012+
TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}(
1013+
(), MLIR.IR.result(op, i), size(xs[i])
1014+
) for i in eachindex(xs)
1015+
]
10071016
end
10081017

10091018
@noinline function top_k(
1010-
x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__)
1019+
x::TracedRArray{T,N},
1020+
k;
1021+
dimension::Integer=N,
1022+
location=mlir_stacktrace("top_k", @__FILE__, @__LINE__),
10111023
) where {T,N}
1024+
@assert 1 <= dimension <= N
1025+
if dimension != N # chlo.top_k performs the operation along the last dimension
1026+
pdims = collect(Int64, 1:N)
1027+
pdims[dimension] = N
1028+
pdims[N] = dimension
1029+
x = permutedims(x, pdims)
1030+
end
1031+
10121032
rsize = [size(x)[1:(end - 1)]..., k]
10131033
values = mlir_type(TracedRArray{T,N}, rsize)
10141034
indices = mlir_type(TracedRArray{Int32,N}, rsize)
10151035
op = chlo.top_k(x.mlir_data; values, indices, k, location)
1016-
return (;
1017-
values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize),
1018-
indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
1019-
)
1036+
indices = add(
1037+
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
1038+
constant(fill(Int32(1), Tuple(rsize))),
1039+
) # return the 1-indexed index
1040+
indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally
1041+
values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize)
1042+
1043+
if dimension != N
1044+
values = permutedims(values, invperm(pdims))
1045+
indices = permutedims(indices, invperm(pdims))
1046+
end
1047+
1048+
return (; values, indices)
10201049
end
10211050

10221051
@noinline function iota(

Diff for: src/TracedRArray.jl

+249-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using ..Reactant:
1010
ReactantPrimitive,
1111
WrappedTracedRArray,
1212
AnyTracedRArray,
13+
AnyTracedRVector,
1314
Ops,
1415
MLIR,
1516
ancestor,
@@ -19,10 +20,12 @@ using ..Reactant:
1920
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array
2021

2122
using ReactantCore: ReactantCore
22-
using GPUArraysCore: GPUArraysCore
23+
using GPUArraysCore: GPUArraysCore, @allowscalar
2324

2425
ReactantCore.is_traced(::TracedRArray) = true
2526

27+
Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)
28+
2629
function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
2730
@assert ndims(x) == N
2831
if x isa TracedRArray
@@ -86,6 +89,17 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh
8689
return idxs
8790
end
8891

92+
function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T<:Number,N}
93+
idx = idx - 1
94+
idxs = (idx % T(sz[1]),)
95+
idx = idx ÷ T(sz[1])
96+
for i in 2:N
97+
idxs = (idxs..., idx % T(sz[i]))
98+
idx = idx ÷ T(sz[i])
99+
end
100+
return idxs
101+
end
102+
89103
function Base.getindex(
90104
a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}}
91105
) where {T,N}
@@ -509,7 +523,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)
509523

510524
args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
511525

512-
res = TracedUtils.elem_apply(bc.f, args...)
526+
res = TracedUtils.promote_to(
527+
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
528+
TracedUtils.elem_apply(bc.f, args...),
529+
)
513530
TracedUtils.set_mlir_data!(dest, res.mlir_data)
514531
return dest
515532
end
@@ -687,4 +704,234 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
687704
return cat(res...; dims)
688705
end
689706

707+
# sort
708+
function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
709+
return sort!(copy(x); alg, order, kwargs...)
710+
end
711+
function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
712+
return sort!(copy(x); alg, order, dims=1, kwargs...)
713+
end
714+
715+
function Base.sort!(
716+
x::AnyTracedRArray;
717+
dims::Union{Integer,Nothing}=nothing,
718+
lt=isless,
719+
by=identity,
720+
rev::Bool=false,
721+
alg=missing,
722+
order=missing,
723+
)
724+
if dims === nothing
725+
@assert ndims(x) == 1
726+
dims = 1
727+
end
728+
729+
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
730+
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"
731+
732+
comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b))
733+
res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator))
734+
set_mlir_data!(x, get_mlir_data(res))
735+
return x
736+
end
737+
738+
function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
739+
return sortperm!(similar(x, Int), x; alg, order, kwargs...)
740+
end
741+
function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
742+
return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...)
743+
end
744+
745+
function Base.sortperm!(
746+
ix::AnyTracedRArray{Int,N},
747+
x::AnyTracedRArray{<:Any,N};
748+
dims::Union{Integer,Nothing}=nothing,
749+
lt=isless,
750+
by=identity,
751+
rev::Bool=false,
752+
alg=missing,
753+
order=missing,
754+
) where {N}
755+
if dims === nothing
756+
@assert ndims(x) == 1
757+
dims = 1
758+
end
759+
760+
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
761+
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"
762+
763+
comparator =
764+
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
765+
idxs = Ops.constant(collect(LinearIndices(x)))
766+
_, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator)
767+
set_mlir_data!(ix, get_mlir_data(res))
768+
return ix
769+
end
770+
771+
function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
772+
values, _ = overloaded_partialsort(x, k; kwargs...)
773+
k = k .- minimum(k) .+ 1
774+
k isa Integer && return @allowscalar(values[k])
775+
return view(values, k)
776+
end
777+
778+
function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
779+
values, _ = overloaded_partialsort(x, k; kwargs...)
780+
kget = k .- minimum(k) .+ 1
781+
val = @allowscalar(values[kget])
782+
@allowscalar setindex!(x, val, k)
783+
k isa Integer && return val
784+
return view(x, k)
785+
end
786+
787+
function Base.partialsortperm(
788+
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
789+
)
790+
idxs = overloaded_partialsort(x, k; kwargs...)[2]
791+
k = k .- minimum(k) .+ 1
792+
k isa Integer && return @allowscalar(idxs[k])
793+
return view(idxs, k)
794+
end
795+
796+
function Base.partialsortperm!(
797+
ix::AnyTracedRVector{Int},
798+
x::AnyTracedRVector,
799+
k::Union{Integer,OrdinalRange};
800+
kwargs...,
801+
)
802+
_, idxs = overloaded_partialsort(x, k; kwargs...)
803+
kget = k .- minimum(k) .+ 1
804+
val = @allowscalar(idxs[kget])
805+
@allowscalar setindex!(ix, val, k)
806+
k isa Integer && return val
807+
return view(ix, k)
808+
end
809+
810+
function overloaded_partialsort(
811+
x::AnyTracedRVector,
812+
k::Union{Integer,OrdinalRange};
813+
by=identity,
814+
rev::Bool=false,
815+
lt=isless,
816+
)
817+
if lt !== isless || by !== identity
818+
comparator =
819+
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
820+
idxs = Ops.constant(collect(LinearIndices(x)))
821+
sorted_x, sorted_idxs = Ops.sort(
822+
materialize_traced_array(x), idxs; dimension=1, comparator
823+
)
824+
return sorted_x[1:maximum(k)], sorted_idxs[1:maximum(k)]
825+
end
826+
827+
# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
828+
!rev && (k = length(x) .- k .+ 1)
829+
!(k isa Integer) && (k = maximum(k))
830+
(; values, indices) = Ops.top_k(materialize_traced_array(x), k)
831+
if !rev
832+
values = Ops.reverse(values; dimensions=[1])
833+
indices = Ops.reverse(indices; dimensions=[1])
834+
end
835+
return values, indices
836+
end
837+
838+
# arg* functions
839+
function Base.argmin(f::F, x::AnyTracedRArray) where {F}
840+
idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1
841+
return @allowscalar x[idx...]
842+
end
843+
844+
function Base.argmax(f::F, x::AnyTracedRArray) where {F}
845+
idx = scalar_index_to_cartesian(argmax(f.(x)), size(x)) .+ 1
846+
return @allowscalar x[idx...]
847+
end
848+
849+
Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2]
850+
Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2]
851+
852+
# find* functions
853+
Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x)
854+
Base.findlast(x::AnyTracedRArray) = findlast(identity, x)
855+
856+
function Base.findfirst(f::Function, x::AnyTracedRArray)
857+
fA = materialize_traced_array(vec(f.(x)))
858+
(; indices) = Ops.top_k(fA, 1)
859+
return @allowscalar indices[1]
860+
end
861+
862+
function Base.findlast(f::Function, x::AnyTracedRArray)
863+
fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1])
864+
(; indices) = Ops.top_k(fA, 1)
865+
return length(x) - @allowscalar(indices[1]) + 1
866+
end
867+
868+
Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1)
869+
function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
870+
return findmin(identity, x; dims)
871+
end
872+
873+
Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1)
874+
function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
875+
return findmax(identity, x; dims)
876+
end
877+
878+
## To avoid scalar indexing and constructing an array of tuples, we return the linear index
879+
## instead of the cartesian index
880+
function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
881+
if dims === nothing
882+
if ndims(x) == 1
883+
dims = 1
884+
else
885+
return findmin(f, vec(x); dims=1)
886+
end
887+
end
888+
889+
fx = Ops.negate(materialize_traced_array(f.(x)))
890+
(; values, indices) = Ops.top_k(fx, 1; dimension=dims)
891+
892+
# Compute linear indices
893+
strds = strides(x)
894+
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
895+
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
896+
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
897+
for d in eachindex(iotas)
898+
linear_indices = Ops.add(
899+
linear_indices,
900+
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
901+
)
902+
end
903+
904+
values = Ops.negate(values)
905+
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
906+
return (values, linear_indices)
907+
end
908+
909+
function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
910+
if dims === nothing
911+
if ndims(x) == 1
912+
dims = 1
913+
else
914+
return findmax(f, vec(x); dims=1)
915+
end
916+
end
917+
918+
fx = materialize_traced_array(f.(x))
919+
(; values, indices) = Ops.top_k(fx, 1; dimension=dims)
920+
921+
# Compute linear indices
922+
strds = strides(x)
923+
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
924+
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
925+
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
926+
for d in eachindex(iotas)
927+
linear_indices = Ops.add(
928+
linear_indices,
929+
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
930+
)
931+
end
932+
933+
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
934+
return (values, linear_indices)
935+
end
936+
690937
end

0 commit comments

Comments
 (0)