Skip to content

Commit ac13645

Browse files
Propagate AxisArray copy / view down to taking copies / views of its axes as well.
1 parent 5b1fd0e commit ac13645

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

src/indexing.jl

+35-20
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ Base.eachindex(A::AxisArray) = eachindex(A.data)
4848
This internal function determines the new set of axes that are constructed upon
4949
indexing with I.
5050
"""
51-
reaxis(A::AxisArray, I::Idx...) = _reaxis(make_axes_match(axes(A), I), I)
51+
reaxis(A::AxisArray, copy::Val, I::Idx...) = _reaxis(make_axes_match(axes(A), I), copy, I)
5252
# Linear indexing
53-
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Int}) = _new_axes(A.axes[1], I)
54-
reaxis(A::AxisArray, I::AbstractArray{Int}) = default_axes(I)
55-
reaxis(A::AxisArray{<:Any,1}, I::Real) = ()
56-
reaxis(A::AxisArray, I::Real) = ()
57-
reaxis(A::AxisArray{<:Any,1}, I::Colon) = _new_axes(A.axes[1], Base.axes(A, 1))
58-
reaxis(A::AxisArray, I::Colon) = default_axes(Base.OneTo(length(A)))
59-
reaxis(A::AxisArray{<:Any,1}, I::AbstractArray{Bool}) = _new_axes(A.axes[1], findall(I))
60-
reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
53+
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::AbstractArray{Int}) = _new_axes(A.axes[1], copy, I)
54+
reaxis(A::AxisArray, copy::Val, I::AbstractArray{Int}) = default_axes(I)
55+
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::Real) = ()
56+
reaxis(A::AxisArray, copy::Val, I::Real) = ()
57+
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::Colon) = _new_axes(A.axes[1], copy, Base.axes(A, 1))
58+
reaxis(A::AxisArray, copy::Val, I::Colon) = default_axes(Base.OneTo(length(A)))
59+
reaxis(A::AxisArray{<:Any,1}, copy::Val, I::AbstractArray{Bool}) = _new_axes(A.axes[1], copy, findall(I))
60+
reaxis(A::AxisArray, copy::Val, I::AbstractArray{Bool}) = default_axes(findall(I))
6161

6262
# Ensure the number of axes matches the number of indexing dimensions
6363
@inline function make_axes_match(axs, idxs)
@@ -66,28 +66,43 @@ reaxis(A::AxisArray, I::AbstractArray{Bool}) = default_axes(findall(I))
6666
end
6767

6868
# Now we can reaxis without worrying about mismatched axes/indices
69-
@inline _reaxis(axs::Tuple{}, idxs::Tuple{}) = ()
69+
@inline _reaxis(axs::Tuple{}, copy::Val, idxs::Tuple{}) = ()
7070
# Scalars are dropped
7171
const ScalarIndex = Union{Real, AbstractArray{<:Any, 0}}
72-
@inline _reaxis(axs::Tuple, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), tail(idxs))
72+
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{ScalarIndex, Vararg{Any}}) = _reaxis(tail(axs), copy, tail(idxs))
7373
# Colon passes straight through
74-
@inline _reaxis(axs::Tuple, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), tail(idxs))...)
74+
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{Colon, Vararg{Any}}) = (axs[1], _reaxis(tail(axs), copy, tail(idxs))...)
7575
# But arrays can add or change dimensions and accompanying axis names
76-
@inline _reaxis(axs::Tuple, idxs::Tuple{AbstractArray, Vararg{Any}}) =
77-
(_new_axes(axs[1], idxs[1])..., _reaxis(tail(axs), tail(idxs))...)
76+
@inline _reaxis(axs::Tuple, copy::Val, idxs::Tuple{AbstractArray, Vararg{Any}}) =
77+
(_new_axes(axs[1], copy, idxs[1])..., _reaxis(tail(axs), copy, tail(idxs))...)
7878

7979
# Vectors simply create new axes with the same name; just subsetted by their value
80-
@inline _new_axes(ax::Axis{name}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
80+
@inline _new_axes(ax::Axis{name}, copy::Val{true}, idx::AbstractVector) where {name} = (Axis{name}(ax.val[idx]),)
81+
@inline _new_axes(ax::Axis{name}, copy::Val{false}, idx::AbstractVector) where {name} = (Axis{name}(view(ax.val, idx)),)
82+
83+
# @inline _new_axes(ax::Axis{name}, copy::Val{false}, idx::AxisArray{T,1,D,Ax}) where {Ax, D, T, name} = _new_axes(ax, copy, idx)
84+
8185
# Arrays create multiple axes with _N appended to the axis name containing their indices
82-
@generated function _new_axes(ax::Axis{name}, idx::AbstractArray{<:Any,N}) where {name,N}
86+
@generated function _new_axes(ax::Axis{name}, copy::Val, idx::AbstractArray{<:Any,N}) where {name, N}
8387
newaxes = Expr(:tuple)
8488
for i=1:N
8589
push!(newaxes.args, :($(Axis{Symbol(name, "_", i)})(Base.axes(idx, $i))))
8690
end
8791
newaxes
8892
end
93+
8994
# And indexing with an AxisArray joins the name and overrides the values
90-
@generated function _new_axes(ax::Axis{name}, idx::AxisArray{<:Any, N}) where {name,N}
95+
@generated function _new_axes(ax::Axis{name}, copy::Val{true}, idx::AxisArray{<:Any, N}) where {name,N}
96+
newaxes = Expr(:tuple)
97+
idxnames = axisnames(idx)
98+
for i=1:N
99+
push!(newaxes.args, :($(Axis{Symbol(name, "_", idxnames[i])})(idx.axes[$i].val)))
100+
end
101+
newaxes
102+
end
103+
104+
# TODO: this is duplicated from the above
105+
@generated function _new_axes(ax::Axis{name}, copy::Val{false}, idx::AxisArray{<:Any, N}) where {name,N}
91106
newaxes = Expr(:tuple)
92107
idxnames = axisnames(idx)
93108
for i=1:N
@@ -97,19 +112,19 @@ end
97112
end
98113

99114
@propagate_inbounds function Base.getindex(A::AxisArray, idxs::Idx...)
100-
AxisArray(A.data[idxs...], reaxis(A, idxs...))
115+
AxisArray(A.data[idxs...], reaxis(A, Val(true), idxs...))
101116
end
102117

103118
# To resolve ambiguities, we need several definitions
104119
using Base: AbstractCartesianIndex
105-
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, idxs...))
120+
@propagate_inbounds Base.view(A::AxisArray, idxs::Idx...) = AxisArray(view(A.data, idxs...), reaxis(A, Val(false), idxs...))
106121

107122
# Setindex is so much simpler. Just assign it to the data:
108123
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs::Idx...) = (A.data[idxs...] = v)
109124

110125
# Logical indexing
111126
@propagate_inbounds function Base.getindex(A::AxisArray, idx::AbstractArray{Bool})
112-
AxisArray(A.data[idx], reaxis(A, idx))
127+
AxisArray(A.data[idx], reaxis(A, Val(true), idx))
113128
end
114129
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)
115130

0 commit comments

Comments
 (0)