Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module AbstractFFTs

using Base.ScopedValues

export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
Expand Down
103 changes: 85 additions & 18 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,55 +58,102 @@ to1(x::AbstractArray) = _to1(axes(x), x)
_to1(::Tuple{Base.OneTo,Vararg{Base.OneTo}}, x) = x
_to1(::Tuple, x) = copy1(eltype(x), x)

# Abstract FFT Backend
export AbstractFFTBackend, fft_backend
abstract type AbstractFFTBackend end
struct BackendReference
ref::Ref{Union{Missing, AbstractFFTBackend}}
BackendReference(val::Union{Missing, AbstractFFTBackend}) = new(Ref{Union{Missing, AbstractFFTBackend}}(val))
end
Base.setindex!(ref::BackendReference, val::Union{Missing, AbstractFFTBackend}) = ref.ref[] = val
Base.setindex!(ref::BackendReference, val::Module) = setindex!(ref, val.backend())
Base.getindex(ref::BackendReference) = getindex(ref.ref)::Union{Missing, AbstractFFTBackend}
Base.convert(::Type{BackendReference}, val::AbstractFFTBackend) = BackendReference(val)
const fft_backend = ScopedValue(BackendReference(missing))

"""
set_active_backend!(back::Union{Missing, Module, AbstractFFTBackend})

Set the default FFT plan backend. A module `back` must implement `back.backend()`.
"""
function set_active_backend!(back::Union{Missing, AbstractFFTBackend, Module})
fft_backend[][] = back
end
active_backend() = fft_backend[][]
function no_backend_error()
error(
"""
No default backend available!
Make sure to also "import/using" an FFT backend such as FFTW, FFTA or RustFFT.
"""
)
end

for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft, :brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, args...; kws...) = $f(active_backend(), x, args...; kws...)
$pf(x::AbstractArray, args...; kws...) = $pf(active_backend(), x, args...; kws...)
$f(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
$pf(::Missing, x::AbstractArray, args...; kws...) = no_backend_error()
end
end
# implementations only need to provide plan_X(x, region)
# for X in (:fft, :bfft, ...):
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray) = $f(x, 1:ndims(x))
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
$f(b::AbstractFFTBackend, x::AbstractArray) = $f(b, x, 1:ndims(x))
$f(b::AbstractFFTBackend, x::AbstractArray, region) = (y = to1(x); $pf(b, y, region) * y)
$pf(b::AbstractFFTBackend, x::AbstractArray; kws...) = (y = to1(x); $pf(b, y, 1:ndims(y); kws...))
end
end

"""
plan_ifft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_ifft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but produces a plan that performs inverse transforms
[`ifft`](@ref).
[`ifft`](@ref). Uses active `backend` if no explicit `backend` is provided.
"""
plan_ifft

"""
plan_ifft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_ifft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_ifft`](@ref), but operates in-place on `A`.
"""
plan_ifft!

"""
plan_bfft!(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_bfft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_bfft`](@ref), but operates in-place on `A`.
"""
plan_bfft!

"""
plan_bfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_bfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but produces a plan that performs an unnormalized
backwards transform [`bfft`](@ref).
backwards transform [`bfft`](@ref). Uses active `backend` if no explicit `backend` is provided.
"""
plan_bfft

