Skip to content

Commit e29f045

Browse files
authored
Specialized dot(u, D::Diagonal{<:Any,<:Union{Ones,Fill}}, v) (#138)
1 parent 708b830 commit e29f045

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.11.3"
3+
version = "0.11.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
99
show, view, in
1010

1111
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
12-
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec
12+
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec
1313

1414
import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1515

src/fillalgebra.jl

+17
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real
138138
end
139139
*(a::Transpose{T, <:AbstractMatrix{T}}, b::Zeros{T, 1}) where T<:Real = mult_zeros(a, b)
140140

141+
function dot(u::AbstractVector, E::Eye, v::AbstractVector)
142+
length(u) == size(E,1) && length(v) == size(E,2) ||
143+
throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(E))×$(length(v))"))
144+
dot(u, v)
145+
end
146+
147+
function dot(u::AbstractVector, D::Diagonal{<:Any,<:Fill}, v::AbstractVector)
148+
length(u) == size(D,1) && length(v) == size(D,2) ||
149+
throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(D))×$(length(v))"))
150+
D.diag.value*dot(u, v)
151+
end
152+
153+
function dot(u::AbstractVector{T}, D::Diagonal{U,<:Zeros}, v::AbstractVector{V}) where {T,U,V}
154+
length(u) == size(D,1) && length(v) == size(D,2) ||
155+
throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(D))×$(length(v))"))
156+
zero(promote_type(T,U,V))
157+
end
141158

142159
+(a::Zeros) = a
143160
-(a::Zeros) = a

test/runtests.jl

+27-2
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ end
244244
@test Z[:,1] Z[1:5,1] Zeros(5)
245245
@test Z[1,:] Z[1,1:6] Zeros(6)
246246
@test Z[:,:] Z[1:5,1:6] Z[1:5,:] Z[:,1:6] Z
247-
247+
248248
A = Fill(2.0,5,6,7)
249249
Z = Zeros(5,6,7)
250250
@test A[:,1,1] A[1:5,1,1] Fill(2.0,5)
@@ -1098,6 +1098,31 @@ end
10981098
end
10991099
end
11001100

1101+
@testset "dot products" begin
1102+
n = 15
1103+
o = Ones(1:n)
1104+
z = Zeros(1:n)
1105+
D = Diagonal(o)
1106+
Z = Diagonal(z)
1107+
1108+
Random.seed!(5)
1109+
u = rand(n)
1110+
v = rand(n)
1111+
1112+
@test dot(u, D, v) == dot(u, v)
1113+
@test dot(u, 2D, v) == 2dot(u, v)
1114+
@test dot(u, Z, v) == 0
1115+
1116+
@test_throws DimensionMismatch dot(u[1:end-1], D, v)
1117+
@test_throws DimensionMismatch dot(u[1:end-1], D, v[1:end-1])
1118+
1119+
@test_throws DimensionMismatch dot(u, 2D, v[1:end-1])
1120+
@test_throws DimensionMismatch dot(u, 2D, v[1:end-1])
1121+
1122+
@test_throws DimensionMismatch dot(u, Z, v[1:end-1])
1123+
@test_throws DimensionMismatch dot(u, Z, v[1:end-1])
1124+
end
1125+
11011126
if VERSION  v"1.5"
11021127
@testset "print" begin
11031128
@test stringmime("text/plain", Zeros(3)) == "3-element Zeros{Float64}"
@@ -1203,4 +1228,4 @@ end
12031228
@test FillArrays.getindex_value(transpose(a)) == FillArrays.unique_value(transpose(a)) == 2.0
12041229
@test convert(Fill, transpose(a)) Fill(2.0,1,5)
12051230
end
1206-
end
1231+
end

0 commit comments

Comments
 (0)