diff --git a/docs/make.jl b/docs/make.jl index 573fe9fc1..3ade9df04 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,6 +12,7 @@ makedocs(; "sparsearrays.md", "tuples.md", "wrapping.md", + "index_labels.md", ] ) diff --git a/docs/src/index_labels.md b/docs/src/index_labels.md new file mode 100644 index 000000000..884d30cef --- /dev/null +++ b/docs/src/index_labels.md @@ -0,0 +1,9 @@ +# Index Labels Interface + +The following ArrayInterface functions provide support for indices with labels. + +```@docs +ArrayInterface.has_index_labels +ArrayInterface.index_labels +``` + diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index ed616a87f..a6ea297f8 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -1030,6 +1030,55 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x)) +const INDEX_LABELS_EXTENDED_HELP = """ +## Extended help + +Structures that explicitly provide labels along their indices must define both +`has_index_labels` and `index_labels`. Wrappers that don't change the layout +of their parent data and define `is_forwarding_wrapper` will propagate these methods +freely, but all other wrappers must define these two methods in order to propagate +labelled indices information. + +Labeled indices are expected to hold the following properties: +* `length(index_labels(x)) == ndims(x)` +* `map(length, index_labels(x)) == size(x)` +""" + +""" + has_index_labels(T::Type) -> Bool + +Returns `true` if instances of `T` have labeled indices. Structures overloading this +method are also responsible for defining [`ArrayInterface.index_labels`](@ref). + +$INDEX_LABELS_EXTENDED_HELP +""" +function has_index_labels(T::Type) + is_forwarding_wrapper(T) ? has_index_labels(parent_type(T)) : false +end + +""" + index_labels(x) -> Tuple{Vararg{Any, ndims(x)}} + index_labels(x, dim) -> itr + +Returns a tuple of labels assigned to each axis or a collection of labels corresponding to +each index along `dim` of `x`. + +$INDEX_LABELS_EXTENDED_HELP +""" +function index_labels(x::T) where {T} + has_index_labels(T) || _throw_index_labels(T) + is_forwarding_wrapper(T) || _violated_index_label_interface(T) + return index_labels(parent(x)) +end +index_labels(x, dim::Integer) = index_labels(x)[Int(dim)] + +@noinline function _throw_index_labels(T::DataType) + throw(ArgumentError("Objects of type $T do not support `index_labels`")) +end +@noinline function _violated_index_label_interface(T::DataType) + throw(ArgumentError("`has_index_labels($(T)) == true` but does not have `ArrayInterface.index_labels(::$T)` defined.")) +end + ## Extensions import Requires diff --git a/test/core.jl b/test/core.jl index bd0cd6cf3..53196eb7b 100644 --- a/test/core.jl +++ b/test/core.jl @@ -7,6 +7,21 @@ using Random using SparseArrays using Test +struct LabeledIndicesArray{T,N,P<:AbstractArray{T,N},L} <: AbstractArray{T,N} + parent::P + labels::L + + LabeledIndicesArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels) +end +ArrayInterface.is_forwarding_wrapper(::Type{<:LabeledIndicesArray}) = true +Base.parent(x::LabeledIndicesArray) = getfield(x, :parent) +ArrayInterface.parent_type(::Type{T}) where {P,T<:LabeledIndicesArray{<:Any,<:Any,P}} = P +ArrayInterface.index_labels(x::LabeledIndicesArray) = getfield(x, :labels) +ArrayInterface.has_index_labels(T::Type{<:LabeledIndicesArray}) = true +ArrayInterface.is_forwarding_wrapper(::Type{<:LabeledIndicesArray}) = true +Base.size(x::LabeledIndicesArray) = size(parent(x)) +Base.@propagate_inbounds Base.getindex(x::LabeledIndicesArray, inds...) = parent(x)[inds...] + # ensure we are correctly parsing these ArrayInterface.@assume_effects :total foo(x::Bool) = x ArrayInterface.@assume_effects bar(x::Bool) = x @@ -282,4 +297,22 @@ end end @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) end -end \ No newline at end of file +end + +@testset "index_labels" begin + a = ones(2, 3) + lia = LabeledIndicesArray(a, ([:a, :b], ["x", "y", "z"])) + + @test @inferred(ArrayInterface.has_index_labels(typeof(lia))) + @test !@inferred(ArrayInterface.has_index_labels(typeof(a))) + + @test @inferred(ArrayInterface.index_labels(lia)) == lia.labels + @test ArrayInterface.index_labels(lia, 1) == lia.labels[1] + @test_throws ArgumentError ArrayInterface.index_labels(a) + + # throw errors when interface isn't implemented correctly + struct IllegalLabelledIndices end + ArrayInterface.has_index_labels(::Type{IllegalLabelledIndices}) = true + @test_throws ArgumentError ArrayInterface.index_labels(IllegalLabelledIndices()) +end +