diff --git a/src/FFTW.jl b/src/FFTW.jl index 8884704..3d3a344 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -4,7 +4,7 @@ using LinearAlgebra, Reexport, Preferences @reexport using AbstractFFTs using Base.Threads -import AbstractFFTs: Plan, ScaledPlan, +import AbstractFFTs: Plan, ScaledPlan, AbstractFFTBackend, fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, @@ -16,6 +16,11 @@ export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct! include("providers.jl") +export FFTWBackend +struct FFTWBackend <: AbstractFFTBackend end +backend() = FFTWBackend() +activate!() = AbstractFFTs.set_active_backend!(FFTW) + function fftw_init_check() # If someone is trying to set the provider via the old environment variable, warn them that they # should instead use `set_provider!()` instead. @@ -59,7 +64,9 @@ elseif fftw_provider == "mkl" end const libfftw3 = FakeLazyLibrary(:libfftw3_no_init, fftw_init_check, C_NULL) const libfftw3f = FakeLazyLibrary(:libfftw3f_no_init, fftw_init_check, C_NULL) - +function __init__() + activate!() +end else @static if fftw_provider == "fftw" import FFTW_jll: libfftw3_path as libfftw3_no_init, @@ -74,6 +81,7 @@ elseif fftw_provider == "mkl" end function __init__() fftw_init_check() + activate!() end end diff --git a/src/fft.jl b/src/fft.jl index eb62f61..fce2041 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -771,13 +771,13 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD)) plan_f! = Symbol("plan_",f,"!") idirection = -direction @eval begin - function $plan_f(X::StridedArray{T,N}, region; + function $plan_f(b::FFTWBackend, X::StridedArray{T,N}, region; flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT, num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N} if num_threads !== nothing plan = set_num_threads(num_threads) do - $plan_f(X, region; flags = flags, timelimit = timelimit) + $plan_f(b, X, region; flags = flags, timelimit = timelimit) end return plan end @@ -785,22 +785,22 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD)) region, flags, timelimit) end - function $plan_f!(X::StridedArray{T,N}, region; + function $plan_f!(::FFTWBackend, X::StridedArray{T,N}, region; flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT, num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N} if num_threads !== nothing plan = set_num_threads(num_threads) do - $plan_f!(X, region; flags = flags, timelimit = timelimit) + $plan_f!(b, X, region; flags = flags, timelimit = timelimit) end return plan end cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit) end - $plan_f(X::StridedArray{<:fftwComplex}; kws...) = - $plan_f(X, ntuple(identity, ndims(X)); kws...) - $plan_f!(X::StridedArray{<:fftwComplex}; kws...) = - $plan_f!(X, ntuple(identity, ndims(X)); kws...) + $plan_f(b::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) = + $plan_f(b, X, ntuple(identity, ndims(X)); kws...) + $plan_f!(b, ::FFTWBackend, X::StridedArray{<:fftwComplex}; kws...) = + $plan_f!(b, X, ntuple(identity, ndims(X)); kws...) function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}; num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace} @@ -843,13 +843,13 @@ end for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) # Note: use $FORWARD and $BACKWARD below because of issue #9775 @eval begin - function plan_rfft(X::StridedArray{$Tr,N}, region; + function plan_rfft(b::FFTWBackend, X::StridedArray{$Tr,N}, region; flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT, num_threads::Union{Nothing, Integer} = nothing) where N if num_threads !== nothing plan = set_num_threads(num_threads) do - plan_rfft(X, region; flags = flags, timelimit = timelimit) + plan_rfft(b, X, region; flags = flags, timelimit = timelimit) end return plan end @@ -858,13 +858,13 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit) end - function plan_brfft(X::StridedArray{$Tc,N}, d::Integer, region; + function plan_brfft(::FFTWBackend, X::StridedArray{$Tc,N}, d::Integer, region; flags::Integer=ESTIMATE, timelimit::Real=NO_TIMELIMIT, num_threads::Union{Nothing, Integer} = nothing) where N if num_threads !== nothing plan = set_num_threads(num_threads) do - plan_brfft(X, d, region; flags = flags, timelimit = timelimit) + plan_brfft(b, X, d, region; flags = flags, timelimit = timelimit) end return plan end @@ -884,8 +884,8 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) end end - plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,ntuple(identity, ndims(X));kws...) - plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,ntuple(identity, ndims(X));kws...) + plan_rfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_rfft(b, X,ntuple(identity, ndims(X));kws...) + plan_brfft(b::FFTWBackend, X::StridedArray{$Tr};kws...)=plan_brfft(b, X,ntuple(identity, ndims(X));kws...) function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}, num_threads::Union{Nothing, Integer} = nothing) where N diff --git a/test/runtests.jl b/test/runtests.jl index d84eb50..d249884 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,8 @@ true_fftd3_m3d[:,:,2] .= -15 @eval begin $f_(x::$A{T,N}) where {T,N} = invoke($f, Tuple{AbstractArray{T,N}}, x) $f_(x::$A{T,N},r::R) where {T,N,R} = invoke($f,Tuple{AbstractArray{T,N},R},x,r) + $f_(b::FFTWBackend, x::$A{T,N}) where {T,N} = invoke($f, Tuple{FFTWBackend, AbstractArray{T,N}}, b, x) + $f_(b::FFTWBackend, x::$A{T,N},r::R) where {T,N,R} = invoke($f,Tuple{FFTWBackend, AbstractArray{T,N},R}, b, x,r) end end end @@ -358,12 +360,16 @@ let end @testset "Base Julia issue #9772, with size $(size(x))" for x in (randn(10),randn(10,12)) + # note: Inference/type-stability "breaks" if multiple FFT backends are loaded + # and one does not supply the backend z = complex(x) y = rfft(x) @inferred rfft(x) + @inferred rfft(FFTW.backend(), x) if ndims(x) == 2 @inferred brfft(x,18) + @inferred brfft(FFTW.backend(), x,18) end @inferred brfft(y,10) @@ -371,6 +377,7 @@ end plan_bfft, plan_fft, plan_ifft, fft, bfft, fft_, ifft) p = @inferred f(z) + @inferred f(FFTW.backend(), z) if isa(p, Plan) @inferred plan_inv(p) end @@ -378,6 +385,7 @@ end for f in (plan_bfft, plan_fft, plan_ifft, plan_rfft, fft, bfft, fft_, ifft) p = @inferred f(x) + @inferred f(FFTW.backend(), x) if isa(p, Plan) @inferred plan_inv(p) end