Skip to content

Commit e16bbe4

Browse files
authored
New calculus rules (#10)
* `Ax_mul_Bx` --> Generalizes `NonLinearCompose` * `Axt_mul_Bx` * `Ax_mul_Bxt` * `HadamardProd` --> Generalizes `Hadamard` `Hadamard` & `NonLinearCompose` will be deprecated in future version of AbstractOperators.
1 parent 8e0fbd8 commit e16bbe4

12 files changed

+916
-168
lines changed

docs/src/calculus.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ DCAT
1212

1313
```@docs
1414
Compose
15-
NonLinearCompose
16-
Hadamard
15+
HadamardProd
16+
Ax_mul_Bx
17+
Axt_mul_Bx
18+
Ax_mul_Bxt
1719
```
1820

1921
## Transformations

src/AbstractOperators.jl

+4
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ include("calculus/Sum.jl")
5858
include("calculus/AffineAdd.jl")
5959
include("calculus/Jacobian.jl")
6060
include("calculus/NonLinearCompose.jl")
61+
include("calculus/Axt_mul_Bx.jl")
62+
include("calculus/Ax_mul_Bxt.jl")
63+
include("calculus/Ax_mul_Bx.jl")
6164
include("calculus/Hadamard.jl")
65+
include("calculus/HadamardProd.jl")
6266

6367
# Non-Linear operators
6468
include("nonlinearoperators/Pow.jl")

src/calculus/Ax_mul_Bx.jl

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#Ax_mul_Bx
2+
3+
export Ax_mul_Bx
4+
5+
"""
6+
`Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)`
7+
8+
Create an operator `P` such that:
9+
10+
`P*x == (Ax)*(Bx)`
11+
12+
# Example
13+
14+
```julia
15+
julia> A,B = randn(4,4),randn(4,4);
16+
17+
julia> P = Ax_mul_Bx(MatrixOp(A,4),MatrixOp(B,4))
18+
▒*▒ ℝ^4 -> ℝ^(4, 4)
19+
20+
julia> X = randn(4,4);
21+
22+
julia> P*X == (A*X)*(B*X)
23+
true
24+
25+
```
26+
"""
27+
struct Ax_mul_Bx{
28+
L1 <: AbstractOperator,
29+
L2 <: AbstractOperator,
30+
C <: AbstractArray,
31+
D <: AbstractArray,
32+
} <: NonLinearOperator
33+
A::L1
34+
B::L2
35+
bufA::C
36+
bufB::C
37+
bufC::C
38+
bufD::D
39+
function Ax_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
40+
if ndims(A,1) != 2 || size(A,2) != size(B,2) || size(A,1)[2] != size(B,1)[1]
41+
throw(DimensionMismatch("Cannot compose operators"))
42+
end
43+
new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
44+
end
45+
end
46+
47+
struct Ax_mul_BxJac{
48+
L1 <: AbstractOperator,
49+
L2 <: AbstractOperator,
50+
C <: AbstractArray,
51+
D <: AbstractArray,
52+
} <: LinearOperator
53+
A::L1
54+
B::L2
55+
bufA::C
56+
bufB::C
57+
bufC::C
58+
bufD::D
59+
end
60+
61+
# Constructors
62+
function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)
63+
s,t = size(A,1), codomainType(A)
64+
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
65+
s,t = size(B,1), codomainType(B)
66+
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
67+
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
68+
s,t = size(A,2), domainType(A)
69+
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
70+
Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD)
71+
end
72+
73+
# Jacobian
74+
function Jacobian(P::Ax_mul_Bx{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D}
75+
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
76+
Ax_mul_BxJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
77+
end
78+
79+
# Mappings
80+
function mul!(y, P::Ax_mul_Bx{L1,L2,C,D}, b) where {L1,L2,C,D}
81+
mul!(P.bufA,P.A,b)
82+
mul!(P.bufB,P.B,b)
83+
mul!(y,P.bufA, P.bufB)
84+
end
85+
86+
function mul!(y, J::AdjointOperator{Ax_mul_BxJac{L1,L2,C,D}}, b) where {L1,L2,C,D}
87+
#y .= J.A.B' * ( J.A.bufA'*b ) + J.A.A' * ( b*J.A.bufB' )
88+
mul!(J.A.bufC, J.A.bufA', b)
89+
mul!(y, J.A.B', J.A.bufC)
90+
mul!(J.A.bufA, b, J.A.bufB')
91+
mul!(J.A.bufD, J.A.A', J.A.bufA)
92+
y .+= J.A.bufD
93+
return y
94+
end
95+
96+
size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2))
97+
98+
fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
99+
100+
domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A)
101+
codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A)
102+
103+
# utils
104+
function permute(P::Ax_mul_Bx{L1,L2,C,D},
105+
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
106+
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
107+
end
108+
109+
remove_displacement(P::Ax_mul_Bx) =
110+
Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

src/calculus/Ax_mul_Bxt.jl

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#Ax_mul_Bxt
2+
3+
export Ax_mul_Bxt
4+
5+
"""
6+
`Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)`
7+
8+
Create an operator `P` such that:
9+
10+
`P == (Ax)*(Bx)'`
11+
12+
# Example: Matrix multiplication
13+
14+
```julia
15+
julia> A,B = randn(4,4),randn(4,4);
16+
17+
julia> P = Ax_mul_Bxt(MatrixOp(A),MatrixOp(B))
18+
▒*▒ ℝ^4 -> ℝ^(4, 4)
19+
20+
julia> x = randn(4);
21+
22+
julia> P*x == (A*x)*(B*x)'
23+
true
24+
25+
```
26+
"""
27+
struct Ax_mul_Bxt{
28+
L1 <: AbstractOperator,
29+
L2 <: AbstractOperator,
30+
C <: AbstractArray,
31+
D <: AbstractArray,
32+
} <: NonLinearOperator
33+
A::L1
34+
B::L2
35+
bufA::C
36+
bufB::C
37+
bufC::C
38+
bufD::D
39+
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
40+
if ndims(A,1) == 1
41+
if size(A) != size(B)
42+
throw(DimensionMismatch("Cannot compose operators"))
43+
end
44+
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
45+
if size(A,1)[2] != size(B,1)[2]
46+
throw(DimensionMismatch("Cannot compose operators"))
47+
end
48+
else
49+
throw(DimensionMismatch("Cannot compose operators"))
50+
end
51+
new{L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
52+
end
53+
end
54+
55+
struct Ax_mul_BxtJac{
56+
L1 <: AbstractOperator,
57+
L2 <: AbstractOperator,
58+
C <: AbstractArray,
59+
D <: AbstractArray,
60+
} <: LinearOperator
61+
A::L1
62+
B::L2
63+
bufA::C
64+
bufB::C
65+
bufC::C
66+
bufD::D
67+
end
68+
69+
# Constructors
70+
function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)
71+
s,t = size(A,1), codomainType(A)
72+
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
73+
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
74+
s,t = size(B,1), codomainType(B)
75+
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
76+
s,t = size(A,2), domainType(A)
77+
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
78+
Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD)
79+
end
80+
81+
# Jacobian
82+
function Jacobian(P::Ax_mul_Bxt{L1,L2,C,D}, x::AbstractArray) where {L1,L2,C,D}
83+
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
84+
Ax_mul_BxtJac{typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
85+
end
86+
87+
# Mappings
88+
function mul!(y, P::Ax_mul_Bxt{L1,L2,C,D}, b) where {L1,L2,C,D}
89+
mul!(P.bufA,P.A,b)
90+
mul!(P.bufB,P.B,b)
91+
mul!(y,P.bufA, P.bufB')
92+
end
93+
94+
function mul!(y, J::AdjointOperator{Ax_mul_BxtJac{L1,L2,C,D}}, b) where {L1,L2,C,D}
95+
#y .= J.A.A'*(b*(J.A.bufB)) + J.A.B'*(b'*(J.A.bufA))
96+
mul!(J.A.bufC, b, J.A.bufB)
97+
mul!(y, J.A.A', J.A.bufC)
98+
mul!(J.A.bufB, b', J.A.bufA)
99+
mul!(J.A.bufD, J.A.B', J.A.bufB)
100+
y .+= J.A.bufD
101+
return y
102+
end
103+
104+
size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2))
105+
106+
fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)
107+
108+
domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A)
109+
codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A)
110+
111+
# utils
112+
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
113+
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
114+
Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
115+
end
116+
117+
remove_displacement(P::Ax_mul_Bxt) =
118+
Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

src/calculus/Axt_mul_Bx.jl

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#Axt_mul_Bx
2+
3+
export Axt_mul_Bx
4+
5+
"""
6+
`Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)`
7+
8+
Create an operator `P` such that:
9+
10+
`P*x == (Ax)'*(Bx)`
11+
12+
# Example
13+
14+
```julia
15+
julia> A,B = randn(4,4),randn(4,4);
16+
17+
julia> P = Axt_mul_Bx(MatrixOp(A),MatrixOp(B))
18+
▒*▒ ℝ^4 -> ℝ^1
19+
20+
julia> x = randn(4);
21+
22+
julia> P*x == [(A*x)'*(B*x)]
23+
true
24+
25+
```
26+
"""
27+
struct Axt_mul_Bx{N,
28+
L1 <: AbstractOperator,
29+
L2 <: AbstractOperator,
30+
C <: AbstractArray,
31+
D <: AbstractArray,
32+
} <: NonLinearOperator
33+
A::L1
34+
B::L2
35+
bufA::C
36+
bufB::C
37+
bufC::C
38+
bufD::D
39+
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
40+
if ndims(A,1) == 1
41+
if size(A) != size(B)
42+
throw(DimensionMismatch("Cannot compose operators"))
43+
end
44+
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
45+
if size(A,1)[1] != size(B,1)[1]
46+
throw(DimensionMismatch("Cannot compose operators"))
47+
end
48+
else
49+
throw(DimensionMismatch("Cannot compose operators"))
50+
end
51+
N = ndims(A,1)
52+
new{N,L1,L2,C,D}(A,B,bufA,bufB,bufC,bufD)
53+
end
54+
end
55+
56+
struct Axt_mul_BxJac{N,
57+
L1 <: AbstractOperator,
58+
L2 <: AbstractOperator,
59+
C <: AbstractArray,
60+
D <: AbstractArray,
61+
} <: LinearOperator
62+
A::L1
63+
B::L2
64+
bufA::C
65+
bufB::C
66+
bufC::C
67+
bufD::D
68+
end
69+
70+
# Constructors
71+
function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)
72+
s,t = size(A,1), codomainType(A)
73+
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
74+
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
75+
s,t = size(B,1), codomainType(B)
76+
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
77+
s,t = size(A,2), domainType(A)
78+
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
79+
Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD)
80+
end
81+
82+
# Jacobian
83+
function Jacobian(P::Axt_mul_Bx{N,L1,L2,C,D}, x::AbstractArray) where {N,L1,L2,C,D}
84+
JA, JB = Jacobian(P.A, x), Jacobian(P.B, x)
85+
Axt_mul_BxJac{N,typeof(JA),typeof(JB),C,D}(JA,JB,P.bufA,P.bufB,P.bufC,P.bufD)
86+
end
87+
88+
# Mappings
89+
# N == 1 input is a vector
90+
function mul!(y, P::Axt_mul_Bx{1,L1,L2,C,D}, b) where {L1,L2,C,D}
91+
mul!(P.bufA,P.A,b)
92+
mul!(P.bufB,P.B,b)
93+
y[1] = dot(P.bufA,P.bufB)
94+
end
95+
96+
function mul!(y, J::AdjointOperator{Axt_mul_BxJac{1,L1,L2,C,D}}, b) where {L1,L2,C,D}
97+
#y .= conj(J.A.A'*J.A.bufB+J.A.B'*J.A.bufA).*b[1]
98+
mul!(y, J.A.A', J.A.bufB)
99+
mul!(J.A.bufD, J.A.B', J.A.bufA)
100+
y .= conj.( y .+ J.A.bufD ) .* b[1]
101+
return y
102+
end
103+
104+
# N == 2 input is a matrix
105+
function mul!(y, P::Axt_mul_Bx{2,L1,L2,C,D}, b) where {L1,L2,C,D}
106+
mul!(P.bufA,P.A,b)
107+
mul!(P.bufB,P.B,b)
108+
mul!(y,P.bufA',P.bufB)
109+
return y
110+
end
111+
112+
function mul!(y, J::AdjointOperator{Axt_mul_BxJac{2,L1,L2,C,D}}, b) where {L1,L2,C,D}
113+
# y .= J.A.A'*((J.A.bufB)*b') + J.A.B'*((J.A.bufA)*b)
114+
mul!(J.A.bufC, J.A.bufB, b')
115+
mul!(y, J.A.A', J.A.bufC)
116+
mul!(J.A.bufB, J.A.bufA, b)
117+
mul!(J.A.bufD, J.A.B', J.A.bufB)
118+
y .+= J.A.bufD
119+
return y
120+
end
121+
122+
size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2))
123+
size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2))
124+
125+
fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
126+
127+
domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A)
128+
codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A)
129+
130+
# utils
131+
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
132+
p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition}
133+
Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
134+
end
135+
136+
remove_displacement(P::Axt_mul_Bx) =
137+
Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

0 commit comments

Comments
 (0)