-
Notifications
You must be signed in to change notification settings - Fork 260
Description
The plan_*fft functions in AbstractFFTs take keyword arguments, but the methods of these functions provided by CUDA.CUFFT do not. Code that passes keyword arguments to these functions, e.g. to influence FFTW planning with Arrays, does not work with CuArrays because the presence of keyword arguments prevents dispatch to the CUDA.CUFFT methods. This results in an a MethodError if FFTW is not being used. If FFTW is being used, this dispatches to an FFTW method which then throws an ArgumentError because it tries to take the CPU address of the CuArray.
Since CUDA.CUFFT ignores these keyword arguments, I think it would be sufficient to just add ; ignored_kwargs... to the function signatures or create alternate methods that include ; ignored_kwargs... and just call the non-kwarg methods.
Here is a MWE showing the problem both without and without FFTW:
julia> using CUDA, CUDA.CUFFT
julia> A=CuArray{ComplexF32}(undef, 4);
julia> plan_fft(A)
CUFFT complex forward plan for 4-element CuArray of ComplexF32
julia> plan_fft(A; flags=0)
ERROR: MethodError: no method matching plan_fft(::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, ::UnitRange{Int64}; flags=0)
Closest candidates are:
plan_fft(::CuArray{T, N}, ::Any) where {T<:Union{ComplexF32, ComplexF64}, N} at ~/.julia/packages/CUDA/GGwVa/lib/cufft/fft.jl:307 got unsupported keyword argument "flags"
plan_fft(::CuArray{<:Complex{<:Union{Integer, Rational}}}, ::Any) at ~/.julia/packages/CUDA/GGwVa/lib/cufft/fft.jl:276 got unsupported keyword argument "flags"
plan_fft(::CuArray{<:Real}, ::Any) at ~/.julia/packages/CUDA/GGwVa/lib/cufft/fft.jl:274 got unsupported keyword argument "flags"
...
Stacktrace:
[1] plan_fft(x::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}; kws::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:flags,), Tuple{Int64}}})
@ AbstractFFTs ~/.julia/packages/AbstractFFTs/SFNY3/src/definitions.jl:64
[2] top-level scope
@ REPL[5]:1
[3] top-level scope
@ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52
julia> using FFTW
julia> plan_fft(A; flags=0)
ERROR: ArgumentError: cannot take the CPU address of a CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}
Stacktrace:
[1] unsafe_convert(#unused#::Type{Ptr{ComplexF32}}, x::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer})
@ CUDA ~/.julia/packages/CUDA/GGwVa/src/array.jl:319
[2] macro expansion
@ ~/.julia/packages/FFTW/sfy1o/src/fft.jl:592 [inlined]
[3] (FFTW.cFFTWPlan{ComplexF32, -1, false, 1})(X::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, Y::Vector{ComplexF32}, region::UnitRange{Int64}, flags::Int64, timelimit::Float64)
@ FFTW ~/.julia/packages/FFTW/sfy1o/src/FFTW.jl:49
[4] plan_fft(X::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}, region::UnitRange{Int64}; flags::Int64, timelimit::Float64, num_threads::Nothing)
@ FFTW ~/.julia/packages/FFTW/sfy1o/src/fft.jl:719
[5] plan_fft(X::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}; kws::Base.Pairs{Symbol, Int64, Tuple{Symbol}, NamedTuple{(:flags,), Tuple{Int64}}})
@ FFTW ~/.julia/packages/FFTW/sfy1o/src/fft.jl:735
[6] top-level scope
@ REPL[7]:1
[7] top-level scope
@ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52