Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra, Reexport, Preferences
@reexport using AbstractFFTs
using Base.Threads

import AbstractFFTs: Plan, ScaledPlan,
import AbstractFFTs: Plan, ScaledPlan, AbstractFFTBackend,
fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
Expand All @@ -16,6 +16,11 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!

include("providers.jl")

export FFTWBackend
struct FFTWBackend <: AbstractFFTBackend end
backend() = FFTWBackend()
activate!() = AbstractFFTs.set_active_backend!(FFTW)

function fftw_init_check()
# If someone is trying to set the provider via the old environment variable, warn them that they
# should instead use `set_provider!()` instead.
Expand Down Expand Up @@ -59,7 +64,9 @@ elseif fftw_provider == "mkl"
end
const libfftw3 = FakeLazyLibrary(:libfftw3_no_init, fftw_init_check, C_NULL)
const libfftw3f = FakeLazyLibrary(:libfftw3f_no_init, fftw_init_check, C_NULL)

function __init__()
activate!()
end
else
@static if fftw_provider == "fftw"
import FFTW_jll: libfftw3_path as libfftw3_no_init,
Expand All @@ -74,6 +81,7 @@ elseif fftw_provider == "mkl"
end
function __init__()
fftw_init_check()
activate!()
end
end

Expand Down
28 changes: 14 additions & 14 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,36 +771,36 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
plan_f! = Symbol("plan_",f,"!")
idirection = -direction
@eval begin
function $plan_f(X::StridedArray{T,N}, region;
function $plan_f(b::FFTWBackend, X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f(X, region; flags = flags, timelimit = timelimit)
$plan_f(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,false,N}(X, fakesimilar(flags, X, T),
region, flags, timelimit)
end

function $plan_f!(X::StridedArray{T,N}, region;
function $plan_f!(::FFTWBackend, X::StridedArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N}
if num_threads !== nothing
plan = set_num_threads(num_threads) do
$plan_f!(X, region; flags = flags, timelimit = timelimit)
$plan_f!(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
end
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(X, ntuple(identity, ndims(X)); kws...)
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(X, ntuple(identity, ndims(X)); kws...)
$plan_f(b::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
$plan_f(b, X, ntuple(identity, ndims(X)); kws...)
$plan_f!(b, ::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) =
$plan_f!(b, X, ntuple(identity, ndims(X)); kws...)

function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
Expand Down Expand Up @@ -843,13 +843,13 @@ end
for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
# Note: use $FORWARD and $BACKWARD below because of issue #9775
@eval begin
function plan_rfft(X::StridedArray{$Tr,N}, region;
function plan_rfft(b::FFTWBackend, X::StridedArray{$Tr,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_rfft(X, region; flags = flags, timelimit = timelimit)
plan_rfft(b, X, region; flags = flags, timelimit = timelimit)
end
return plan
end
Expand All @@ -858,13 +858,13 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit)
end

function plan_brfft(X::StridedArray{$Tc,N}, d::Integer, region;
function plan_brfft(::FFTWBackend, X::StridedArray{$Tc,N}, d::Integer, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT,
num_threads::Union{Nothing, Integer} = nothing) where N
if num_threads !== nothing
plan = set_num_threads(num_threads) do
plan_brfft(X, d, region; flags = flags, timelimit = timelimit)
plan_brfft(b, X, d, region; flags = flags, timelimit = timelimit)
end
return plan
end
Expand All @@ -884,8 +884,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
end
end

plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...)
plan_rfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_rfft(b, X,ntuple(identity, ndims(X));kws...)
plan_brfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_brfft(b, X,ntuple(identity, ndims(X));kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
num_threads::Union{Nothing, Integer} = nothing) where N
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ true_fftd3_m3d[:,:,2] .= -15
@eval begin
$f_(x::$A{T,N}) where {T,N} = invoke($f, Tuple{AbstractArray{T,N}}, x)
$f_(x::$A{T,N},r::R) where {T,N,R} = invoke($f,Tuple{AbstractArray{T,N},R},x,r)
$f_(b::FFTWBackend, x::$A{T,N}) where {T,N} = invoke($f, Tuple{FFTWBackend, AbstractArray{T,N}}, b, x)
$f_(b::FFTWBackend, x::$A{T,N},r::R) where {T,N,R} = invoke($f,Tuple{FFTWBackend, AbstractArray{T,N},R}, b, x,r)
end
end
end
Expand Down Expand Up @@ -358,26 +360,32 @@ let
end

@testset "Base Julia issue #9772, with size $(size(x))" for x in (randn(10),randn(10,12))
# note: Inference/type-stability "breaks" if multiple FFT backends are loaded
# and one does not supply the backend
z = complex(x)
y = rfft(x)
@inferred rfft(x)
@inferred rfft(FFTW.backend(), x)

if ndims(x) == 2
@inferred brfft(x,18)
@inferred brfft(FFTW.backend(), x,18)
end

@inferred brfft(y,10)
for f in (plan_bfft!, plan_fft!, plan_ifft!,
plan_bfft, plan_fft, plan_ifft,
fft, bfft, fft_, ifft)
p = @inferred f(z)
@inferred f(FFTW.backend(), z)
if isa(p, Plan)
@inferred plan_inv(p)
end
end
for f in (plan_bfft, plan_fft, plan_ifft,
plan_rfft, fft, bfft, fft_, ifft)
p = @inferred f(x)
@inferred f(FFTW.backend(), x)
if isa(p, Plan)
@inferred plan_inv(p)
end
Expand Down
Loading