"""
plan_fft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_fft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized FFT along given dimensions (`dims`) of arrays matching the shape and
type of `A`. (The first two arguments have the same meaning as for [`fft`](@ref).)
Returns an object `P` which represents the linear operator computed by the FFT, and which
contains all of the information needed to compute `fft(A, dims)` quickly.

Uses active `backend` if no explicit `backend` is provided.

To apply `P` to an array `A`, use `P * A`; in general, the syntax for applying plans is much
like that of matrices. (A plan can only be applied to arrays of the same size as the `A`
for which the plan was created.) You can also apply a plan with a preallocated output array `Â`
Expand All @@ -132,34 +179,40 @@ plans that perform the equivalent of the inverse transforms [`ifft`](@ref) and s
plan_fft

"""
plan_fft!(backend A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_fft!(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Same as [`plan_fft`](@ref), but operates in-place on `A`.
"""
plan_fft!

"""
rfft(backend, A [, dims])
rfft(A [, dims])

Multidimensional FFT of a real array `A`, exploiting the fact that the transform has
conjugate symmetry in order to save roughly half the computational time and storage costs
compared with [`fft`](@ref). If `A` has size `(n_1, ..., n_d)`, the result has size
`(div(n_1,2)+1, ..., n_d)`.

Uses active `backend` if no explicit `backend` is provided.

The optional `dims` argument specifies an iterable subset of one or more dimensions of `A`
to transform, similar to [`fft`](@ref). Instead of (roughly) halving the first
dimension of `A` in the result, the `dims[1]` dimension is (roughly) halved in the same way.
"""
rfft

"""
ifft!(backend, A [, dims])
ifft!(A [, dims])

Same as [`ifft`](@ref), but operates in-place on `A`.
"""
ifft!

"""
ifft(backend, A [, dims])
ifft(A [, dims])

Multidimensional inverse FFT.
Expand All @@ -177,6 +230,7 @@ A multidimensional inverse FFT simply performs this operation along each transfo
ifft

"""
fft!(backend, A [, dims])
fft!(A [, dims])

Same as [`fft`](@ref), but operates in-place on `A`, which must be an array of
Expand All @@ -185,6 +239,7 @@ complex floating-point numbers.
fft!

"""
bfft(backend, A [, dims])
bfft(A [, dims])

Similar to [`ifft`](@ref), but computes an unnormalized inverse (backward)
Expand All @@ -200,6 +255,7 @@ computational steps elsewhere.)
bfft

"""
bfft!(backend, A [, dims])
bfft!(A [, dims])

Same as [`bfft`](@ref), but operates in-place on `A`.
Expand All @@ -215,10 +271,15 @@ for f in (:fft, :bfft, :ifft)
$pf(x::AbstractArray{<:Real}, region; kws...) = $pf(complexfloat(x), region; kws...)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(complexfloat(x), region)
$pf(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(complexfloat(x), region; kws...)
# These methods run into ambig. if a backend does not specialise T further
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region) = $f(b, complexfloat(x), region)
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Real}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region) = $f(b, complexfloat(x), region)
$pf(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, region; kws...) = $pf(b, complexfloat(x), region; kws...)
end
end
rfft(x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(realfloat(x), region)
plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...)
rfft(b::AbstractFFTBackend, x::AbstractArray{<:Union{Integer,Rational}}, region=1:ndims(x)) = rfft(b, realfloat(x), region)
plan_rfft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) = plan_rfft(b, realfloat(x), region; kws...)

# only require implementation to provide *(::Plan{T}, ::Array{T})
*(p::Plan{T}, x::AbstractArray) where {T} = p * copy1(T, x)
Expand Down Expand Up @@ -279,10 +340,10 @@ summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
end
normalization(X, region) = normalization(real(eltype(X)), size(X), region)

plan_ifft(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft(x, region; kws...), normalization(x, region))
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))
plan_ifft(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft(b, x, region; kws...), normalization(x, region))
plan_ifft!(b::AbstractFFTBackend, x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(b, x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale))
# Don't cache inverse of scaled plan (only inverse of inner plan)
Expand All @@ -302,20 +363,21 @@ LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
for f in (:brfft, :irfft)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, d::Integer) = $f(x, d, 1:ndims(x))
$f(x::AbstractArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::AbstractArray, d::Integer;kws...) = $pf(x, d, 1:ndims(x);kws...)
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer) = $f(b, x, d, 1:ndims(x))
$f(b::AbstractFFTBackend, x::AbstractArray, d::Integer, region) = $pf(b, x, d, region) * x
$pf(b::AbstractFFTBackend, x::AbstractArray, d::Integer;kws...) = $pf(b, x, d, 1:ndims(x);kws...)
end
end

for f in (:brfft, :irfft)
@eval begin
$f(x::AbstractArray{<:Real}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(complexfloat(x), d, region)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Real}, d::Integer, region) = $f(b, complexfloat(x), d, region)
$f(b::AbstractFFTBackend, x::AbstractArray{<:Complex{<:Union{Integer,Rational}}}, d::Integer, region) = $f(b, complexfloat(x), d, region)
end
end

"""
irfft(backend, A, d [, dims])
irfft(A, d [, dims])

Inverse of [`rfft`](@ref): for a complex array `A`, gives the corresponding real
Expand All @@ -330,6 +392,7 @@ transformed real array.)
irfft

