Skip to content

Minimal interface for index labels #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ makedocs(;
"sparsearrays.md",
"tuples.md",
"wrapping.md",
"index_labels.md",
]
)

Expand Down
9 changes: 9 additions & 0 deletions docs/src/index_labels.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Index Labels Interface

The following ArrayInterface functions provide support for indices with labels.

```@docs
ArrayInterface.has_index_labels
ArrayInterface.index_labels
```

49 changes: 49 additions & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,55 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true
ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false
ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x))

const INDEX_LABELS_EXTENDED_HELP = """
## Extended help

Structures that explicitly provide labels along their indices must define both
`has_index_labels` and `index_labels`. Wrappers that don't change the layout
of their parent data and define `is_forwarding_wrapper` will propagate these methods
freely, but all other wrappers must define these two methods in order to propagate
labelled indices information.

Labeled indices are expected to hold the following properties:
* `length(index_labels(x)) == ndims(x)`
* `map(length, index_labels(x)) == size(x)`
"""

"""
has_index_labels(T::Type) -> Bool

Returns `true` if instances of `T` have labeled indices. Structures overloading this
method are also responsible for defining [`ArrayInterface.index_labels`](@ref).

$INDEX_LABELS_EXTENDED_HELP
"""
function has_index_labels(T::Type)
is_forwarding_wrapper(T) ? has_index_labels(parent_type(T)) : false
end

"""
index_labels(x) -> Tuple{Vararg{Any, ndims(x)}}
index_labels(x, dim) -> itr

Returns a tuple of labels assigned to each axis or a collection of labels corresponding to
each index along `dim` of `x`.

$INDEX_LABELS_EXTENDED_HELP
"""
function index_labels(x::T) where {T}
has_index_labels(T) || _throw_index_labels(T)
is_forwarding_wrapper(T) || _violated_index_label_interface(T)
return index_labels(parent(x))
end
index_labels(x, dim::Integer) = index_labels(x)[Int(dim)]

@noinline function _throw_index_labels(T::DataType)
throw(ArgumentError("Objects of type $T do not support `index_labels`"))
end
@noinline function _violated_index_label_interface(T::DataType)
throw(ArgumentError("`has_index_labels($(T)) == true` but does not have `ArrayInterface.index_labels(::$T)` defined."))
end

## Extensions

import Requires
Expand Down
35 changes: 34 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@ using Random
using SparseArrays
using Test

struct LabeledIndicesArray{T,N,P<:AbstractArray{T,N},L} <: AbstractArray{T,N}
parent::P
labels::L

LabeledIndicesArray(p::P, labels::L) where {P,L} = new{eltype(P),ndims(p),P,L}(p, labels)
end
ArrayInterface.is_forwarding_wrapper(::Type{<:LabeledIndicesArray}) = true
Base.parent(x::LabeledIndicesArray) = getfield(x, :parent)
ArrayInterface.parent_type(::Type{T}) where {P,T<:LabeledIndicesArray{<:Any,<:Any,P}} = P
ArrayInterface.index_labels(x::LabeledIndicesArray) = getfield(x, :labels)
ArrayInterface.has_index_labels(T::Type{<:LabeledIndicesArray}) = true
ArrayInterface.is_forwarding_wrapper(::Type{<:LabeledIndicesArray}) = true
Base.size(x::LabeledIndicesArray) = size(parent(x))
Base.@propagate_inbounds Base.getindex(x::LabeledIndicesArray, inds...) = parent(x)[inds...]

# ensure we are correctly parsing these
ArrayInterface.@assume_effects :total foo(x::Bool) = x
ArrayInterface.@assume_effects bar(x::Bool) = x
Expand Down Expand Up @@ -282,4 +297,22 @@ end
end
@test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A)))
end
end
end

@testset "index_labels" begin
a = ones(2, 3)
lia = LabeledIndicesArray(a, ([:a, :b], ["x", "y", "z"]))

@test @inferred(ArrayInterface.has_index_labels(typeof(lia)))
@test !@inferred(ArrayInterface.has_index_labels(typeof(a)))

@test @inferred(ArrayInterface.index_labels(lia)) == lia.labels
@test ArrayInterface.index_labels(lia, 1) == lia.labels[1]
@test_throws ArgumentError ArrayInterface.index_labels(a)

# throw errors when interface isn't implemented correctly
struct IllegalLabelledIndices end
ArrayInterface.has_index_labels(::Type{IllegalLabelledIndices}) = true
@test_throws ArgumentError ArrayInterface.index_labels(IllegalLabelledIndices())
end