diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 3225916..83f8ea7 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -1,5 +1,7 @@ module AbstractFFTs +using Base.ScopedValues + export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, diff --git a/src/definitions.jl b/src/definitions.jl index eb9622a..745adf0 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -58,26 +58,68 @@ to1(x::AbstractArray) = _to1(axes(x), x) _to1(::Tuple{Base.OneTo,Vararg{Base.OneTo}}, x) = x _to1(::Tuple, x) = copy1(eltype(x), x) +# Abstract FFT Backend +export AbstractFFTBackend, fft_backend +abstract type AbstractFFTBackend end +struct BackendReference + ref::Ref{Union{Missing, AbstractFFTBackend}} + BackendReference(val::Union{Missing, AbstractFFTBackend}) = new(Ref{Union{Missing, AbstractFFTBackend}}(val)) +end +Base.setindex!(ref::BackendReference, val::Union{Missing, AbstractFFTBackend}) = ref.ref[] = val +Base.setindex!(ref::BackendReference, val::Module) = setindex!(ref, val.backend()) +Base.getindex(ref::BackendReference) = getindex(ref.ref)::Union{Missing, AbstractFFTBackend} +Base.convert(::Type{BackendReference}, val::AbstractFFTBackend) = BackendReference(val) +const fft_backend = ScopedValue(BackendReference(missing)) + +""" + set_active_backend!(back::Union{Missing, Module, AbstractFFTBackend}) + +Set the default FFT plan backend. A module `back` must implement `back.backend()`. +""" +function set_active_backend!(back::Union{Missing, AbstractFFTBackend, Module}) + fft_backend[][] = back +end +active_backend() = fft_backend[][] +function no_backend_error() + error( + """ + No default backend available! + Make sure to also "import/using" an FFT backend such as FFTW, FFTA or RustFFT. + """ + ) +end + +for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft, :brfft, :irfft) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray, args...; kws...) = $f(active_backend(), x, args...; kws...) + $pf(x::AbstractArray, args...; kws...) = $pf(active_backend(), x, args...; kws...) + $f(::Missing, x::AbstractArray, args...; kws...) = no_backend_error() + $pf(::Missing, x::AbstractArray, args...; kws...) = no_backend_error() + end +end # implementations only need to provide plan_X(x, region) # for X in (:fft, :bfft, ...): for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray) = $f(x, 1:ndims(x)) - $f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y) - $pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...)) + $f(b::AbstractFFTBackend, x::AbstractArray) = $f(b, x, 1:ndims(x)) + $f(b::AbstractFFTBackend, x::AbstractArray, region) = (y = to1(x); $pf(b, y, region) * y) + $pf(b::AbstractFFTBackend, x::AbstractArray; kws...) = (y = to1(x); $pf(b, y, 1:ndims(y); kws...)) end end """ + plan_ifft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_ifft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Same as [`plan_fft`](@ref), but produces a plan that performs inverse transforms -[`ifft`](@ref). +[`ifft`](@ref). Uses active `backend` if no explicit `backend` is provided. """ plan_ifft """ + plan_ifft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_ifft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Same as [`plan_ifft`](@ref), but operates in-place on `A`. @@ -85,6 +127,7 @@ Same as [`plan_ifft`](@ref), but operates in-place on `A`. plan_ifft! """ + plan_bfft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_bfft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Same as [`plan_bfft`](@ref), but operates in-place on `A`. @@ -92,14 +135,16 @@ Same as [`plan_bfft`](@ref), but operates in-place on `A`. plan_bfft! """ + plan_bfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_bfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Same as [`plan_fft`](@ref), but produces a plan that performs an unnormalized -backwards transform [`bfft`](@ref). +backwards transform [`bfft`](@ref). Uses active `backend` if no explicit `backend` is provided. """ plan_bfft """ + plan_fft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_fft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Pre-plan an optimized FFT along given dimensions (`dims`) of arrays matching the shape and @@ -107,6 +152,8 @@ type of `A`. (The first two arguments have the same meaning as for [`fft`](@ref Returns an object `P` which represents the linear operator computed by the FFT, and which contains all of the information needed to compute `fft(A, dims)` quickly. +Uses active `backend` if no explicit `backend` is provided. + To apply `P` to an array `A`, use `P * A`; in general, the syntax for applying plans is much like that of matrices. (A plan can only be applied to arrays of the same size as the `A` for which the plan was created.) You can also apply a plan with a preallocated output array `Â` @@ -132,6 +179,7 @@ plans that perform the equivalent of the inverse transforms [`ifft`](@ref) and s plan_fft """ + plan_fft!(backend A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_fft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Same as [`plan_fft`](@ref), but operates in-place on `A`. @@ -139,6 +187,7 @@ Same as [`plan_fft`](@ref), but operates in-place on `A`. plan_fft! """ + rfft(backend, A [, dims]) rfft(A [, dims]) Multidimensional FFT of a real array `A`, exploiting the fact that the transform has @@ -146,6 +195,8 @@ conjugate symmetry in order to save roughly half the computational time and stor compared with [`fft`](@ref). If `A` has size `(n_1, ..., n_d)`, the result has size `(div(n_1,2)+1, ..., n_d)`. +Uses active `backend` if no explicit `backend` is provided. + The optional `dims` argument specifies an iterable subset of one or more dimensions of `A` to transform, similar to [`fft`](@ref). Instead of (roughly) halving the first dimension of `A` in the result, the `dims[1]` dimension is (roughly) halved in the same way. @@ -153,6 +204,7 @@ dimension of `A` in the result, the `dims[1]` dimension is (roughly) halved in t rfft """ + ifft!(backend, A [, dims]) ifft!(A [, dims]) Same as [`ifft`](@ref), but operates in-place on `A`. @@ -160,6 +212,7 @@ Same as [`ifft`](@ref), but operates in-place on `A`. ifft! """ + ifft(backend, A [, dims]) ifft(A [, dims]) Multidimensional inverse FFT. @@ -177,6 +230,7 @@ A multidimensional inverse FFT simply performs this operation along each transfo ifft """ + fft!(backend, A [, dims]) fft!(A [, dims]) Same as [`fft`](@ref), but operates in-place on `A`, which must be an array of @@ -185,6 +239,7 @@ complex floating-point numbers. fft! """ + bfft(backend, A [, dims]) bfft(A [, dims]) Similar to [`ifft`](@ref), but computes an unnormalized inverse (backward) @@ -200,6 +255,7 @@ computational steps elsewhere.) bfft """ + bfft!(backend, A [, dims]) bfft!(A [, dims]) Same as [`bfft`](@ref), but operates in-place on `A`. @@ -215,10 +271,15 @@ for f in (:fft, :bfft, :ifft) $pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...) $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region) $pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...) + # These methods run into ambig. if a backend does not specialise T further + $f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region) = $f(b, complexfloat(x), region) + $pf(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region; kws...) = $pf(b, complexfloat(x), region; kws...) + $f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(b, complexfloat(x), region) + $pf(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(b, complexfloat(x), region; kws...) end end -rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region) -plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...) +rfft(b::AbstractFFTBackend, x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(b, realfloat(x), region) +plan_rfft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = plan_rfft(b, realfloat(x), region; kws...) # only require implementation to provide *(::Plan{T}, ::Array{T}) *(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x) @@ -279,10 +340,10 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) end normalization(X, region) = normalization(real(eltype(X)), size(X), region) -plan_ifft(x::AbstractArray, region; kws...) = - ScaledPlan(plan_bfft(x, region; kws...), normalization(x, region)) -plan_ifft!(x::AbstractArray, region; kws...) = - ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region)) +plan_ifft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = + ScaledPlan(plan_bfft(b, x, region; kws...), normalization(x, region)) +plan_ifft!(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = + ScaledPlan(plan_bfft!(b, x, region; kws...), normalization(x, region)) plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale)) # Don't cache inverse of scaled plan (only inverse of inner plan) @@ -302,20 +363,21 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) = for f in (:brfft, :irfft) pf = Symbol("plan_", f) @eval begin - $f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x)) - $f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x - $pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...) + $f(b::AbstractFFTBackend, x::AbstractArray, d::Integer) = $f(b, x, d, 1:ndims(x)) + $f(b::AbstractFFTBackend, x::AbstractArray, d::Integer, region) = $pf(b, x, d, region) * x + $pf(b::AbstractFFTBackend, x::AbstractArray, d::Integer;kws...) = $pf(b, x, d, 1:ndims(x);kws...) end end for f in (:brfft, :irfft) @eval begin - $f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region) - $f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region) + $f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, d::Integer, region) = $f(b, complexfloat(x), d, region) + $f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(b, complexfloat(x), d, region) end end """ + irfft(backend, A, d [, dims]) irfft(A, d [, dims]) Inverse of [`rfft`](@ref): for a complex array `A`, gives the corresponding real @@ -330,6 +392,7 @@ transformed real array.) irfft """ + brfft(backend, A, d [, dims]) brfft(A, d [, dims]) Similar to [`irfft`](@ref) but computes an unnormalized inverse transform (similar @@ -351,11 +414,12 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N} return ntuple(i -> i == d1 ? d : sz[i], Val(N)) end -plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} = - ScaledPlan(plan_brfft(x, d, region; kws...), +plan_irfft(b::AbstractFFTBackend, x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} = + ScaledPlan(plan_brfft(b, x, d, region; kws...), normalization(T, brfft_output_size(x, d, region), region)) """ + plan_irfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_irfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Pre-plan an optimized inverse real-input FFT, similar to [`plan_rfft`](@ref) @@ -543,6 +607,7 @@ fftshift(x::Frequencies) = (x.n_nonnegative-x.n:x.n_nonnegative-1)*x.multiplier ############################################################################## """ + fft(backend, A [, dims]) fft(A [, dims]) Performs a multidimensional FFT of the array `A`. The optional `dims` argument specifies an @@ -570,6 +635,7 @@ A multidimensional FFT simply performs this operation along each transformed dim fft """ + plan_rfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_rfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Pre-plan an optimized real-input FFT, similar to [`plan_fft`](@ref) except for @@ -579,6 +645,7 @@ size of the transformed result, are the same as for [`rfft`](@ref). plan_rfft """ + plan_brfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) plan_brfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf) Pre-plan an optimized real-input unnormalized transform, similar to diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 1c3459a..e8fd00f 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -4,6 +4,10 @@ using LinearAlgebra using AbstractFFTs using AbstractFFTs: Plan +struct TestBackend <: AbstractFFTBackend end +backend() = TestBackend() +activate!() = AbstractFFTs.set_active_backend!(TestPlans) + mutable struct TestPlan{T,N,G} <: Plan{T} region::G sz::NTuple{N,Int} @@ -30,10 +34,10 @@ Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle() AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle() -function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} +function AbstractFFTs.plan_fft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) end -function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T} +function AbstractFFTs.plan_bfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T} return InverseTestPlan{T}(region, size(x)) end @@ -119,10 +123,10 @@ end AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle() AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d) -function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} +function AbstractFFTs.plan_rfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) end -function AbstractFFTs.plan_brfft(x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T} +function AbstractFFTs.plan_brfft(::TestBackend, x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T} return InverseTestRPlan{T}(d, region, size(x)) end function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N} @@ -265,10 +269,10 @@ Base.ndims(p::InplaceTestPlan) = ndims(p.plan) AbstractFFTs.fftdims(p::InplaceTestPlan) = fftdims(p.plan) AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan) -function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) +function AbstractFFTs.plan_fft!(::TestBackend, x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_fft(x, region; kwargs...)) end -function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...) +function AbstractFFTs.plan_bfft!(::TestBackend, x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_bfft(x, region; kwargs...)) end diff --git a/test/runtests.jl b/test/runtests.jl index 0560174..0906b14 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ Random.seed!(1234) # Load example plan implementation. include("TestPlans.jl") +TestPlans.activate!() # Run interface tests for TestPlans AbstractFFTs.TestUtils.test_complex_ffts(Array)