Skip to content

Commit 7413876

Browse files
hakkeltTamas Hakkel
and
Tamas Hakkel
authored
Storage type (#23)
Co-authored-by: Tamas Hakkel <[email protected]>
1 parent 4dea8fb commit 7413876

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+468
-437
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AbstractOperators"
22
uuid = "d9c5613a-d543-52d8-9afd-8f241a8c3f1c"
3-
version = "0.3"
3+
version = "0.4"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/calculus/AdjointOperator.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export AdjointOperator
33
"""
44
`AdjointOperator(A::AbstractOperator)`
55
6-
Shorthand constructor:
6+
Shorthand constructor:
77
88
`'(A::AbstractOperator)`
99
@@ -19,7 +19,7 @@ julia> [DFT(10); DCT(10)]'
1919
"""
2020
struct AdjointOperator{T <: AbstractOperator} <: AbstractOperator
2121
A::T
22-
function AdjointOperator(A::T) where {T<:AbstractOperator}
22+
function AdjointOperator(A::T) where {T<:AbstractOperator}
2323
is_linear(A) == false && error("Cannot transpose a nonlinear operator. You might use `jacobian`")
2424
new{T}(A)
2525
end

src/calculus/AffineAdd.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export AffineAdd
33
"""
44
`AffineAdd(A::AbstractOperator, d, [sign = true])`
55
6-
Affine addition to `AbstractOperator` with an array or scalar `d`.
6+
Affine addition to `AbstractOperator` with an array or scalar `d`.
77
88
Use `sign = false` to perform subtraction.
99
@@ -26,17 +26,17 @@ true
2626
struct AffineAdd{L <: AbstractOperator, D <: Union{AbstractArray, Number}, S} <: AbstractOperator
2727
A::L
2828
d::D
29-
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
30-
if size(d) != size(A,1)
29+
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: AbstractArray}
30+
if size(d) != size(A,1)
3131
throw(DimensionMismatch("codomain size of $A not compatible with array `d` of size $(size(d))"))
3232
end
33-
if eltype(d) != codomainType(A)
33+
if eltype(d) != codomainType(A)
3434
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
3535
end
3636
new{L,D,sign}(A,d)
3737
end
3838
# scalar
39-
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
39+
function AffineAdd(A::L, d::D, sign::Bool = true) where {L, D <: Number}
4040
if typeof(d) <: Complex && codomainType(A) <: Real
4141
error("cannot tilt opertor having codomain type $(codomainType(A)) with array of type $(eltype(d))")
4242
end
@@ -46,12 +46,12 @@ end
4646

4747
# Mappings
4848
# array
49-
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
49+
function mul!(y::DD, T::AffineAdd{L, D, true}, x) where {L <: AbstractOperator, DD, D}
5050
mul!(y,T.A,x)
5151
y .+= T.d
5252
end
5353

54-
function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
54+
function mul!(y::DD, T::AffineAdd{L, D, false}, x) where {L <: AbstractOperator, DD, D}
5555
mul!(y,T.A,x)
5656
y .-= T.d
5757
end
@@ -70,7 +70,7 @@ is_null(L::AffineAdd) = is_null(L.A)
7070
is_eye(L::AffineAdd) = is_diagonal(L.A)
7171
is_diagonal(L::AffineAdd) = is_diagonal(L.A)
7272
is_invertible(L::AffineAdd) = is_invertible(L.A)
73-
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
73+
is_AcA_diagonal(L::AffineAdd) = is_AcA_diagonal(L.A)
7474
is_AAc_diagonal(L::AffineAdd) = is_AAc_diagonal(L.A)
7575
is_full_row_rank(L::AffineAdd) = is_full_row_rank(L.A)
7676
is_full_column_rank(L::AffineAdd) = is_full_column_rank(L.A)
@@ -90,7 +90,7 @@ sign(T::AffineAdd{L,D, true}) where {L,D} = 1
9090

9191
function permute(T::AffineAdd{L,D,S}, p::AbstractVector{Int}) where {L,D,S}
9292
A = permute(T.A,p)
93-
return AffineAdd(A,T.d,S)
93+
return AffineAdd(A,T.d,S)
9494
end
9595

9696
displacement(A::AffineAdd{L,D,true}) where {L,D} = A.d .+ displacement(A.A)

src/calculus/Ax_mul_Bx.jl

+8-11
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,10 @@ end
6060

6161
# Constructors
6262
function Ax_mul_Bx(A::AbstractOperator,B::AbstractOperator)
63-
s,t = size(A,1), codomainType(A)
64-
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
65-
s,t = size(B,1), codomainType(B)
66-
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
67-
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
68-
s,t = size(A,2), domainType(A)
69-
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
63+
bufA = allocateInCodomain(A)
64+
bufB = allocateInCodomain(B)
65+
bufC = allocateInCodomain(B)
66+
bufD = allocateInDomain(A)
7067
Ax_mul_Bx(A,B,bufA,bufB,bufC,bufD)
7168
end
7269

@@ -95,16 +92,16 @@ end
9592

9693
size(P::Union{Ax_mul_Bx,Ax_mul_BxJac}) = ((size(P.A,1)[1],size(P.B,1)[2]),size(P.A,2))
9794

98-
fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
95+
fun_name(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
9996

10097
domainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = domainType(L.A)
10198
codomainType(L::Union{Ax_mul_Bx,Ax_mul_BxJac}) = codomainType(L.A)
10299

103100
# utils
104-
function permute(P::Ax_mul_Bx{L1,L2,C,D},
101+
function permute(P::Ax_mul_Bx{L1,L2,C,D},
105102
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
106-
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
103+
Ax_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]))
107104
end
108105

109-
remove_displacement(P::Ax_mul_Bx) =
106+
remove_displacement(P::Ax_mul_Bx) =
110107
Ax_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

src/calculus/Ax_mul_Bxt.jl

+9-12
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ struct Ax_mul_Bxt{
3838
bufD::D
3939
function Ax_mul_Bxt(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
4040
if ndims(A,1) == 1
41-
if size(A) != size(B)
41+
if size(A) != size(B)
4242
throw(DimensionMismatch("Cannot compose operators"))
4343
end
4444
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
45-
if size(A,1)[2] != size(B,1)[2]
45+
if size(A,1)[2] != size(B,1)[2]
4646
throw(DimensionMismatch("Cannot compose operators"))
4747
end
4848
else
@@ -68,13 +68,10 @@ end
6868

6969
# Constructors
7070
function Ax_mul_Bxt(A::AbstractOperator,B::AbstractOperator)
71-
s,t = size(A,1), codomainType(A)
72-
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
73-
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
74-
s,t = size(B,1), codomainType(B)
75-
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
76-
s,t = size(A,2), domainType(A)
77-
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
71+
bufA = allocateInCodomain(A)
72+
bufB = allocateInCodomain(B)
73+
bufC = allocateInCodomain(A)
74+
bufD = allocateInDomain(A)
7875
Ax_mul_Bxt(A,B,bufA,bufB,bufC,bufD)
7976
end
8077

@@ -103,16 +100,16 @@ end
103100

104101
size(P::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = ((size(P.A,1)[1],size(P.B,1)[1]),size(P.A,2))
105102

106-
fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)
103+
fun_name(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = fun_name(L.A)*"*"*fun_name(L.B)
107104

108105
domainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = domainType(L.A)
109106
codomainType(L::Union{Ax_mul_Bxt,Ax_mul_BxtJac}) = codomainType(L.A)
110107

111108
# utils
112-
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
109+
function permute(P::Ax_mul_Bxt{L1,L2,C,D},
113110
p::AbstractVector{Int}) where {L1,L2,C,D <:ArrayPartition}
114111
Ax_mul_Bxt(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
115112
end
116113

117-
remove_displacement(P::Ax_mul_Bxt) =
114+
remove_displacement(P::Ax_mul_Bxt) =
118115
Ax_mul_Bxt(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

src/calculus/Axt_mul_Bx.jl

+9-12
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ struct Axt_mul_Bx{N,
3838
bufD::D
3939
function Axt_mul_Bx(A::L1, B::L2, bufA::C, bufB::C, bufC::C, bufD::D) where {L1,L2,C,D}
4040
if ndims(A,1) == 1
41-
if size(A) != size(B)
41+
if size(A) != size(B)
4242
throw(DimensionMismatch("Cannot compose operators"))
4343
end
4444
elseif ndims(A,1) == 2 && ndims(B,1) == 2 && size(A,2) == size(B,2)
45-
if size(A,1)[1] != size(B,1)[1]
45+
if size(A,1)[1] != size(B,1)[1]
4646
throw(DimensionMismatch("Cannot compose operators"))
4747
end
4848
else
@@ -69,13 +69,10 @@ end
6969

7070
# Constructors
7171
function Axt_mul_Bx(A::AbstractOperator,B::AbstractOperator)
72-
s,t = size(A,1), codomainType(A)
73-
bufA = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
74-
bufC = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
75-
s,t = size(B,1), codomainType(B)
76-
bufB = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
77-
s,t = size(A,2), domainType(A)
78-
bufD = eltype(s) <: Int ? zeros(t,s) : ArrayPartition(zeros.(t,s)...)
72+
bufA = allocateInCodomain(A)
73+
bufB = allocateInCodomain(B)
74+
bufC = allocateInCodomain(A)
75+
bufD = allocateInDomain(A)
7976
Axt_mul_Bx(A,B,bufA,bufB,bufC,bufD)
8077
end
8178

@@ -122,16 +119,16 @@ end
122119
size(P::Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}}) = ((1,),size(P.A,2))
123120
size(P::Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}}) = ((size(P.A,1)[2],size(P.B,1)[2]),size(P.A,2))
124121

125-
fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
122+
fun_name(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = fun_name(L.A)*"*"*fun_name(L.B)
126123

127124
domainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = domainType(L.A)
128125
codomainType(L::Union{Axt_mul_Bx,Axt_mul_BxJac}) = codomainType(L.A)
129126

130127
# utils
131-
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
128+
function permute(P::Axt_mul_Bx{N,L1,L2,C,D},
132129
p::AbstractVector{Int}) where {N,L1,L2,C,D <:ArrayPartition}
133130
Axt_mul_Bx(permute(P.A,p),permute(P.B,p),P.bufA,P.bufB,P.bufC,ArrayPartition(P.bufD.x[p]) )
134131
end
135132

136-
remove_displacement(P::Axt_mul_Bx) =
133+
remove_displacement(P::Axt_mul_Bx) =
137134
Axt_mul_Bx(remove_displacement(P.A), remove_displacement(P.B), P.bufA, P.bufB, P.bufC, P.bufD)

src/calculus/BroadCast.jl

+9-10
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ julia> B*[1.;2.]
2121
```
2222
2323
"""
24-
struct BroadCast{N,
25-
L <: AbstractOperator,
26-
T <: AbstractArray,
24+
struct BroadCast{N,
25+
L <: AbstractOperator,
26+
T <: AbstractArray,
2727
D <: AbstractArray,
2828
M,
2929
C <: NTuple{M,Colon},
@@ -36,14 +36,14 @@ struct BroadCast{N,
3636
cols::C
3737
idxs::I
3838

39-
function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
40-
L<:AbstractOperator,
39+
function BroadCast(A::L,dim_out::NTuple{N,Int},bufC::T, bufD::D) where {N,
40+
L<:AbstractOperator,
4141
T<:AbstractArray,
4242
D<:AbstractArray
4343
}
4444
Base.Broadcast.check_broadcast_shape(dim_out,size(A,1))
4545
if size(A,1) != (1,)
46-
M = length(size(A,1))
46+
M = length(size(A,1))
4747
cols = ([Colon() for i = 1:M]...,)
4848
idxs = CartesianIndices((dim_out[M+1:end]...,))
4949
new{N,L,T,D,M,typeof(cols),typeof(idxs)}(A,dim_out,bufC,bufD,cols,idxs)
@@ -52,14 +52,13 @@ struct BroadCast{N,
5252
idxs = CartesianIndices((1,))
5353
new{N,L,T,D,M,NTuple{0,Colon},typeof(idxs)}(A,dim_out,bufC,bufD,(),idxs)
5454
end
55-
5655
end
5756
end
5857

5958
# Constructors
6059

6160
BroadCast(A::L, dim_out::NTuple{N,Int}) where {N,L<:AbstractOperator} =
62-
BroadCast(A, dim_out, zeros(codomainType(A),size(A,1)), zeros(domainType(A),size(A,2)) )
61+
BroadCast(A, dim_out, allocateInCodomain(A), allocateInDomain(A))
6362

6463
# Mappings
6564

@@ -82,7 +81,7 @@ end
8281
function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where {N,L,T,D,C,I,CC,DD}
8382
R = A.A
8483
fill!(y, 0.)
85-
bii = zeros(eltype(b),1)
84+
bii = allocateInCodomain(R.A)
8685
for bi in b
8786
bii[1] = bi
8887
mul!(R.bufD, R.A', bii)
@@ -92,7 +91,7 @@ function mul!(y::CC, A::AdjointOperator{BroadCast{N,L,T,D,0,C,I}}, b::DD) where
9291
end
9392

9493
#TODO make this more general
95-
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
94+
#length(dim_out) == size(A,1) e.g. a .= b; size(a) = (m,n) size(b) = (1,n) matrix out, column in
9695
function mul!(y::CC, A::AdjointOperator{BroadCast{2,L,T,D,2,C,I}}, b::DD) where {L,T,D,C,I,CC,DD}
9796
R = A.A
9897
fill!(y, 0.)

src/calculus/Compose.jl

+10-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ export Compose
33
"""
44
`Compose(A::AbstractOperator,B::AbstractOperator)`
55
6-
Shorthand constructor:
6+
Shorthand constructor:
77
8-
`A*B`
8+
`A*B`
99
1010
Compose different `AbstractOperator`s. Notice that the domain and codomain of the operators `A` and `B` must match, i.e. `size(A,2) == size(B,1)` and `domainType(A) == codomainType(B)`.
1111
@@ -28,19 +28,22 @@ end
2828

2929
function Compose(L1::AbstractOperator, L2::AbstractOperator)
3030
if size(L1,2) != size(L2,1)
31-
throw(DimensionMismatch("cannot compose operators"))
31+
throw(DimensionMismatch("cannot compose operators with different domain and codomain sizes"))
3232
end
3333
if domainType(L1) != codomainType(L2)
34-
throw(DomainError())
34+
throw(DomainError((domainType(L1),codomainType(L2)), "cannot compose operators with different domain and codomain types"))
3535
end
36-
Compose( L1, L2, Array{domainType(L1)}(undef,size(L2,1)) )
36+
if domainStorageType(L1) != codomainStorageType(L2)
37+
throw(DomainError((domainStorageType(L1),codomainStorageType(L2)), "cannot compose operators with different input and output storage types"))
38+
end
39+
Compose(L1, L2, allocateInCodomain(L2))
3740
end
3841

3942
Compose(L1::AbstractOperator,L2::AbstractOperator,buf::AbstractArray) =
40-
Compose( (L2,L1), (buf,))
43+
Compose((L2,L1), (buf,))
4144

4245
Compose(L1::Compose, L2::AbstractOperator,buf::AbstractArray) =
43-
Compose( (L2,L1.A...), (buf,L1.buf...))
46+
Compose((L2,L1.A...), (buf,L1.buf...))
4447

4548
Compose(L1::AbstractOperator,L2::Compose, buf::AbstractArray) =
4649
Compose((L2.A...,L1), (L2.buf...,buf))

0 commit comments

Comments
 (0)