"""
brfft(backend, A, d [, dims])
brfft(A, d [, dims])

Similar to [`irfft`](@ref) but computes an unnormalized inverse transform (similar
Expand All @@ -351,11 +414,12 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
return ntuple(i -> i == d1 ? d : sz[i], Val(N))
end

plan_irfft(x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
ScaledPlan(plan_brfft(x, d, region; kws...),
plan_irfft(b::AbstractFFTBackend, x::AbstractArray{Complex{T}}, d::Integer, region; kws...) where {T} =
ScaledPlan(plan_brfft(b, x, d, region; kws...),
normalization(T, brfft_output_size(x, d, region), region))

"""
plan_irfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_irfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized inverse real-input FFT, similar to [`plan_rfft`](@ref)
Expand Down Expand Up @@ -543,6 +607,7 @@ fftshift(x::Frequencies) = (x.n_nonnegative-x.n:x.n_nonnegative-1)*x.multiplier
##############################################################################

"""
fft(backend, A [, dims])
fft(A [, dims])

Performs a multidimensional FFT of the array `A`. The optional `dims` argument specifies an
Expand Down Expand Up @@ -570,6 +635,7 @@ A multidimensional FFT simply performs this operation along each transformed dim
fft

"""
plan_rfft(backend, A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_rfft(A [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized real-input FFT, similar to [`plan_fft`](@ref) except for
Expand All @@ -579,6 +645,7 @@ size of the transformed result, are the same as for [`rfft`](@ref).
plan_rfft

"""
plan_brfft(backend, A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)
plan_brfft(A, d [, dims]; flags=FFTW.ESTIMATE, timelimit=Inf)

Pre-plan an optimized real-input unnormalized transform, similar to
Expand Down
16 changes: 10 additions & 6 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ using LinearAlgebra
using AbstractFFTs
using AbstractFFTs: Plan

struct TestBackend <: AbstractFFTBackend end
backend() = TestBackend()
activate!() = AbstractFFTs.set_active_backend!(TestPlans)

mutable struct TestPlan{T,N,G} <: Plan{T}
region::G
sz::NTuple{N,Int}
Expand All @@ -30,10 +34,10 @@ Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
function AbstractFFTs.plan_fft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
end
function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T}
function AbstractFFTs.plan_bfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T}
return InverseTestPlan{T}(region, size(x))
end

Expand Down Expand Up @@ -119,10 +123,10 @@ end
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
function AbstractFFTs.plan_rfft(::TestBackend, x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
end
function AbstractFFTs.plan_brfft(x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T}
function AbstractFFTs.plan_brfft(::TestBackend, x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T}
return InverseTestRPlan{T}(d, region, size(x))
end
function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N}
Expand Down Expand Up @@ -265,10 +269,10 @@ Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.fftdims(p::InplaceTestPlan) = fftdims(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
function AbstractFFTs.plan_fft!(::TestBackend, x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
end
function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...)
function AbstractFFTs.plan_bfft!(::TestBackend, x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_bfft(x, region; kwargs...))
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Random.seed!(1234)

# Load example plan implementation.
include("TestPlans.jl")
TestPlans.activate!()

# Run interface tests for TestPlans
AbstractFFTs.TestUtils.test_complex_ffts(Array)
Expand Down
Loading