@@ -38,11 +38,11 @@ struct Axt_mul_Bx{N,
38
38
bufD:: D
39
39
function Axt_mul_Bx (A:: L1 , B:: L2 , bufA:: C , bufB:: C , bufC:: C , bufD:: D ) where {L1,L2,C,D}
40
40
if ndims (A,1 ) == 1
41
- if size (A) != size (B)
41
+ if size (A) != size (B)
42
42
throw (DimensionMismatch (" Cannot compose operators" ))
43
43
end
44
44
elseif ndims (A,1 ) == 2 && ndims (B,1 ) == 2 && size (A,2 ) == size (B,2 )
45
- if size (A,1 )[1 ] != size (B,1 )[1 ]
45
+ if size (A,1 )[1 ] != size (B,1 )[1 ]
46
46
throw (DimensionMismatch (" Cannot compose operators" ))
47
47
end
48
48
else
69
69
70
70
# Constructors
71
71
function Axt_mul_Bx (A:: AbstractOperator ,B:: AbstractOperator )
72
- s,t = size (A,1 ), codomainType (A)
73
- bufA = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s)... )
74
- bufC = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s)... )
75
- s,t = size (B,1 ), codomainType (B)
76
- bufB = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s)... )
77
- s,t = size (A,2 ), domainType (A)
78
- bufD = eltype (s) <: Int ? zeros (t,s) : ArrayPartition (zeros .(t,s)... )
72
+ bufA = allocateInCodomain (A)
73
+ bufB = allocateInCodomain (B)
74
+ bufC = allocateInCodomain (A)
75
+ bufD = allocateInDomain (A)
79
76
Axt_mul_Bx (A,B,bufA,bufB,bufC,bufD)
80
77
end
81
78
@@ -122,16 +119,16 @@ end
122
119
size (P:: Union{Axt_mul_Bx{1},Axt_mul_BxJac{1}} ) = ((1 ,),size (P. A,2 ))
123
120
size (P:: Union{Axt_mul_Bx{2},Axt_mul_BxJac{2}} ) = ((size (P. A,1 )[2 ],size (P. B,1 )[2 ]),size (P. A,2 ))
124
121
125
- fun_name (L:: Union{Axt_mul_Bx,Axt_mul_BxJac} ) = fun_name (L. A)* " *" * fun_name (L. B)
122
+ fun_name (L:: Union{Axt_mul_Bx,Axt_mul_BxJac} ) = fun_name (L. A)* " *" * fun_name (L. B)
126
123
127
124
domainType (L:: Union{Axt_mul_Bx,Axt_mul_BxJac} ) = domainType (L. A)
128
125
codomainType (L:: Union{Axt_mul_Bx,Axt_mul_BxJac} ) = codomainType (L. A)
129
126
130
127
# utils
131
- function permute (P:: Axt_mul_Bx{N,L1,L2,C,D} ,
128
+ function permute (P:: Axt_mul_Bx{N,L1,L2,C,D} ,
132
129
p:: AbstractVector{Int} ) where {N,L1,L2,C,D <: ArrayPartition }
133
130
Axt_mul_Bx (permute (P. A,p),permute (P. B,p),P. bufA,P. bufB,P. bufC,ArrayPartition (P. bufD. x[p]) )
134
131
end
135
132
136
- remove_displacement (P:: Axt_mul_Bx ) =
133
+ remove_displacement (P:: Axt_mul_Bx ) =
137
134
Axt_mul_Bx (remove_displacement (P. A), remove_displacement (P. B), P. bufA, P. bufB, P. bufC, P. bufD)
0 commit comments