4
4
// octave (mkoctfile) needs this otherwise it doesn't know what int64_t is!
5
5
#include < complex>
6
6
7
- #include < cuComplex.h>
8
7
#include < cufinufft/types.h>
9
8
10
9
#include < cuda_runtime.h>
11
10
#include < thrust/extrema.h>
11
+ #include < tuple>
12
12
#include < type_traits>
13
13
#include < utility> // for std::forward
14
14
15
- #include < finufft_errors .h>
15
+ #include < common/common .h>
16
16
17
17
#ifndef _USE_MATH_DEFINES
18
18
#define _USE_MATH_DEFINES
19
19
#endif
20
20
#include < cmath>
21
21
22
- #ifndef M_PI
23
- #define M_PI 3.14159265358979323846
24
- #endif
25
-
26
22
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
27
23
#else
28
24
__inline__ __device__ double atomicAdd (double *address, double val) {
@@ -72,6 +68,8 @@ template<typename T> __forceinline__ __device__ auto interval(const int ns, cons
72
68
namespace cufinufft {
73
69
namespace utils {
74
70
71
+ using namespace finufft ::common;
72
+
75
73
class WithCudaDevice {
76
74
public:
77
75
explicit WithCudaDevice (const int device) : orig_device_{get_orig_device ()} {
@@ -90,10 +88,8 @@ class WithCudaDevice {
90
88
}
91
89
};
92
90
93
- // math helpers whose source is in src/cuda/utils.cpp
94
- CUFINUFFT_BIGINT next235beven (CUFINUFFT_BIGINT n, CUFINUFFT_BIGINT b);
95
- void gaussquad (int n, double *xgl, double *wgl);
96
- std::tuple<double , double > leg_eval (int n, double x);
91
+ // math helpers whose source is in src/utils.cpp
92
+ long next235beven (long n, long b);
97
93
98
94
template <typename T> T infnorm (int n, std::complex <T> *a) {
99
95
T nrm = 0.0 ;
@@ -124,8 +120,8 @@ static __forceinline__ __device__ void atomicAddComplexShared(
124
120
* on shared memory are supported so we leverage them
125
121
*/
126
122
template <typename T>
127
- static __forceinline__ __device__ void atomicAddComplexGlobal (
128
- cuda_complex<T> *address, cuda_complex<T> res) {
123
+ static __forceinline__ __device__ void atomicAddComplexGlobal (cuda_complex<T> *address,
124
+ cuda_complex<T> res) {
129
125
if constexpr (
130
126
std::is_same_v<cuda_complex<T>, float2> && COMPUTE_CAPABILITY_90_OR_HIGHER) {
131
127
atomicAdd (address, res);
@@ -150,7 +146,7 @@ template<typename T> auto arrayrange(int n, T *a, cudaStream_t stream) {
150
146
151
147
// Writes out w = half-width and c = center of an interval enclosing all a[n]'s
152
148
// Only chooses a nonzero center if this increases w by less than fraction
153
- // ARRAYWIDCEN_GROWFRAC defined in defs .h.
149
+ // ARRAYWIDCEN_GROWFRAC defined in common/constants .h.
154
150
// This prevents rephasings which don't grow nf by much. 6/8/17
155
151
// If n==0, w and c are not finite.
156
152
template <typename T> auto arraywidcen (int n, T *a, cudaStream_t stream) {
@@ -180,41 +176,27 @@ auto set_nhg_type3(T S, T X, const cufinufft_opts &opts,
180
176
else
181
177
Ssafe = std::max (Ssafe, T (1 ) / X);
182
178
// use the safe X and S...
183
- T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / M_PI + nss;
179
+ T nfd = 2.0 * opts.upsampfac * Ssafe * Xsafe / PI + nss;
184
180
if (!std::isfinite (nfd)) nfd = 0.0 ; // use FLT to catch inf
185
181
auto nf = (int )nfd;
186
182
// printf("initial nf=%lld, ns=%d\n",*nf,spopts.nspread);
187
183
// catch too small nf, and nan or +-inf, otherwise spread fails...
188
184
if (nf < 2 * spopts.nspread ) nf = 2 * spopts.nspread ;
189
- if (nf < MAX_NF) // otherwise will fail anyway
190
- nf = utils:: next235beven (nf, 1 ); // expensive at huge nf
185
+ if (nf < MAX_NF) // otherwise will fail anyway
186
+ nf = next235beven (nf, 1 ); // expensive at huge nf
191
187
// Note: b is 1 because type 3 uses a type 2 plan, so it should not need the extra
192
188
// condition that seems to be used by Block Gather as type 2 are only GM-sort
193
- auto h = 2 * T (M_PI ) / nf; // upsampled grid spacing
189
+ auto h = 2 * T (PI ) / nf; // upsampled grid spacing
194
190
auto gam = T (nf) / (2.0 * opts.upsampfac * Ssafe); // x scale fac to x'
195
191
return std::make_tuple (nf, h, gam);
196
192
}
197
193
198
- // Generalized dispatcher for any function requiring ns-based dispatch
199
- template <typename Func, typename T, int ns, typename ... Args>
200
- int dispatch_ns (Func &&func, int target_ns, Args &&...args) {
201
- if constexpr (ns > MAX_NSPREAD) {
202
- return FINUFFT_ERR_METHOD_NOTVALID; // Stop recursion
203
- } else {
204
- if (target_ns == ns) {
205
- return std::forward<Func>(func).template operator ()<ns>(
206
- std::forward<Args>(args)...);
207
- }
208
- return dispatch_ns<Func, T, ns + 1 >(std::forward<Func>(func), target_ns,
209
- std::forward<Args>(args)...);
210
- }
211
- }
212
-
213
- // Wrapper function that starts the dispatch recursion
194
+ // Wrapper around the generic dispatcher for nspread-based dispatch
214
195
template <typename Func, typename T, typename ... Args>
215
- int launch_dispatch_ns (Func &&func, int target_ns, Args &&...args) {
216
- return dispatch_ns<Func, T, MIN_NSPREAD>(std::forward<Func>(func), target_ns,
217
- std::forward<Args>(args)...);
196
+ auto launch_dispatch_ns (Func &&func, int target_ns, Args &&...args) {
197
+ using NsSeq = make_range<MIN_NSPREAD, MAX_NSPREAD>;
198
+ auto params = std::make_tuple (DispatchParam<NsSeq>{target_ns});
199
+ return dispatch (std::forward<Func>(func), params, std::forward<Args>(args)...);
218
200
}
219
201
220
202
/* *
0 commit comments