Skip to content

Commit ba9dba0

Browse files
authored
view returns a Fill (#84) (#130)
* view returns a Fill (#84) * add tests * size -> axes in broadcasted
1 parent 3cc0c92 commit ba9dba0

File tree

4 files changed

+46
-16
lines changed

4 files changed

+46
-16
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.10.2"
3+
version = "0.11"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

+25-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
66
+, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!,
77
any, all, axes, isone, iterate, unique, allunique, permutedims, inv,
88
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero,
9-
show
9+
show, view
1010

1111
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1212
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec
@@ -152,7 +152,9 @@ convert(::Type{Fill}, arr::AbstractArray{T}) where T = Fill{T}(unique_value(arr)
152152
convert(::Type{Fill{T}}, arr::AbstractArray) where T = Fill{T}(unique_value(arr), axes(arr))
153153
convert(::Type{Fill{T,N}}, arr::AbstractArray{<:Any,N}) where {T,N} = Fill{T,N}(unique_value(arr), axes(arr))
154154
convert(::Type{Fill{T,N,Axes}}, arr::AbstractArray{<:Any,N}) where {T,N,Axes} = Fill{T,N,Axes}(unique_value(arr), axes(arr))
155-
convert(::Type{T}, F::T) where T<:Fill = F # ambiguity fix
155+
# ambiguity fix
156+
convert(::Type{Fill}, arr::Fill{T}) where T = Fill{T}(unique_value(arr), axes(arr))
157+
convert(::Type{T}, F::T) where T<:Fill = F
156158

157159

158160

@@ -211,14 +213,14 @@ reshape(parent::AbstractFill, dims::Integer...) = reshape(parent, dims)
211213
reshape(parent::AbstractFill, dims::Union{Int,Colon}...) = reshape(parent, dims)
212214
reshape(parent::AbstractFill, dims::Union{Integer,Colon}...) = reshape(parent, dims)
213215

214-
reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Integer,Colon}}}) =
216+
reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Integer,Colon}}}) =
215217
fill_reshape(parent, Base._reshape_uncolon(parent, dims)...)
216-
reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Int,Colon}}}) =
218+
reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Int,Colon}}}) =
217219
fill_reshape(parent, Base._reshape_uncolon(parent, dims)...)
218-
reshape(parent::AbstractFill, shp::Tuple{Union{Integer,Base.OneTo}, Vararg{Union{Integer,Base.OneTo}}}) =
219-
reshape(parent, Base.to_shape(shp))
220-
reshape(parent::AbstractFill, dims::Dims) = Base._reshape(parent, dims)
221-
reshape(parent::AbstractFill, dims::Tuple{Integer, Vararg{Integer}}) = Base._reshape(parent, dims)
220+
reshape(parent::AbstractFill, shp::Tuple{Union{Integer,Base.OneTo}, Vararg{Union{Integer,Base.OneTo}}}) =
221+
reshape(parent, Base.to_shape(shp))
222+
reshape(parent::AbstractFill, dims::Dims) = Base._reshape(parent, dims)
223+
reshape(parent::AbstractFill, dims::Tuple{Integer, Vararg{Integer}}) = Base._reshape(parent, dims)
222224
Base._reshape(parent::AbstractFill, dims::Dims) = fill_reshape(parent, dims...)
223225
Base._reshape(parent::AbstractFill, dims::Tuple{Integer,Vararg{Integer}}) = fill_reshape(parent, dims...)
224226
# Resolves ambiguity error with `_reshape(v::AbstractArray{T, 1}, dims::Tuple{Int})`
@@ -344,7 +346,7 @@ for f in (:triu, :triu!, :tril, :tril!)
344346
end
345347

346348

347-
Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::AbstractString) =
349+
Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::AbstractString) =
348350
i == j ? s : Base.replace_with_centered_mark(s)
349351

350352

@@ -378,7 +380,7 @@ end
378380

379381
Eye(n::Integer, m::Integer) = RectDiagonal(Ones(min(n,m)), n, m)
380382
Eye{T}(n::Integer, m::Integer) where T = RectDiagonal{T}(Ones{T}(min(n,m)), n, m)
381-
function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T
383+
function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T
382384
ab = length(a)  length(b) ? a : b
383385
RectDiagonal{T}(Ones{T}((ab,)), (a,b))
384386
end
@@ -605,7 +607,7 @@ if VERSION ≥ v"1.5"
605607
Base.array_summary(io::IO, a::Fill{T}, inds::Tuple{Vararg{Base.OneTo}}) where T =
606608
print(io, Base.dims2string(length.(inds)), " Fill{$T}")
607609
Base.array_summary(io::IO, a::Eye{T}, inds::Tuple{Vararg{Base.OneTo}}) where T =
608-
print(io, Base.dims2string(length.(inds)), " Eye{$T}")
610+
print(io, Base.dims2string(length.(inds)), " Eye{$T}")
609611
end
610612

611613
Base.show(io::IO, ::MIME"text/plain", x::Union{Eye,AbstractFill}) = show(io, x)
@@ -617,4 +619,16 @@ Base.show(io::IO, ::MIME"text/plain", x::Union{Eye,AbstractFill}) = show(io, x)
617619
getindex_value(a::LinearAlgebra.AdjOrTrans) = getindex_value(parent(a))
618620
getindex_value(a::SubArray) = getindex_value(parent(a))
619621

622+
623+
##
624+
# view
625+
##
626+
627+
Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, kr::AbstractArray{Bool,N}) where N = getindex(A, kr)
628+
Base.@propagate_inbounds view(A::AbstractFill{<:Any,1}, kr::AbstractVector{Bool}) = getindex(A, kr)
629+
Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Union{Real, AbstractArray}, N}) where N =
630+
getindex(A, I...)
631+
Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Real, N}) where N =
632+
Base.invoke(view, Tuple{AbstractArray,Vararg{Any,N}}, A, I...)
633+
620634
end # module

src/fillbroadcast.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::A
162162
return broadcasted(*, a, _broadcast_getindex_value(b))
163163
end
164164

165-
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), size(r))
166-
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), size(r))
167-
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), size(r))
168-
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), size(r))
165+
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), axes(r))
166+
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), axes(r))
167+
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), axes(r))
168+
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r))

test/runtests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -1145,11 +1145,27 @@ end
11451145
end
11461146

11471147
@testset "FillArray interface" begin
1148+
@testset "SubArray" begin
1149+
a = Fill(2.0,5)
1150+
v = SubArray(a,(1:2,))
1151+
@test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0
1152+
@test convert(Fill, v) Fill(2.0,2)
1153+
end
1154+
11481155
@testset "views" begin
11491156
a = Fill(2.0,5)
11501157
v = view(a,1:2)
1158+
@test v isa Fill
11511159
@test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0
11521160
@test convert(Fill, v) Fill(2.0,2)
1161+
@test view(a,1) isa SubArray
1162+
end
1163+
1164+
@testset "view with bool" begin
1165+
a = Fill(2.0,5)
1166+
@test a[[true,false,false,true,false]] view(a,[true,false,false,true,false])
1167+
a = Fill(2.0,2,2)
1168+
@test a[[true false; false true]] view(a, [true false; false true])
11531169
end
11541170

11551171
@testset "adjtrans" begin

0 commit comments

Comments
 (0)