diff --git a/Project.toml b/Project.toml index dd6da32..57bd01b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ContinuumArrays" uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c" -version = "0.20.0" +version = "0.20.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/plans.jl b/src/plans.jl index 497041c..882ee54 100644 --- a/src/plans.jl +++ b/src/plans.jl @@ -31,47 +31,55 @@ end Takes a factorization and supports it applied to different dimensions. """ -struct InvPlan{T, Facts<:Tuple, Dims} <: Plan{T} +struct InvPlan{T, Facts<:Tuple, Pln, Dims} <: Plan{T} factorizations::Facts + plan::Pln dims::Dims end -InvPlan(fact::Tuple, dims) = InvPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims) -InvPlan(fact, dims) = InvPlan((fact,), dims) +InvPlan(fact::Tuple, plan, dims) = InvPlan{mapreduce(eltype,promote_type,fact), typeof(fact), typeof(plan), typeof(dims)}(fact, plan, dims) +InvPlan(fact::Tuple, dims) = InvPlan(fact, nothing, dims) +InvPlan(fact, dims...) = InvPlan((fact,), dims...) size(F::InvPlan) = size.(F.factorizations, 1) """ - MulPlan(matrix, dims) + MulPlan(matrix, [plan], dims) -Takes a matrix and supports it applied to different dimensions. +Takes a matrix and supports it applied to different dimensions, after applying a plan. """ -struct MulPlan{T, Fact<:Tuple, Dims} <: Plan{T} +struct MulPlan{T, Fact<:Tuple, Pln, Dims} <: Plan{T} matrices::Fact + plan::Pln dims::Dims end -MulPlan(mats::Tuple, dims) = MulPlan{eltype(mats), typeof(mats), typeof(dims)}(mats, dims) -MulPlan(mats::AbstractMatrix, dims) = MulPlan((mats,), dims) +MulPlan(mats::Tuple, plan, dims) = MulPlan{mapreduce(eltype,promote_type,mats), typeof(mats), typeof(plan), typeof(dims)}(mats, plan, dims) +MulPlan(mats::Tuple, dims) = MulPlan(mats, nothing, dims) +MulPlan(mats::AbstractMatrix, dims...) = MulPlan((mats,), dims...) + +_transformifnotnothing(::Nothing, x) = x +_transformifnotnothing(P, x) = P*x for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizations))) @eval begin - function *(P::$Pln{<:Any,<:Tuple,Int}, x::AbstractVector) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, x::AbstractVector) @assert P.dims == 1 - $op(only(getfield(P, $fld)), x) # Only a single factorization when dims isa Int + $op(only(getfield(P, $fld)), _transformifnotnothing(P.plan, x)) # Only a single factorization when dims isa Int end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractMatrix) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, X::AbstractMatrix) if P.dims == 1 $op(only(getfield(P, $fld)), X) # Only a single factorization when dims isa Int else @assert P.dims == 2 - permutedims($op(only(getfield(P, $fld)), permutedims(X))) + permutedims($op(only(getfield(P, $fld)), permutedims(_transformifnotnothing(P.plan, X)))) end end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3}) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, Xin::AbstractArray{<:Any,3}) + X = _transformifnotnothing(P.plan, Xin) Y = similar(X) if P.dims == 1 for j in axes(X,3) @@ -90,7 +98,8 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati Y end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,4}) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, Xin::AbstractArray{<:Any,4}) + X = _transformifnotnothing(P.plan, Xin) Y = similar(X) if P.dims == 1 for j in axes(X,3), l in axes(X,4) @@ -114,9 +123,10 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati - *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray) = error("Overload") + *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, X::AbstractArray) = error("Overload") - function *(P::$Pln, X::AbstractArray) + function *(P::$Pln, Xin::AbstractArray) + X = _transformifnotnothing(P.plan, Xin) for (fac,dim) in zip(getfield(P, $fld), P.dims) X = $Pln(fac, dim) * X end @@ -125,7 +135,7 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati end end -*(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.dims) +*(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.plan, P.dims) -inv(P::MulPlan) = InvPlan(map(factorize,P.matrices), P.dims) -inv(P::InvPlan) = MulPlan(convert.(Matrix,P.factorizations), P.dims) \ No newline at end of file +inv(P::MulPlan{<:Any,<:Any,Nothing}) = InvPlan(map(factorize,P.matrices), P.dims) +inv(P::InvPlan{<:Any,<:Any,Nothing}) = MulPlan(convert.(Matrix,P.factorizations), P.dims) \ No newline at end of file