From 98a73c48881ba2fda8988968ed33a3f0a24ac591 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 11 Nov 2024 01:48:11 +0100 Subject: [PATCH 01/12] FFT module revamp. - Added several new functionnalities to the FFT module by changing the dependency from FFTPACK to POCKETFTT. - new optionnal parameters for the API (nthreads, norm, ...) - new functions for cosine and sine transforms (dct, dst, ...) - Switched from dune 2.0 to dune 3.16 (this was required as I ran throught issues with linking while using 2.0) --- .gitmodules | 3 + dune-project | 2 +- examples/dune | 17 +- src/base/core/owl_graph.ml | 2 +- src/owl/dune | 17 +- src/owl/fftpack/fftpack.h | 34 - src/owl/fftpack/fftpack_impl.h | 1450 ------------------------ src/owl/fftpack/owl_fft_d.mli | 65 +- src/owl/fftpack/owl_fft_generic.ml | 109 +- src/owl/fftpack/owl_fft_generic.mli | 112 +- src/owl/fftpack/owl_fft_s.mli | 65 +- src/owl/fftpack/owl_fftpack.ml | 218 +++- src/owl/fftpack/owl_fftpack_float32.c | 41 - src/owl/fftpack/owl_fftpack_float32.cc | 39 + src/owl/fftpack/owl_fftpack_float64.c | 41 - src/owl/fftpack/owl_fftpack_float64.cc | 38 + src/owl/fftpack/owl_fftpack_impl.h | 568 ++++++---- src/owl/fftpack/pocketfft | 1 + src/owl/nlp/owl_nlp_corpus.ml | 2 +- src/owl/nlp/owl_nlp_lda.ml | 2 +- src/owl/nlp/owl_nlp_tfidf.ml | 2 +- src/owl/nlp/owl_nlp_vocabulary.ml | 2 +- 22 files changed, 934 insertions(+), 1896 deletions(-) create mode 100644 .gitmodules delete mode 100644 src/owl/fftpack/fftpack.h delete mode 100644 src/owl/fftpack/fftpack_impl.h delete mode 100644 src/owl/fftpack/owl_fftpack_float32.c create mode 100644 src/owl/fftpack/owl_fftpack_float32.cc delete mode 100644 src/owl/fftpack/owl_fftpack_float64.c create mode 100644 src/owl/fftpack/owl_fftpack_float64.cc create mode 160000 src/owl/fftpack/pocketfft diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..775cc6982 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/owl/fftpack/pocketfft"] + path = src/owl/fftpack/pocketfft + url = https://github.com/mreineck/pocketfft diff --git a/dune-project b/dune-project index 04a070d3b..fe4c8519b 100644 --- a/dune-project +++ b/dune-project @@ -1,3 +1,3 @@ -(lang dune 2.0) +(lang dune 3.16) (name owl) diff --git a/examples/dune b/examples/dune index 6d5a16f8e..14e0e23be 100644 --- a/examples/dune +++ b/examples/dune @@ -1,5 +1,5 @@ (executables - (names + (names backprop checkpoint cifar10_vgg @@ -25,6 +25,15 @@ squeezenet test_log tfidf - vgg16 - ) - (libraries owl)) \ No newline at end of file + vgg16) + (libraries owl) + (flags ; in order to make the examples compile correctly even with the warnings. + (:standard + -warn-error + -unused-value-declaration + -warn-error + -unused-var-strict + -warn-error + -unused-var + -warn-error + -unused-field))) diff --git a/src/base/core/owl_graph.ml b/src/base/core/owl_graph.ml index 2c9d544c3..f344b1276 100644 --- a/src/base/core/owl_graph.ml +++ b/src/base/core/owl_graph.ml @@ -13,7 +13,7 @@ type 'a node = mutable next : 'a node array ; (* children of the node *) mutable attr : 'a (* indicate the validity *) - } + } [@@warning "-69"] type order = | BFS diff --git a/src/owl/dune b/src/owl/dune index 0dcc3c817..0bd4c0914 100644 --- a/src/owl/dune +++ b/src/owl/dune @@ -32,6 +32,8 @@ (copy_files# fftpack/*) +(copy_files# fftpack/pocketfft/*.h) + (copy_files# misc/*) (copy_files# nlp/*) @@ -42,6 +44,15 @@ (name owl) (public_name owl) (wrapped false) + (foreign_stubs + (language cxx) + (names + ;; FFTPACK + owl_fftpack_float32 + owl_fftpack_float64) + (flags + :standard + (:include c_flags.sexp))) (foreign_stubs (language c) (names @@ -65,9 +76,6 @@ owl_ndarray_utils_stub owl_slicing_basic_stub owl_slicing_fancy_stub - ;; FFTPACK - owl_fftpack_float32 - owl_fftpack_float64 ;; stats SFMT owl_stats_dist_beta @@ -202,7 +210,8 @@ (:include c_flags.sexp))) (c_library_flags :standard - (:include c_library_flags.sexp)) + (:include c_library_flags.sexp) + -lstdc++) (flags :standard (:include ocaml_flags.sexp)) diff --git a/src/owl/fftpack/fftpack.h b/src/owl/fftpack/fftpack.h deleted file mode 100644 index eba57ad0c..000000000 --- a/src/owl/fftpack/fftpack.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -/* Refer the doc on http://www.netlib.org/fftpack/doc */ - -#ifdef __cplusplus -extern "C" { -#endif - -// Single precision FFT - -extern void float32_fftpack_cffti(int N, const float wsave[]); -extern void float32_fftpack_cfftf(int N, float c[], const float wsave[]); -extern void float32_fftpack_cfftb(int N, float c[], const float wsave[]); - -extern void float32_fftpack_rffti(int N, const float wsave[]); -extern void float32_fftpack_rfftf(int N, float r[], const float wsave[]); -extern void float32_fftpack_rfftb(int N, float r[], const float wsave[]); - -// Double precision FFT - -extern void float64_fftpack_cffti(int N, const double wsave[]); -extern void float64_fftpack_cfftf(int N, double c[], const double wsave[]); -extern void float64_fftpack_cfftb(int N, double c[], const double wsave[]); - -extern void float64_fftpack_rffti(int N, const double wsave[]); -extern void float64_fftpack_rfftf(int N, double r[], const double wsave[]); -extern void float64_fftpack_rfftb(int N, double r[], const double wsave[]); - -#ifdef __cplusplus -} -#endif diff --git a/src/owl/fftpack/fftpack_impl.h b/src/owl/fftpack/fftpack_impl.h deleted file mode 100644 index 9e90375c7..000000000 --- a/src/owl/fftpack/fftpack_impl.h +++ /dev/null @@ -1,1450 +0,0 @@ -/* - * fftpack.c : A set of FFT routines in C. - * Algorithmically based on Fortran-77 FFTPACK by Paul N. Swarztrauber (Version 4, 1985). - * - * Further adapted into Owl from Numpy library. -*/ - -#include -#include -#include - -#define ref(u,a) u[a] - -#define MAXFAC 13 /* maximum number of factors in factorization of n */ -#define NSPECIAL 4 /* number of factors for which we have special-case routines */ - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef Treal - - -/* ---------------------------------------------------------------------- - passf2, passf3, passf4, passf5, passf. Complex FFT passes fwd and bwd. ------------------------------------------------------------------------ */ - -static void passf2(int ido, int l1, const Treal cc[], Treal ch[], const Treal wa1[], int isign) - /* isign==+1 for backward transform */ - { - int i, k, ah, ac; - Treal ti2, tr2; - if (ido <= 2) { - for (k=0; k= l1) { - for (j=1; j idp) idlj -= idp; - war = wa[idlj - 2]; - wai = wa[idlj-1]; - for (ik=0; ik= l1) { - for (j=1; j= l1) { - for (k=0; k= l1) { - for (j=1; j= l1) { - for (k=0; k= l1) { - for (j=1; j= l1) { - for (j=1; j 5) { - wa[i1-1] = wa[i-1]; - wa[i1] = wa[i]; - } - } - l1 = l2; - } - } /* cffti1 */ - - - /* ------------------------------------------------------------------- -rfftf1, rfftb1, owl_fftpack_rfftf, owl_fftpack_rfftb, rffti1, owl_fftpack_rffti. Treal FFTs. ----------------------------------------------------------------------- */ - -static void rfftf1(int n, Treal c[], Treal ch[], const Treal wa[], const int ifac[MAXFAC+2]) - { - int i; - int k1, l1, l2, na, kh, nf, ip, iw, ix2, ix3, ix4, ido, idl1; - Treal *cinput, *coutput; - nf = ifac[1]; - na = 1; - l2 = n; - iw = n-1; - for (k1 = 1; k1 <= nf; ++k1) { - kh = nf - k1; - ip = ifac[kh + 2]; - l1 = l2 / ip; - ido = n / l2; - idl1 = ido*l1; - iw -= (ip - 1)*ido; - na = !na; - if (na) { - cinput = ch; - coutput = c; - } else { - cinput = c; - coutput = ch; - } - switch (ip) { - case 4: - ix2 = iw + ido; - ix3 = ix2 + ido; - radf4(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3]); - break; - case 2: - radf2(ido, l1, cinput, coutput, &wa[iw]); - break; - case 3: - ix2 = iw + ido; - radf3(ido, l1, cinput, coutput, &wa[iw], &wa[ix2]); - break; - case 5: - ix2 = iw + ido; - ix3 = ix2 + ido; - ix4 = ix3 + ido; - radf5(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3], &wa[ix4]); - break; - default: - if (ido == 1) - na = !na; - if (na == 0) { - radfg(ido, ip, l1, idl1, c, ch, &wa[iw]); - na = 1; - } else { - radfg(ido, ip, l1, idl1, ch, c, &wa[iw]); - na = 0; - } - } - l2 = l1; - } - if (na == 1) return; - for (i = 0; i < n; i++) c[i] = ch[i]; - } /* rfftf1 */ - - -static void rfftb1(int n, Treal c[], Treal ch[], const Treal wa[], const int ifac[MAXFAC+2]) - { - int i; - int k1, l1, l2, na, nf, ip, iw, ix2, ix3, ix4, ido, idl1; - Treal *cinput, *coutput; - nf = ifac[1]; - na = 0; - l1 = 1; - iw = 0; - for (k1=1; k1<=nf; k1++) { - ip = ifac[k1 + 1]; - l2 = ip*l1; - ido = n / l2; - idl1 = ido*l1; - if (na) { - cinput = ch; - coutput = c; - } else { - cinput = c; - coutput = ch; - } - switch (ip) { - case 4: - ix2 = iw + ido; - ix3 = ix2 + ido; - radb4(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3]); - na = !na; - break; - case 2: - radb2(ido, l1, cinput, coutput, &wa[iw]); - na = !na; - break; - case 3: - ix2 = iw + ido; - radb3(ido, l1, cinput, coutput, &wa[iw], &wa[ix2]); - na = !na; - break; - case 5: - ix2 = iw + ido; - ix3 = ix2 + ido; - ix4 = ix3 + ido; - radb5(ido, l1, cinput, coutput, &wa[iw], &wa[ix2], &wa[ix3], &wa[ix4]); - na = !na; - break; - default: - radbg(ido, ip, l1, idl1, cinput, coutput, &wa[iw]); - if (ido == 1) na = !na; - } - l1 = l2; - iw += (ip - 1)*ido; - } - if (na == 0) return; - for (i=0; i (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -val ifft : ?axis:int -> (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -val rfft : ?axis:int -> (float, float64_elt) t -> (Complex.t, complex64_elt) t +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (float, float64_elt) t + -> (Complex.t, complex64_elt) t -val irfft : ?axis:int -> ?n:int -> (Complex.t, complex64_elt) t -> (float, float64_elt) t +val irfft + : ?axis:int + -> ?n:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex64_elt) t + -> (float, float64_elt) t val fft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float64_elt) t + -> (float, float64_elt) t diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index f785e7c46..4d24c219c 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -5,68 +5,143 @@ open Owl_dense_ndarray_generic -let fft ?axis x = +let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftf (kind x) x y axis; + Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads; y -let ifft ?axis x = +let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftb (kind x) x y axis; - let norm = Complex.{ re = float_of_int (shape y).(axis); im = 0. } in - div_scalar_ y norm; + Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads; y -let rfft ?axis ~otyp x = +let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in s.(axis) <- (s.(axis) / 2) + 1; let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftf ityp otyp x y axis; + Owl_fftpack._owl_rfftf ityp otyp x y axis norm nthreads; y -let irfft ?axis ?n ~otyp x = +let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a - | None -> num_dims x - 1 + | None -> num_dims x - 1 in + let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let s = shape x in let _ = match n with | Some n -> s.(axis) <- n - | None -> s.(axis) <- (s.(axis) - 1) * 2 + | None -> s.(axis) <- (s.(axis) - 1) * 2 in let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftb ityp otyp x y axis; - let norm = float_of_int s.(axis) in - div_scalar_ y norm; + Owl_fftpack._owl_rfftb ityp otyp x y axis norm nthreads; y let fft2 x = fft ~axis:0 x |> fft ~axis:1 let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1 + +let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads; + y + + +let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads; + y + + +let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads; + y + + +let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = + let axis = + match axis with + | Some a -> a + | None -> num_dims x - 1 + in + let axis = if axis < 0 then num_dims x + axis else axis in + assert (axis < num_dims x); + let ortho = + match ortho with + | Some o -> o + | None -> if norm = 2 then true else false + in + assert (ttype > 0 || ttype < 5); + let y = empty (kind x) (shape x) in + Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads; + y diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index d805a3b4c..ef0d63032 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -7,45 +7,93 @@ open Owl_dense_ndarray_generic -(** {5 Basic functions} *) +(** {5 Discrete Fourier Transforms functions} *) -val fft : ?axis:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the -highest dimension if not specified. The return is not scaled. - *) +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, 'a) t + -> (Complex.t, 'a) t +(** [fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the + highest dimension if not specified. The return is not scaled. *) -val ifft : ?axis:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] -indicates the highest dimension by default. - *) +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, 'a) t + -> (Complex.t, 'a) t +(** [ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] + indicates the highest dimension by default. *) -val rfft : ?axis:int -> otyp:(Complex.t, 'a) kind -> (float, 'b) t -> (Complex.t, 'a) t -(** -[rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the -[axis]. [otyp] is used to specify the output type, it must be the consistent -precision with input [x]. You can skip this parameter by using a submodule -with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. - *) +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> otyp:('a, 'b) kind + -> ('c, 'd) t + -> ('a, 'b) t +(** [rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the + [axis]. [otyp] is used to specify the output type, it must be the consistent + precision with input [x]. You can skip this parameter by using a submodule + with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) - val irfft +val irfft : ?axis:int -> ?n:int - -> otyp:(float, 'a) kind - -> (Complex.t, 'b) t - -> (float, 'a) t -(** -[irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter -is used to specified the size of output. - *) + -> ?norm:int + -> ?nthreads:int + -> otyp:('a, 'b) kind + -> ('c, 'd) t + -> ('a, 'b) t +(** [irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter + is used to specified the size of output. *) val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. - *) +(** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *) val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** -[ifft2 x] performs inverse 2-dimensional FFT on a complex input. - *) +(** [ifft2 x] performs inverse 2-dimensional FFT on a complex input. *) + +(** {5 Discrete Cosine & Sine Transforms functions} *) + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [dct ~axis ~type x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [idct ~axis ~type x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [dst ~axis ~type x] performs 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> ('a, 'b) t + -> ('a, 'b) t +(** [idst ~axis ~type x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) diff --git a/src/owl/fftpack/owl_fft_s.mli b/src/owl/fftpack/owl_fft_s.mli index fe7fab5e6..aa4b0a643 100644 --- a/src/owl/fftpack/owl_fft_s.mli +++ b/src/owl/fftpack/owl_fft_s.mli @@ -6,14 +6,71 @@ open Bigarray open Owl_dense_ndarray_generic -val fft : ?axis:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val fft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -val ifft : ?axis:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val ifft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t + -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -val rfft : ?axis:int -> (float, float32_elt) t -> (Complex.t, complex32_elt) t +val rfft + : ?axis:int + -> ?norm:int + -> ?nthreads:int + -> (float, float32_elt) t + -> (Complex.t, complex32_elt) t -val irfft : ?axis:int -> ?n:int -> (Complex.t, complex32_elt) t -> (float, float32_elt) t +val irfft + : ?axis:int + -> ?n:int + -> ?norm:int + -> ?nthreads:int + -> (Complex.t, complex32_elt) t + -> (float, float32_elt) t val fft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t val ifft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t + +val dct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val idct + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val dst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t + +val idst + : ?axis:int + -> ?ttype:int + -> ?norm:int + -> ?ortho:bool + -> ?nthreads:int + -> (float, float32_elt) t + -> (float, float32_elt) t diff --git a/src/owl/fftpack/owl_fftpack.ml b/src/owl/fftpack/owl_fftpack.ml index 205b02db7..4e54c22ee 100644 --- a/src/owl/fftpack/owl_fftpack.ml +++ b/src/owl/fftpack/owl_fftpack.ml @@ -6,10 +6,13 @@ open Bigarray open Owl_core_types +(* Forward Real FFT *) external owl_float32_rfftf : (float, float32_elt) owl_arr -> (Complex.t, complex32_elt) owl_arr -> int + -> int + -> int -> unit = "float32_rfftf" @@ -17,24 +20,37 @@ external owl_float64_rfftf : (float, float64_elt) owl_arr -> (Complex.t, complex64_elt) owl_arr -> int + -> int + -> int -> unit = "float64_rfftf" let _owl_rfftf - : type a b c d. - (a, b) kind -> (c, d) kind -> (a, b) owl_arr -> (c, d) owl_arr -> int -> unit + : type a b c d. + (a, b) kind + -> (c, d) kind + -> (a, b) owl_arr + -> (c, d) owl_arr + -> int + -> int + -> int + -> unit = - fun ityp otyp x y axis -> + fun ityp otyp x y axis norm nthreads -> match ityp, otyp with - | Float32, Complex32 -> owl_float32_rfftf x y axis - | Float64, Complex64 -> owl_float64_rfftf x y axis - | _ -> failwith "_owl_rfftf: unsupported operation" + | Float32, Complex32 -> owl_float32_rfftf x y axis norm nthreads + | Float64, Complex64 -> owl_float64_rfftf x y axis norm nthreads + | _ -> failwith "_owl_rfftf: unsupported operation" +(* Backward Real FFT *) + external owl_float32_rfftb : (Complex.t, complex32_elt) owl_arr -> (float, float32_elt) owl_arr -> int + -> int + -> int -> unit = "float32_rfftb" @@ -42,57 +58,191 @@ external owl_float64_rfftb : (Complex.t, complex64_elt) owl_arr -> (float, float64_elt) owl_arr -> int + -> int + -> int -> unit = "float64_rfftb" let _owl_rfftb - : type a b c d. - (a, b) kind -> (c, d) kind -> (a, b) owl_arr -> (c, d) owl_arr -> int -> unit + : type a b c d. + (a, b) kind + -> (c, d) kind + -> (a, b) owl_arr + -> (c, d) owl_arr + -> int + -> int + -> int + -> unit = - fun ityp otyp x y axis -> + fun ityp otyp x y axis norm nthreads -> match ityp, otyp with - | Complex32, Float32 -> owl_float32_rfftb x y axis - | Complex64, Float64 -> owl_float64_rfftb x y axis - | _ -> failwith "_owl_rfftb: unsupported operation" + | Complex32, Float32 -> owl_float32_rfftb x y axis norm nthreads + | Complex64, Float64 -> owl_float64_rfftb x y axis norm nthreads + | _ -> failwith "_owl_rfftb: unsupported operation" -external owl_complex32_cfftf - : (Complex.t, complex32_elt) owl_arr +external owl_complex32_cfft + : bool -> (Complex.t, complex32_elt) owl_arr + -> (Complex.t, complex32_elt) owl_arr + -> int + -> int -> int -> unit - = "float32_cfftf" + = "float64_cfft_bytecode" "float32_cfft" -external owl_complex64_cfftf - : (Complex.t, complex64_elt) owl_arr +external owl_complex64_cfft + : bool + -> (Complex.t, complex64_elt) owl_arr -> (Complex.t, complex64_elt) owl_arr -> int + -> int + -> int -> unit - = "float64_cfftf" + = "float64_cfft_bytecode" "float64_cfft" -let _owl_cfftf : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> unit +let _owl_cfftf + : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> int -> int -> unit = function - | Complex32 -> owl_complex32_cfftf - | Complex64 -> owl_complex64_cfftf - | _ -> failwith "_owl_cfftf: unsupported operation" + | Complex32 -> true |> owl_complex32_cfft + | Complex64 -> true |> owl_complex64_cfft + | _ -> failwith "_owl_cfftf: unsupported operation" -external owl_complex32_cfftb - : (Complex.t, complex32_elt) owl_arr - -> (Complex.t, complex32_elt) owl_arr +let _owl_cfftb + : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> int -> int -> unit + = function + | Complex32 -> false |> owl_complex32_cfft + | Complex64 -> false |> owl_complex64_cfft + | _ -> failwith "_owl_cfftb: unsupported operation" + + +(* DCT and DST *) + +(* little helper to get the inverse type of DSTs and DCTs *) +let inverse_map = function + | 1 -> 1 + | 2 -> 3 + | 3 -> 2 + | 4 -> 4 + | _ -> failwith "unknown transform type" + + +(* DCT *) + +external owl_float32_dct + : (float, float32_elt) owl_arr + -> (float, float32_elt) owl_arr + -> int + -> int + -> int + -> bool -> int -> unit - = "float32_cfftb" + = "float32_dct_bytecode" "float32_dct" -external owl_complex64_cfftb - : (Complex.t, complex64_elt) owl_arr - -> (Complex.t, complex64_elt) owl_arr +external owl_float64_dct + : (float, float64_elt) owl_arr + -> (float, float64_elt) owl_arr + -> int + -> int + -> int + -> bool -> int -> unit - = "float64_cfftb" + = "float64_dct_bytecode" "float64_dct" + +let _owl_dctf + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = function + | Float32 -> owl_float32_dct + | Float64 -> owl_float64_dct + | _ -> failwith "_owl_dctf: unsupported operation" + + +let _owl_dctb + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = + fun ityp x y ttype axis norm ortho nthreads -> + match ityp with + | Float32 -> owl_float32_dct x y (inverse_map ttype) axis norm ortho nthreads + | Float64 -> owl_float64_dct x y (inverse_map ttype) axis norm ortho nthreads + | _ -> failwith "_owl_dctb: unsupported operation" + -let _owl_cfftb : type a b. (a, b) kind -> (a, b) owl_arr -> (a, b) owl_arr -> int -> unit +(* DST *) + +external owl_float32_dst + : (float, float32_elt) owl_arr + -> (float, float32_elt) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = "float32_dst_bytecode" "float32_dst" + +external owl_float64_dst + : (float, float64_elt) owl_arr + -> (float, float64_elt) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = "float64_dst_bytecode" "float64_dst" + +let _owl_dstf + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit = function - | Complex32 -> owl_complex32_cfftb - | Complex64 -> owl_complex64_cfftb - | _ -> failwith "_owl_cfftf: unsupported operation" + | Float32 -> owl_float32_dst + | Float64 -> owl_float64_dst + | _ -> failwith "_owl_dstf: unsupported operation" + + +let _owl_dstb + : type a b. + (a, b) kind + -> (a, b) owl_arr + -> (a, b) owl_arr + -> int + -> int + -> int + -> bool + -> int + -> unit + = + fun ityp x y ttype axis norm ortho nthreads -> + match ityp with + | Float32 -> owl_float32_dst x y (inverse_map ttype) axis norm ortho nthreads + | Float64 -> owl_float64_dst x y (inverse_map ttype) axis norm ortho nthreads + | _ -> failwith "_owl_dstb: unsupported operation" diff --git a/src/owl/fftpack/owl_fftpack_float32.c b/src/owl/fftpack/owl_fftpack_float32.c deleted file mode 100644 index 4bbd337aa..000000000 --- a/src/owl/fftpack/owl_fftpack_float32.c +++ /dev/null @@ -1,41 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -#include - -#include "owl_core.h" - - -#define Treal float - -#define REAL_COPY owl_float32_copy -#define COMPLEX_COPY owl_complex32_copy -#define FFTPACK_CFFTI float32_fftpack_cffti -#define FFTPACK_CFFTF float32_fftpack_cfftf -#define FFTPACK_CFFTB float32_fftpack_cfftb -#define FFTPACK_RFFTI float32_fftpack_rffti -#define FFTPACK_RFFTF float32_fftpack_rfftf -#define FFTPACK_RFFTB float32_fftpack_rfftb -#define STUB_CFFTF float32_cfftf -#define STUB_CFFTB float32_cfftb -#define STUB_RFFTF float32_rfftf -#define STUB_RFFTB float32_rfftb - -#include "owl_fftpack_impl.h" - -#undef REAL_COPY -#undef COMPLEX_COPY -#undef FFTPACK_CFFTI -#undef FFTPACK_CFFTF -#undef FFTPACK_CFFTB -#undef FFTPACK_RFFTI -#undef FFTPACK_RFFTF -#undef FFTPACK_RFFTB -#undef STUB_CFFTF -#undef STUB_CFFTB -#undef STUB_RFFTF -#undef STUB_RFFTB - -#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float32.cc b/src/owl/fftpack/owl_fftpack_float32.cc new file mode 100644 index 000000000..c50778385 --- /dev/null +++ b/src/owl/fftpack/owl_fftpack_float32.cc @@ -0,0 +1,39 @@ +/* + * OWL - OCaml Scientific Computing + * Copyright (c) 2016-2022 Liang Wang + */ + +#include + +#define Treal float + +extern "C" +{ +#include "owl_core.h" +} + +#define REAL_COPY owl_float32_copy +#define COMPLEX_COPY owl_complex32_copy +#define STUB_CFFT float32_cfft +#define STUB_CFFT_bytecode float32_cfft_bytecode +#define STUB_RFFTF float32_rfftf +#define STUB_RFFTB float32_rfftb +#define STUB_RDCT float32_dct +#define STUB_RDCT_bytecode float32_dct_bytecode +#define STUB_RDST float32_dst +#define STUB_RDST_bytecode float32_dst_bytecode + +#include "owl_fftpack_impl.h" + +#undef REAL_COPY +#undef COMPLEX_COPY +#undef STUB_CFFT +#undef STUB_CFFT_bytecode +#undef STUB_RFFTF +#undef STUB_RFFTB +#undef STUB_RDCT +#undef STUB_RDCT_bytecode +#undef STUB_RDST +#undef STUB_RDST_bytecode + +#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float64.c b/src/owl/fftpack/owl_fftpack_float64.c deleted file mode 100644 index 2a9fb7832..000000000 --- a/src/owl/fftpack/owl_fftpack_float64.c +++ /dev/null @@ -1,41 +0,0 @@ -/* - * OWL - OCaml Scientific Computing - * Copyright (c) 2016-2022 Liang Wang - */ - -#include - -#include "owl_core.h" - - -#define Treal double - -#define REAL_COPY owl_float64_copy -#define COMPLEX_COPY owl_complex64_copy -#define FFTPACK_CFFTI float64_fftpack_cffti -#define FFTPACK_CFFTF float64_fftpack_cfftf -#define FFTPACK_CFFTB float64_fftpack_cfftb -#define FFTPACK_RFFTI float64_fftpack_rffti -#define FFTPACK_RFFTF float64_fftpack_rfftf -#define FFTPACK_RFFTB float64_fftpack_rfftb -#define STUB_CFFTF float64_cfftf -#define STUB_CFFTB float64_cfftb -#define STUB_RFFTF float64_rfftf -#define STUB_RFFTB float64_rfftb - -#include "owl_fftpack_impl.h" - -#undef REAL_COPY -#undef COMPLEX_COPY -#undef FFTPACK_CFFTI -#undef FFTPACK_CFFTF -#undef FFTPACK_CFFTB -#undef FFTPACK_RFFTI -#undef FFTPACK_RFFTF -#undef FFTPACK_RFFTB -#undef STUB_CFFTF -#undef STUB_CFFTB -#undef STUB_RFFTF -#undef STUB_RFFTB - -#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_float64.cc b/src/owl/fftpack/owl_fftpack_float64.cc new file mode 100644 index 000000000..b34050df0 --- /dev/null +++ b/src/owl/fftpack/owl_fftpack_float64.cc @@ -0,0 +1,38 @@ +/* + * OWL - OCaml Scientific Computing + * Copyright (c) 2016-2022 Liang Wang + */ + +#include +#define Treal double + +extern "C" +{ +#include "owl_core.h" +} + +#define REAL_COPY owl_float64_copy +#define COMPLEX_COPY owl_complex64_copy +#define STUB_CFFT float64_cfft +#define STUB_CFFT_bytecode float64_cfft_bytecode +#define STUB_RFFTF float64_rfftf +#define STUB_RFFTB float64_rfftb +#define STUB_RDCT float64_dct +#define STUB_RDCT_bytecode float64_dct_bytecode +#define STUB_RDST float64_dst +#define STUB_RDST_bytecode float64_dst_bytecode + +#include "owl_fftpack_impl.h" + +#undef REAL_COPY +#undef COMPLEX_COPY +#undef STUB_CFFT +#undef STUB_CFFT_bytecode +#undef STUB_RFFTF +#undef STUB_RFFTB +#undef STUB_RDCT +#undef STUB_RDCT_bytecode +#undef STUB_RDST +#undef STUB_RDST_bytecode + +#undef Treal diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index f62491c6e..f3a36ebc4 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -3,264 +3,382 @@ * Copyright (c) 2016-2022 Liang Wang */ - #ifdef Treal -#include "fftpack_impl.h" - - -/** Owl's interface function to FFTPACK **/ - - -void FFTPACK_CFFTI (int n, Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cffti1(n, wsave + iw1, (int*) (wsave + iw2)); -} - - -void FFTPACK_CFFTF (int n, Treal c[], Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cfftf1(n, c, wsave, wsave + iw1, (int*) (wsave + iw2), -1); -} - - -void FFTPACK_CFFTB (int n, Treal c[], Treal wsave[]) { - if (n == 1) return; - int iw1 = 2 * n; - int iw2 = iw1 + 2 * n; - cfftf1(n, c, wsave, wsave + iw1, (int*) (wsave + iw2), +1); -} - - -void FFTPACK_RFFTI (int n, Treal wsave[]) { - if (n == 1) return; - rffti1(n, wsave + n, (int*) (wsave + 2 * n)); -} - - -void FFTPACK_RFFTF (int n, Treal r[], Treal wsave[]) { - if (n == 1) return; - rfftf1(n, r, wsave, wsave + n, (int*) (wsave + 2 * n)); -} - - -void FFTPACK_RFFTB(int n, Treal r[], Treal wsave[]) { - if (n == 1) return; - rfftb1(n, r, wsave, wsave + n, (int*) (wsave + 2 * n)); -} - - -/** Helper functions **/ - - -// uppack from halfcomplex x to complex y -static OWL_INLINE void halfcomplex_unpack (int n, Treal* x, int ofsx, int incx, _Complex Treal* y, int ofsy, int incy) { - int i; - *(y + ofsy) = *(x + ofsx) + 0 * I; - - for (i = 1; i < n - i; i++) { - ofsx += incx + incx; - ofsy += incy; - Treal re = *(x + ofsx - incx); - Treal im = *(x + ofsx); - *(y + ofsy) = re + im * I; +#include "pocketfft_hdronly.h" + +/** Owl's interface function to pocketfft **/ +/** Adapted from scipy's pypocketfft.cxx **/ + +using namespace pocketfft::detail; + +template +T norm_fct(int inorm, size_t N) +{ + switch (inorm) + { + case 0: // "backward" - no normalization for forward transform + return T(1); + case 1: // "forward" - 1/n normalization for forward transform + return T(1) / T(N); + case 2: // "ortho" - 1/sqrt(n) normalization for both directions + return T(1) / std::sqrt(T(N)); + default: + caml_failwith("invalid value for inorm (must be 0, 1, or 2)"); + // This will never be reached + return T(0); } - - if (i == n - i) - *(y + ofsy + incy) = *(x + ofsx + incx) + 0 * I; } - -// pack from complex x to halfcomplex y -static OWL_INLINE void halfcomplex_pack (int n, _Complex Treal* x, int ofsx, int incx, Treal* y, int ofsy, int incy) { - int i; - *(y + ofsy) = creal(*(x + ofsx)); - - for (i = 1; i < n - i; i++) { - ofsx += incx; - ofsy += incy + incy; - *(y + ofsy - incy) = creal(*(x + ofsx)); - *(y + ofsy) = cimag(*(x + ofsx)); +template +T compute_norm_factor(const shape_t &dims, const shape_t &axes, int inorm, size_t fct = 1, int delta = 0) +{ + if (inorm == 0) + return T(1); + size_t N = 1; + for (auto a : axes) + { + N *= fct * size_t(int64_t(dims[a]) + delta); } - - if (i == n - i) - *(y + ofsy + incy) = creal(*(x + ofsx + incx)); + return norm_fct(inorm, N); } - -/** Owl's stub functions **/ - - -value STUB_CFFTF (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; - - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; - - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 4 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(2 * n * sizeof(Treal)); - - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; - - FFTPACK_CFFTI(n, wsave); - - int ofsx = 0; - int ofsy = 0; - - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - COMPLEX_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_CFFTF(n, (Treal*) data, wsave); - COMPLEX_COPY(n, data, 0, 1, Y_data, ofsy + j, stdy); +extern "C" +{ + + /** Owl's stub functions **/ + + /** + * Complex-to-complex FFT + * @param forward: true for forward transform, false for backward transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + int forward = Bool_val(vForward); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); } - ofsx += slcx; - ofsy += slcy; - } - - free(wsave); - free(data); - - return Val_unit; -} - - -value STUB_CFFTB (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; - - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; - - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 4 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(2 * n * sizeof(Treal)); - - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; - - FFTPACK_CFFTI(n, wsave); - int ofsx = 0; - int ofsy = 0; - - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - COMPLEX_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_CFFTB(n, (Treal*) data, wsave); - COMPLEX_COPY(n, data, 0, 1, Y_data, ofsy + j, stdy); + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); } - ofsx += slcx; - ofsy += slcy; - } - - free(wsave); - free(data); - return Val_unit; -} + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::detail::c2c(dims, stride_in, stride_out, axes, forward, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } + return Val_unit; + } -value STUB_RFFTF (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - Treal *X_data = (Treal *) X->data; + /** + * Complex-to-complex FFT + * @param argv: array of arguments + * @param argn: number of arguments + * @see STUB_CFFT, https://ocaml.org/manual/5.2/intfc.html#ss:c-prim-impl + */ + value STUB_CFFT_bytecode(value *argv, int argn) + { + return STUB_CFFT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]); + } - struct caml_ba_array *Y = Caml_ba_array_val(vY); - _Complex Treal *Y_data = (_Complex Treal *) Y->data; + /** + * Forward Real-to-complex FFT + * @param X: input array (real data) + * @param Y: output array (complex data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - int d = Long_val(vD); - int n = X->dim[d]; - size_t ws_sz = 2 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(n * sizeof(Treal)); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; + multiplier = sizeof(std::complex); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - FFTPACK_RFFTI(n, wsave); + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } - int ofsx = 0; - int ofsy = 0; + return Val_unit; + } - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - REAL_COPY(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_RFFTF(n, (Treal*) data, wsave); - halfcomplex_unpack(n, data, 0, 1, Y_data, ofsy + j, stdy); + /** + * Backward Real-to-complex FFT + * @param X: input array (complex data) + * @param Y: output array (real data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + + if (Y->dim[d] != (X->dim[d] - 1) * 2) + caml_failwith("Invalid output dimension for inverse real FFT"); + + shape_t dims; + stride_t stride_in, stride_out; + + int ncomplex = X->dim[d]; + int nreal = Y->dim[d]; + + for (int i = 0; i < X->num_dims; ++i) + { + if (i == d) + { + dims.push_back(static_cast(nreal)); + } + else + { + dims.push_back(static_cast(X->dim[i])); + } } - ofsx += slcx; - ofsy += slcy; - } - free(wsave); - free(data); + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - return Val_unit; -} + multiplier = sizeof(Treal); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try + { + pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD, + X_data, Y_data, norm_factor, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } -value STUB_RFFTB (value vX, value vY, value vD) { - struct caml_ba_array *X = Caml_ba_array_val(vX); - _Complex Treal *X_data = (_Complex Treal *) X->data; + return Val_unit; + } - struct caml_ba_array *Y = Caml_ba_array_val(vY); - Treal *Y_data = (Treal *) Y->data; + /** + * Discrete Cosine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DCT (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - int d = Long_val(vD); - int n = Y->dim[d]; - size_t ws_sz = 2 * n * sizeof(Treal); - size_t fc_sz = (MAXFAC + 2) * sizeof(int); - void* wsave = malloc(ws_sz + fc_sz); - void* data = malloc(n * sizeof(_Complex Treal)); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - int stdx = c_ndarray_stride_dim(X, d); - int slcx = c_ndarray_slice_dim(X,d); - int stdy = c_ndarray_stride_dim(Y, d); - int slcy = c_ndarray_slice_dim(Y,d); - int m = c_ndarray_numel(X) / slcx; + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + : compute_norm_factor(dims, axes, norm, 2); + try + { + pocketfft::detail::dct(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } - FFTPACK_RFFTI(n, wsave); + return Val_unit; + } - int ofsx = 0; - int ofsy = 0; + value STUB_RDCT_bytecode(value *argv, int argn) + { + return STUB_RDCT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); + } - for (int i = 0; i < m; i ++) { - for (int j = 0; j < stdx; j++) { - halfcomplex_pack(n, X_data, ofsx + j, stdx, data, 0, 1); - FFTPACK_RFFTB(n, (Treal*) data, wsave); - REAL_COPY(n, (Treal*) data, 0, 1, Y_data, ofsy + j, stdy); + /** + * Discrete Sine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DST (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ + value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) + { + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); + + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); + + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); + + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); } - ofsx += slcx; - ofsy += slcy; - } - free(wsave); - free(data); + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - return Val_unit; -} + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + : compute_norm_factor(dims, axes, norm, 2); + try + { + pocketfft::detail::dst(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); + } + catch (const std::exception &e) + { + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? + } + } + return Val_unit; + } -#endif //Treal + value STUB_RDST_bytecode(value *argv, int argn) + { + return STUB_RDST(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); + } +} // extern "C" +#endif // Treal diff --git a/src/owl/fftpack/pocketfft b/src/owl/fftpack/pocketfft new file mode 160000 index 000000000..bb87ca50d --- /dev/null +++ b/src/owl/fftpack/pocketfft @@ -0,0 +1 @@ +Subproject commit bb87ca50df0478415a12d9011dc374eeed4e9d93 diff --git a/src/owl/nlp/owl_nlp_corpus.ml b/src/owl/nlp/owl_nlp_corpus.ml index 89c42b161..84a7199c0 100644 --- a/src/owl/nlp/owl_nlp_corpus.ml +++ b/src/owl/nlp/owl_nlp_corpus.ml @@ -21,7 +21,7 @@ type t = mutable minlen : int ; (* minimum length of document to save *) mutable docid : int array (* document id, can refer to original data *) - } + } [@@warning "-69"] let _close_if_open = function | Some h -> close_in h diff --git a/src/owl/nlp/owl_nlp_lda.ml b/src/owl/nlp/owl_nlp_lda.ml index 90894f2c0..05cb5dce9 100644 --- a/src/owl/nlp/owl_nlp_lda.ml +++ b/src/owl/nlp/owl_nlp_lda.ml @@ -43,7 +43,7 @@ type model = mutable data : Owl_nlp_corpus.t ; (* training data, tokenised*) mutable vocb : (string, int) Hashtbl.t (* vocabulary, or dictionary if you prefer *) - } + } [@@warning "-69"] let include_token m w d k = m.t__k.(k) <- m.t__k.(k) +. 1.; diff --git a/src/owl/nlp/owl_nlp_tfidf.ml b/src/owl/nlp/owl_nlp_tfidf.ml index d083bffac..6142c285a 100644 --- a/src/owl/nlp/owl_nlp_tfidf.ml +++ b/src/owl/nlp/owl_nlp_tfidf.ml @@ -30,7 +30,7 @@ type t = mutable corpus : Owl_nlp_corpus.t ; (* corpus type *) mutable handle : in_channel option (* file descriptor of the tfidf *) - } + } [@@warning "-69"] (* various types of TF and IDF functions *) diff --git a/src/owl/nlp/owl_nlp_vocabulary.ml b/src/owl/nlp/owl_nlp_vocabulary.ml index 913267dff..195f7c96d 100644 --- a/src/owl/nlp/owl_nlp_vocabulary.ml +++ b/src/owl/nlp/owl_nlp_vocabulary.ml @@ -11,7 +11,7 @@ type t = mutable i2w : (int, string) Hashtbl.t ; (* index -> word *) mutable i2f : (int, int) Hashtbl.t (* index -> freq *) - } + } [@@warning "-69"] let get_w2i d = d.w2i From 6d324e9454f23dddf9c9538d9752bfbcf1b09af7 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 11 Nov 2024 14:21:19 +0100 Subject: [PATCH 02/12] Adding abstract types and documentation. - Added the ttrig_transform type to specify the DCT and DST types - Added the tnorm type to specify the normalization option for the FFTs. --- src/owl/fftpack/owl_fft_d.mli | 25 ++++----- src/owl/fftpack/owl_fft_generic.ml | 66 +++++++++++++++--------- src/owl/fftpack/owl_fft_generic.mli | 78 ++++++++++++++++++++--------- src/owl/fftpack/owl_fft_s.mli | 25 ++++----- 4 files changed, 122 insertions(+), 72 deletions(-) diff --git a/src/owl/fftpack/owl_fft_d.mli b/src/owl/fftpack/owl_fft_d.mli index 093791d42..ea80f18bb 100644 --- a/src/owl/fftpack/owl_fft_d.mli +++ b/src/owl/fftpack/owl_fft_d.mli @@ -5,24 +5,25 @@ open Bigarray open Owl_dense_ndarray_generic +open Owl_fft_generic val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex64_elt) Owl_dense_ndarray_generic.t val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (float, float64_elt) t -> (Complex.t, complex64_elt) t @@ -30,7 +31,7 @@ val rfft val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) t -> (float, float64_elt) t @@ -41,8 +42,8 @@ val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -50,8 +51,8 @@ val dct val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -59,8 +60,8 @@ val idct val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t @@ -68,8 +69,8 @@ val dst val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float64_elt) t diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index 4d24c219c..377e7a0c7 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -5,7 +5,17 @@ open Owl_dense_ndarray_generic -let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = +type tnorm = + | Backward + | Forward + | Ortho + +let tnorm_to_int = function + | Backward -> 0 + | Forward -> 1 + | Ortho -> 2 + +let fft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a @@ -14,11 +24,11 @@ let fft ?axis ?(norm : int = 0) ?(nthreads : int = 1) x = let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftf (kind x) x y axis norm nthreads; + Owl_fftpack._owl_cfftf (kind x) x y axis (tnorm_to_int norm) nthreads; y -let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = +let ifft ?axis ?(norm : tnorm = Forward) ?(nthreads : int = 1) x = let axis = match axis with | Some a -> a @@ -27,11 +37,11 @@ let ifft ?axis ?(norm : int = 1) ?(nthreads : int = 1) x = let axis = if axis < 0 then num_dims x + axis else axis in assert (axis < num_dims x); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_cfftb (kind x) x y axis norm nthreads; + Owl_fftpack._owl_cfftb (kind x) x y axis (tnorm_to_int norm) nthreads; y -let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = +let rfft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a @@ -43,11 +53,11 @@ let rfft ?axis ?(norm : int = 0) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x s.(axis) <- (s.(axis) / 2) + 1; let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftf ityp otyp x y axis norm nthreads; + Owl_fftpack._owl_rfftf ityp otyp x y axis (tnorm_to_int norm) nthreads; y -let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = +let irfft ?axis ?n ?(norm : tnorm = Forward) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kind) x = let axis = match axis with | Some a -> a @@ -63,7 +73,7 @@ let irfft ?axis ?n ?(norm : int = 1) ?(nthreads : int = 1) ~(otyp : ('a, 'b) kin in let y = empty otyp s in let ityp = kind x in - Owl_fftpack._owl_rfftb ityp otyp x y axis norm nthreads; + Owl_fftpack._owl_rfftb ityp otyp x y axis (tnorm_to_int norm) nthreads; y @@ -71,7 +81,19 @@ let fft2 x = fft ~axis:0 x |> fft ~axis:1 let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1 -let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = +type ttrig_transform = + | I + | II + | III + | IV + +let ttrig_transform_to_int = function + | I -> 1 + | II -> 2 + | III -> 3 + | IV -> 4 + +let dct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -82,15 +104,14 @@ let dct ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dctf (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dctf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = +let idct ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -101,15 +122,14 @@ let idct ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dctb (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dctb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = 1) x = +let dst ?axis ?(ttype: ttrig_transform = II) ?(norm : tnorm = Backward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -120,15 +140,14 @@ let dst ?axis ?(ttype = 2) ?(norm : int = 0) ?(ortho : bool option) ?(nthreads = let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dstf (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dstf (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y -let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads = 1) x = +let idst ?axis ?(ttype = III) ?(norm : tnorm = Forward) ?(ortho : bool option) ?(nthreads = 1) x = let axis = match axis with | Some a -> a @@ -139,9 +158,8 @@ let idst ?axis ?(ttype = 3) ?(norm : int = 1) ?(ortho : bool option) ?(nthreads let ortho = match ortho with | Some o -> o - | None -> if norm = 2 then true else false + | None -> if norm = Ortho then true else false in - assert (ttype > 0 || ttype < 5); let y = empty (kind x) (shape x) in - Owl_fftpack._owl_dstb (kind x) x y axis ttype norm ortho nthreads; + Owl_fftpack._owl_dstb (kind x) x y axis (ttrig_transform_to_int ttype) (tnorm_to_int norm) ortho nthreads; y diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index ef0d63032..d66528bc9 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -7,48 +7,59 @@ open Owl_dense_ndarray_generic +(** Normalisation options for transforms. *) +type tnorm = + | Backward (** No normalization on Forward and scaling by 1/N on Backward *) + | Forward (** Normalization by 1/N on the Forward transform. *) + | Ortho (** Forward and Backward are scaled by 1/sqrt(N) *) + (** {5 Discrete Fourier Transforms functions} *) val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t -(** [fft ~axis x] performs 1-dimensional FFT on a complex input. [axis] is the - highest dimension if not specified. The return is not scaled. *) +(** [fft ~axis ~norm x] performs 1-dimensional FFT on a complex input. [axis] is the + highest dimension if not specified. [norm] is the normalization option. By default, [norm] is set to [Backward]. + [nthreads] is the desired number of threads used to compute the fft. *) val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'a) t -> (Complex.t, 'a) t (** [ifft ~axis x] performs inverse 1-dimensional FFT on a complex input. The parameter [axis] - indicates the highest dimension by default. *) + indicates the highest dimension by default. [norm] is the normalization option. By default, [norm] is set to [Forward]. + [nthreads] is the desired number of threads used to compute the fft. *) val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> otyp:('a, 'b) kind -> ('c, 'd) t -> ('a, 'b) t (** [rfft ~axis ~otyp x] performs 1-dimensional FFT on real input along the - [axis]. [otyp] is used to specify the output type, it must be the consistent - precision with input [x]. You can skip this parameter by using a submodule - with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) + [axis]. [norm] is the normalization option. By default, [norm] is set to [Backward]. + [nthreads] is the desired number of threads used to compute the fft. + [otyp] is used to specify the output type, it must be the consistent precision with input [x]. + You can skip this parameter by using a submodule with specific precision such as [Owl.Fft.S] or [Owl.Fft.D]. *) val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> otyp:('a, 'b) kind -> ('c, 'd) t -> ('a, 'b) t -(** [irfft ~axis ~n x] is the inverse function of [rfft]. Note the [n] parameter - is used to specified the size of output. *) +(** [irfft ~axis ~n x] is the inverse function of [rfft]. [norm] is the normalization option. + By default, [norm] is set to [Forward]. + [nthreads] is the desired number of threads used to compute the fft. + Note the [n] parameter is used to specified the size of output. *) val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t (** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *) @@ -58,42 +69,61 @@ val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t (** {5 Discrete Cosine & Sine Transforms functions} *) +type ttrig_transform = + | I + | II + | III + | IV +(** Trigonometric (Cosine and Sine) transform types. *) + val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [dct ~axis ~type x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) +(** [dct ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Backward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [idct ~axis ~type x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. Default type is 2. *) +(** [idct ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Forward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [dst ~axis ~type x] performs 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) +(** [dst ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Sine Transform (DCT) on a real input. + [ttype] is the DCT type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Backward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> ('a, 'b) t -> ('a, 'b) t -(** [idst ~axis ~type x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. Default type is 2. *) +(** [idst ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. + [ttype] is the DST type to use for this transform. Default value is [II]. + [norm] is the normalization option. By default, [norm] is set to [Forward]. + [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) \ No newline at end of file diff --git a/src/owl/fftpack/owl_fft_s.mli b/src/owl/fftpack/owl_fft_s.mli index aa4b0a643..3914d685c 100644 --- a/src/owl/fftpack/owl_fft_s.mli +++ b/src/owl/fftpack/owl_fft_s.mli @@ -5,24 +5,25 @@ open Bigarray open Owl_dense_ndarray_generic +open Owl_fft_generic val fft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t val ifft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t -> (Complex.t, complex32_elt) Owl_dense_ndarray_generic.t val rfft : ?axis:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (float, float32_elt) t -> (Complex.t, complex32_elt) t @@ -30,7 +31,7 @@ val rfft val irfft : ?axis:int -> ?n:int - -> ?norm:int + -> ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) t -> (float, float32_elt) t @@ -41,8 +42,8 @@ val ifft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t val dct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -50,8 +51,8 @@ val dct val idct : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -59,8 +60,8 @@ val idct val dst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t @@ -68,8 +69,8 @@ val dst val idst : ?axis:int - -> ?ttype:int - -> ?norm:int + -> ?ttype:ttrig_transform + -> ?norm:tnorm -> ?ortho:bool -> ?nthreads:int -> (float, float32_elt) t From aea4703ef88ca847c29c07e0e24d885c82ad6e92 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 11 Nov 2024 20:52:00 +0100 Subject: [PATCH 03/12] Adding 2 FFT example usage. - One "complex" usage: computing a PSD spectrogram from input data - One "simple" usage: computing the maximum frequency peak in time series data using rfft. --- examples/dune | 2 ++ examples/max_freq.ml | 21 +++++++++++++++ examples/specgram.ml | 62 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 examples/max_freq.ml create mode 100644 examples/specgram.ml diff --git a/examples/dune b/examples/dune index 14e0e23be..fabf06167 100644 --- a/examples/dune +++ b/examples/dune @@ -19,9 +19,11 @@ lazy_mnist linear_algebra lstm + max_freq mnist_cnn mnist_lenet newton_method + specgram squeezenet test_log tfidf diff --git a/examples/max_freq.ml b/examples/max_freq.ml new file mode 100644 index 000000000..a5a25e373 --- /dev/null +++ b/examples/max_freq.ml @@ -0,0 +1,21 @@ +(** This example shows how to compute the maximum peak frequency in time series data using the FFT module. *) + +module G = Owl.Dense.Ndarray.Generic +module FFT = Owl.Fft.Generic + +let max_freq signal sampling_rate = + (* Apply FFT *) + let fft_result = FFT.rfft ~otyp:Bigarray.Complex32 signal in + (* Get magnitude spectrum *) + let magnitudes = G.abs fft_result in + (* Find peak frequency *) + let max_idx = ref 0 in + let max_val = ref (G.get magnitudes [|0|]) in + for i = 0 to G.numel magnitudes - 1 do + let curr_val = G.get magnitudes [|i|] in + if curr_val > !max_val then ( + max_val := curr_val ; + max_idx := i ) + done ; + (* Convert index to frequency *) + float_of_int !max_idx *. sampling_rate /. float_of_int (G.numel signal) \ No newline at end of file diff --git a/examples/specgram.ml b/examples/specgram.ml new file mode 100644 index 000000000..2ef04a634 --- /dev/null +++ b/examples/specgram.ml @@ -0,0 +1,62 @@ +(** This example, extracted from the SoundML library, shows how to compute a specgram using the FFT module *) + +module G = Owl.Dense.Ndarray.Generic + +(* helper to compute the fft frequencies *) +let fftfreq (n : int) (d : float) = + let nslice = ((n - 1) / 2) + 1 in + let fhalf = + G.linspace Bigarray.Float32 0. (float_of_int nslice) nslice + in + let shalf = + G.linspace Bigarray.Float32 (-.float_of_int nslice) (-1.) nslice + in + let v = G.concatenate ~axis:0 [|fhalf; shalf|] in + Owl.Arr.(1. /. (d *. float_of_int n) $* v) + +(* Computes a one-sided PSD spectrogram with no padding and no detrend *) +let specgram (nfft : int) (fs : int) ?(noverlap : int = 0) (x : (float, Bigarray.float32_elt) G.t) = + let window = Owl.Signal.hann in (* we're using hann window *) + assert (noverlap < nfft) ; + (* we're making copies of the data from x and y to then use in place padding + and operations *) + let x = G.copy x in + (* We're making sure the arrays are at least of size nfft *) + let xshp = G.shape x in + ( if Array.get xshp 0 < nfft then + let delta = nfft - Array.get xshp 0 in + (* we're doing this in place in hope to gain a little bit of speed *) + G.pad_ ~out:x ~v:0. [[0; delta - 1]; [0; 0]] x ) ; + let scale_by_freq = true in + let pad_to = nfft in + let scaling_factor = 2. in + let window = window nfft |> G.cast_d2s in + let window = + G.reshape G.(window * ones Bigarray.float32 [|nfft|]) [|-1; 1|] + in + let res = + G.slide ~window:nfft ~step:(nfft - noverlap) x |> G.transpose + in + (* if we'd add a detrend, we'd need to do it before applying the window *) + let res = G.(res * window) in + (* here comes the rfft compute, if you're processing large audio data, you might want to set ~nthreads to + something that suits both your hardware and your needs. *) + let res = Owl.Fft.S.rfft res ~axis:0 in + let freqs = fftfreq pad_to (1. /. float_of_int fs) in + let conj = G.conj res in + (* using in-place operations to avoid array copy *) + G.mul_ ~out:res conj res; + let slice = if nfft mod 2 = 0 then [[1; -1]; []] else [[1]; []] in + let gslice = G.get_slice slice res in + G.mul_scalar_ ~out:gslice gslice Complex.{re= scaling_factor; im= 0.} ; + G.set_slice slice res gslice ; + if scale_by_freq then ( + let window = G.abs window in + G.div_scalar_ ~out:res res Complex.{re= float_of_int fs; im= 0.} ; + let n = G.sum' (G.pow_scalar window (float_of_int 2)) in + G.div_scalar_ ~out:res res Complex.{re= n; im= 0.} ) + else ( + let window = G.abs window in + let n = Float.pow (G.sum' window) 2. in + G.div_scalar_ ~out:res res Complex.{re= n; im= 0.} ) ; + (res, freqs) \ No newline at end of file From 0dbafb43b66bdae27e6ba42018de608caf854db5 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Tue, 12 Nov 2024 02:41:51 +0100 Subject: [PATCH 04/12] Adding tests, fixing minor issues. - Fixed an issue where the DCT/DST parameters weren't passed in the right order - Fixed an issue where the norm factor of the DCT wasn't computed correctly (incorrect delta) - Generated a unit_fft file that tests various parameters of FFT. Values are generated using scipy.fft module for the FFT, DCT and DST functions - Changed pocketfft::detail:: namespace usage to just pocketfft:: in the C++ code for more readability --- src/owl/fftpack/owl_fft_generic.mli | 8 +- src/owl/fftpack/owl_fftpack.ml | 12 +- src/owl/fftpack/owl_fftpack_impl.h | 14 +- test/test_runner.ml | 1 + test/unit_fft.ml | 1785 +++++++++++++++++++++++++++ 5 files changed, 1803 insertions(+), 17 deletions(-) create mode 100644 test/unit_fft.ml diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index d66528bc9..73c7a4276 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -85,7 +85,7 @@ val dct -> ('a, 'b) t -> ('a, 'b) t (** [dct ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Cosine Transform (DCT) on a real input. - [ttype] is the DCT type to use for this transform. Default value is [II]. + [ttype] is the DCT type to use for this transform. Default value is [Two]. [norm] is the normalization option. By default, [norm] is set to [Backward]. [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) @@ -98,7 +98,7 @@ val idct -> ('a, 'b) t -> ('a, 'b) t (** [idct ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Cosine Transform (DCT) on a real input. - [ttype] is the DCT type to use for this transform. Default value is [II]. + [ttype] is the DCT type to use for this transform. Default value is [Two]. [norm] is the normalization option. By default, [norm] is set to [Forward]. [ortho] constrols whether or not we should use the orthogonalized variant of the DCT. *) @@ -111,7 +111,7 @@ val dst -> ('a, 'b) t -> ('a, 'b) t (** [dst ?axis ?ttype ?norm ?ortho ?nthreads x] performs 1-dimensional Discrete Sine Transform (DCT) on a real input. - [ttype] is the DCT type to use for this transform. Default value is [II]. + [ttype] is the DCT type to use for this transform. Default value is [Two]. [norm] is the normalization option. By default, [norm] is set to [Backward]. [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) @@ -124,6 +124,6 @@ val idst -> ('a, 'b) t -> ('a, 'b) t (** [idst ?axis ?ttype ?norm ?ortho ?nthreads x] performs inverse 1-dimensional Discrete Sine Transform (DST) on a real input. - [ttype] is the DST type to use for this transform. Default value is [II]. + [ttype] is the DST type to use for this transform. Default value is [Two]. [norm] is the normalization option. By default, [norm] is set to [Forward]. [ortho] constrols whether or not we should use the orthogonalized variant of the DST. *) \ No newline at end of file diff --git a/src/owl/fftpack/owl_fftpack.ml b/src/owl/fftpack/owl_fftpack.ml index 4e54c22ee..af2f8373b 100644 --- a/src/owl/fftpack/owl_fftpack.ml +++ b/src/owl/fftpack/owl_fftpack.ml @@ -181,10 +181,10 @@ let _owl_dctb -> int -> unit = - fun ityp x y ttype axis norm ortho nthreads -> + fun ityp x y axis ttype norm ortho nthreads -> match ityp with - | Float32 -> owl_float32_dct x y (inverse_map ttype) axis norm ortho nthreads - | Float64 -> owl_float64_dct x y (inverse_map ttype) axis norm ortho nthreads + | Float32 -> owl_float32_dct x y axis (inverse_map ttype) norm ortho nthreads + | Float64 -> owl_float64_dct x y axis (inverse_map ttype) norm ortho nthreads | _ -> failwith "_owl_dctb: unsupported operation" @@ -241,8 +241,8 @@ let _owl_dstb -> int -> unit = - fun ityp x y ttype axis norm ortho nthreads -> + fun ityp x y axis ttype norm ortho nthreads -> match ityp with - | Float32 -> owl_float32_dst x y (inverse_map ttype) axis norm ortho nthreads - | Float64 -> owl_float64_dst x y (inverse_map ttype) axis norm ortho nthreads + | Float32 -> owl_float32_dst x y axis (inverse_map ttype) norm ortho nthreads + | Float64 -> owl_float64_dst x y axis (inverse_map ttype) norm ortho nthreads | _ -> failwith "_owl_dstb: unsupported operation" diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index f3a36ebc4..4f6018e35 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -93,8 +93,8 @@ extern "C" Treal norm_factor = compute_norm_factor(dims, axes, norm); try { - pocketfft::detail::c2c(dims, stride_in, stride_out, axes, forward, - X_data, Y_data, norm_factor, nthreads); + pocketfft::c2c(dims, stride_in, stride_out, axes, forward, + X_data, Y_data, norm_factor, nthreads); } catch (const std::exception &e) { @@ -294,12 +294,12 @@ extern "C" shape_t axes{static_cast(d)}; { - Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, -1) : compute_norm_factor(dims, axes, norm, 2); try { - pocketfft::detail::dct(dims, stride_in, stride_out, axes, type, - X_data, Y_data, norm_factor, ortho, nthreads); + pocketfft::dct(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); } catch (const std::exception &e) { @@ -364,8 +364,8 @@ extern "C" : compute_norm_factor(dims, axes, norm, 2); try { - pocketfft::detail::dst(dims, stride_in, stride_out, axes, type, - X_data, Y_data, norm_factor, ortho, nthreads); + pocketfft::dst(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); } catch (const std::exception &e) { diff --git a/test/test_runner.ml b/test/test_runner.ml index cce4db693..7a1f1cd90 100644 --- a/test/test_runner.ml +++ b/test/test_runner.ml @@ -37,4 +37,5 @@ let () = ; "base: complex", Unit_base_complex.test_set ; "base: ndarray core", Unit_base_ndarray_core.test_set ; "base: linalg", Unit_linalg_solver.test_set; "base: signal", Unit_signal.test_set + ; "fft", Unit_fft.test_set ; ("algodiff matrix", Unit_algodiff_matrix.[ Reverse.test; Forward.test ]) ] diff --git a/test/unit_fft.ml b/test/unit_fft.ml new file mode 100644 index 000000000..0a09ea2cb --- /dev/null +++ b/test/unit_fft.ml @@ -0,0 +1,1785 @@ +(** Unit test for functions in fft *) + +open Owl_dense_ndarray_generic +open Owl_fft_generic + +(* Helpers & error *) +let eps = 1e-6 + +let close a b = (sub a b |> abs |> sum') < eps + +let close_complex a b = + let d = (sub a b) in + let d = (abs d) in + let d = (sum' d) in + Complex.(d.re) < eps && Complex.(d.im) < eps + +(* Test functions *) +module To_test = struct + let test_fft_8_backward () = + let input = + [| Complex.{re= 0.37454012; im= 0.60111499} + ; Complex.{re= 0.95071429; im= 0.70807260} + ; Complex.{re= 0.73199391; im= 0.02058449} + ; Complex.{re= 0.59865850; im= 0.96990985} + ; Complex.{re= 0.15601864; im= 0.83244264} + ; Complex.{re= 0.15599452; im= 0.21233912} + ; Complex.{re= 0.05808361; im= 0.18182497} + ; Complex.{re= 0.86617613; im= 0.18340451} |] + in + let x = of_array Bigarray.Complex64 input [|8|] in + let expected = + [| Complex.{re= 3.89217973; im= 3.70969343} + ; Complex.{re= 1.71507597; im= -1.48363292} + ; Complex.{re= -0.49242139; im= 1.58927405} + ; Complex.{re= 0.53532648; im= 0.27540120} + ; Complex.{re= -1.25090718; im= -0.43775904} + ; Complex.{re= -1.60051394; im= -0.32684302} + ; Complex.{re= -0.02661610; im= 0.87302220} + ; Complex.{re= 0.22419742; im= 0.60976410} |] + in + let expected = of_array Bigarray.Complex64 expected [|8|] in + let result = fft ~norm:Backward x in + close_complex result expected + + let test_fft_8_ortho () = + let input = + [| Complex.{re= 0.30424225; im= 0.45606998} + ; Complex.{re= 0.52475643; im= 0.78517598} + ; Complex.{re= 0.43194503; im= 0.19967379} + ; Complex.{re= 0.29122913; im= 0.51423442} + ; Complex.{re= 0.61185288; im= 0.59241456} + ; Complex.{re= 0.13949387; im= 0.04645041} + ; Complex.{re= 0.29214466; im= 0.60754484} + ; Complex.{re= 0.36636186; im= 0.17052412} |] + in + let x = of_array Bigarray.Complex64 input [|8|] in + let expected = + [| Complex.{re= 1.04723442; im= 1.19221306} + ; Complex.{re= 0.13274679; im= -0.07641063} + ; Complex.{re= 0.11980981; im= 0.08294597} + ; Complex.{re= 0.19095755; im= -0.17506446} + ; Complex.{re= 0.11255147; im= 0.11996708} + ; Complex.{re= -0.63866872; im= -0.11885333} + ; Complex.{re= 0.01595855; im= 0.08765482} + ; Complex.{re= -0.12006272; im= 0.17750807} |] + in + let expected = of_array Bigarray.Complex64 expected [|8|] in + let result = fft ~norm:Ortho x in + close_complex result expected + + let test_fft_8_forward () = + let input = + [| Complex.{re= 0.06505159; im= 0.12203824} + ; Complex.{re= 0.94888556; im= 0.49517691} + ; Complex.{re= 0.96563202; im= 0.03438852} + ; Complex.{re= 0.80839735; im= 0.90932041} + ; Complex.{re= 0.30461377; im= 0.25877997} + ; Complex.{re= 0.09767211; im= 0.66252226} + ; Complex.{re= 0.68423301; im= 0.31171107} + ; Complex.{re= 0.44015250; im= 0.52006805} |] + in + let x = of_array Bigarray.Complex64 input [|8|] in + let expected = + [| Complex.{re= 0.53932971; im= 0.41425067} + ; Complex.{re= -0.00230779; im= -0.20925026} + ; Complex.{re= -0.19398613; im= 0.02958885} + ; Complex.{re= -0.01835475; im= -0.04050700} + ; Complex.{re= -0.03444713; im= -0.23252124} + ; Complex.{re= -0.12691340; im= 0.10471506} + ; Complex.{re= -0.12606378; im= -0.02090919} + ; Complex.{re= 0.02779485; im= 0.07667132} |] + in + let expected = of_array Bigarray.Complex64 expected [|8|] in + let result = fft ~norm:Forward x in + close_complex result expected + + let test_fft_4x4_backward () = + let input = + [| Complex.{re= 0.54671025; im= 0.28093451} + ; Complex.{re= 0.18485446; im= 0.54269606} + ; Complex.{re= 0.96958464; im= 0.14092423} + ; Complex.{re= 0.77513283; im= 0.80219698} + ; Complex.{re= 0.93949896; im= 0.07455064} + ; Complex.{re= 0.89482737; im= 0.98688692} + ; Complex.{re= 0.59789997; im= 0.77224475} + ; Complex.{re= 0.92187423; im= 0.19871569} + ; Complex.{re= 0.08849251; im= 0.00552212} + ; Complex.{re= 0.19598286; im= 0.81546146} + ; Complex.{re= 0.04522729; im= 0.70685732} + ; Complex.{re= 0.32533032; im= 0.72900718} + ; Complex.{re= 0.38867730; im= 0.77127033} + ; Complex.{re= 0.27134904; im= 0.07404465} + ; Complex.{re= 0.82873750; im= 0.35846573} + ; Complex.{re= 0.35675332; im= 0.11586906} |] + in + let x = of_array Bigarray.Complex64 input [|4; 4|] in + let expected = + [| Complex.{re= 2.47628212; im= 1.76675177} + ; Complex.{re= -0.68237531; im= 0.73028868} + ; Complex.{re= 0.55630767; im= -0.92303425} + ; Complex.{re= -0.16337347; im= -0.45026809} + ; Complex.{re= 3.35410070; im= 2.03239799} + ; Complex.{re= 1.12977028; im= -0.67064726} + ; Complex.{re= -0.27930272; im= -0.33880728} + ; Complex.{re= -0.44657224; im= -0.72474098} + ; Complex.{re= 0.65503299; im= 2.25684810} + ; Complex.{re= 0.12971950; im= -0.57198775} + ; Complex.{re= -0.38759339; im= -0.83208919} + ; Complex.{re= -0.04318906; im= -0.83068264} + ; Complex.{re= 1.84551716; im= 1.31964982} + ; Complex.{re= -0.48188460; im= 0.49820888} + ; Complex.{re= 0.58931249; im= 0.93982232} + ; Complex.{re= -0.39823580; im= 0.32740033} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 4|] in + let result = fft ~norm:Backward x in + close_complex result expected + + let test_fft_4x4_ortho () = + let input = + [| Complex.{re= 0.86310345; im= 0.52273285} + ; Complex.{re= 0.62329811; im= 0.42754102} + ; Complex.{re= 0.33089802; im= 0.02541913} + ; Complex.{re= 0.06355835; im= 0.10789143} + ; Complex.{re= 0.31098232; im= 0.03142919} + ; Complex.{re= 0.32518333; im= 0.63641042} + ; Complex.{re= 0.72960615; im= 0.31435597} + ; Complex.{re= 0.63755745; im= 0.50857067} + ; Complex.{re= 0.88721275; im= 0.90756649} + ; Complex.{re= 0.47221494; im= 0.24929222} + ; Complex.{re= 0.11959425; im= 0.41038293} + ; Complex.{re= 0.71324480; im= 0.75555116} + ; Complex.{re= 0.76078504; im= 0.22879817} + ; Complex.{re= 0.56127721; im= 0.07697991} + ; Complex.{re= 0.77096719; im= 0.28975144} + ; Complex.{re= 0.49379560; im= 0.16122128} |] + in + let x = of_array Bigarray.Complex64 input [|4; 4|] in + let expected = + [| Complex.{re= 0.94042897; im= 0.54179221} + ; Complex.{re= 0.42592752; im= -0.03121302} + ; Complex.{re= 0.25357249; im= 0.00635976} + ; Complex.{re= 0.10627794; im= 0.52852678} + ; Complex.{re= 1.00166464; im= 0.74538314} + ; Complex.{re= -0.14539205; im= 0.01472366} + ; Complex.{re= 0.03892386; im= -0.39959800} + ; Complex.{re= -0.27323180; im= -0.29765046} + ; Complex.{re= 1.09613335; im= 1.16139638} + ; Complex.{re= 0.13067979; im= 0.36910671} + ; Complex.{re= -0.08932638; im= 0.15655303} + ; Complex.{re= 0.63693875; im= 0.12807685} + ; Complex.{re= 1.29341245; im= 0.37837541} + ; Complex.{re= -0.04721176; im= -0.06421744} + ; Complex.{re= 0.23833972; im= 0.14017421} + ; Complex.{re= 0.03702961; im= 0.00326417} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 4|] in + let result = fft ~norm:Ortho x in + close_complex result expected + + let test_fft_4x4_forward () = + let input = + [| Complex.{re= 0.92969763; im= 0.00695213} + ; Complex.{re= 0.80812037; im= 0.51074731} + ; Complex.{re= 0.63340378; im= 0.41741100} + ; Complex.{re= 0.87146062; im= 0.22210781} + ; Complex.{re= 0.80367208; im= 0.11986537} + ; Complex.{re= 0.18657006; im= 0.33761516} + ; Complex.{re= 0.89255899; im= 0.94290972} + ; Complex.{re= 0.53934222; im= 0.32320294} + ; Complex.{re= 0.80744016; im= 0.51879060} + ; Complex.{re= 0.89609128; im= 0.70301896} + ; Complex.{re= 0.31800348; im= 0.36362961} + ; Complex.{re= 0.11005192; im= 0.97178209} + ; Complex.{re= 0.22793517; im= 0.96244729} + ; Complex.{re= 0.42710778; im= 0.25178230} + ; Complex.{re= 0.81801474; im= 0.49724850} + ; Complex.{re= 0.86073059; im= 0.30087832} |] + in + let x = of_array Bigarray.Complex64 input [|4; 4|] in + let expected = + [| Complex.{re= 0.81067061; im= 0.28930455} + ; Complex.{re= 0.14623334; im= -0.08677965} + ; Complex.{re= -0.02911988; im= -0.07712300} + ; Complex.{re= 0.00191359; im= -0.11844978} + ; Complex.{re= 0.60553586; im= 0.43089831} + ; Complex.{re= -0.01861867; im= -0.11756805} + ; Complex.{re= 0.24257971; im= 0.10048926} + ; Complex.{re= -0.02582479; im= -0.29395413} + ; Complex.{re= 0.53289676; im= 0.63930535} + ; Complex.{re= 0.05516839; im= -0.15771958} + ; Complex.{re= 0.02982512; im= -0.19809523} + ; Complex.{re= 0.18954995; im= 0.23530009} + ; Complex.{re= 0.58344710; im= 0.50308907} + ; Complex.{re= -0.15979388; im= 0.22470540} + ; Complex.{re= -0.06047210; im= 0.22675881} + ; Complex.{re= -0.13524589; im= 0.00789399} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 4|] in + let result = fft ~norm:Forward x in + close_complex result expected + + let test_fft_3x4_backward () = + let input = + [| Complex.{re= 0.28484049; im= 0.67213553} + ; Complex.{re= 0.03688695; im= 0.76161963} + ; Complex.{re= 0.60956430; im= 0.23763755} + ; Complex.{re= 0.50267905; im= 0.72821635} + ; Complex.{re= 0.05147875; im= 0.36778313} + ; Complex.{re= 0.27864647; im= 0.63230580} + ; Complex.{re= 0.90826589; im= 0.63352972} + ; Complex.{re= 0.23956189; im= 0.53577471} + ; Complex.{re= 0.14489487; im= 0.09028977} + ; Complex.{re= 0.48945275; im= 0.83530247} + ; Complex.{re= 0.98565048; im= 0.32078007} + ; Complex.{re= 0.24205527; im= 0.18651851} |] + in + let x = of_array Bigarray.Complex64 input [|3; 4|] in + let expected = + [| Complex.{re= 1.43397069; im= 2.39960909} + ; Complex.{re= -0.29132053; im= 0.90029007} + ; Complex.{re= 0.35483879; im= -0.58006287} + ; Complex.{re= -0.35812709; im= -0.03129411} + ; Complex.{re= 1.47795296; im= 2.16939354} + ; Complex.{re= -0.76025605; im= -0.30483118} + ; Complex.{re= 0.44153625; im= -0.16676772} + ; Complex.{re= -0.95331824; im= -0.22666201} + ; Complex.{re= 1.86205339; im= 1.43289089} + ; Complex.{re= -0.19197160; im= -0.47788778} + ; Complex.{re= 0.39903736; im= -0.61075115} + ; Complex.{re= -1.48953962; im= 0.01690719} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 4|] in + let result = fft ~norm:Backward x in + close_complex result expected + + let test_fft_3x4_ortho () = + let input = + [| Complex.{re= 0.04077514; im= 0.34106636} + ; Complex.{re= 0.59089297; im= 0.11347352} + ; Complex.{re= 0.67756438; im= 0.92469364} + ; Complex.{re= 0.01658783; im= 0.87733936} + ; Complex.{re= 0.51209307; im= 0.25794163} + ; Complex.{re= 0.22649577; im= 0.65998405} + ; Complex.{re= 0.64517277; im= 0.81722218} + ; Complex.{re= 0.17436643; im= 0.55520082} + ; Complex.{re= 0.69093776; im= 0.52965057} + ; Complex.{re= 0.38673535; im= 0.24185228} + ; Complex.{re= 0.93672997; im= 0.09310277} + ; Complex.{re= 0.13752094; im= 0.89721578} |] + in + let x = of_array Bigarray.Complex64 input [|3; 4|] in + let expected = + [| Complex.{re= 0.66291016; im= 1.12828636} + ; Complex.{re= -0.70032752; im= -0.57896620} + ; Complex.{re= 0.05542934; im= 0.13747352} + ; Complex.{re= 0.06353828; im= -0.00466108} + ; Complex.{re= 0.77906406; im= 1.14517438} + ; Complex.{re= -0.01414824; im= -0.30570492} + ; Complex.{re= 0.37820184; im= -0.07001054} + ; Complex.{re= -0.11893147; im= -0.25357559} + ; Complex.{re= 1.07596195; im= 0.88091075} + ; Complex.{re= -0.45057786; im= 0.09366670} + ; Complex.{re= 0.55170572; im= -0.25815740} + ; Complex.{re= 0.20478565; im= 0.34288111} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 4|] in + let result = fft ~norm:Ortho x in + close_complex result expected + + let test_fft_3x4_forward () = + let input = + [| Complex.{re= 0.90041804; im= 0.60642904} + ; Complex.{re= 0.63310146; im= 0.00919705} + ; Complex.{re= 0.33902979; im= 0.10147154} + ; Complex.{re= 0.34920958; im= 0.66350174} + ; Complex.{re= 0.72595567; im= 0.00506158} + ; Complex.{re= 0.89711028; im= 0.16080806} + ; Complex.{re= 0.88708645; im= 0.54873377} + ; Complex.{re= 0.77987552; im= 0.69189519} + ; Complex.{re= 0.64203167; im= 0.65196127} + ; Complex.{re= 0.08413997; im= 0.22426932} + ; Complex.{re= 0.16162871; im= 0.71217924} + ; Complex.{re= 0.89855421; im= 0.23724909} |] + in + let x = of_array Bigarray.Complex64 input [|3; 4|] in + let expected = + [| Complex.{re= 0.55543971; im= 0.34514984} + ; Complex.{re= -0.02322911; im= 0.05526640} + ; Complex.{re= 0.06428421; im= 0.00880045} + ; Complex.{re= 0.30392325; im= 0.19721234} + ; Complex.{re= 0.82250696; im= 0.35162464} + ; Complex.{re= -0.17305449; im= -0.16522674} + ; Complex.{re= -0.01598591; im= -0.07472697} + ; Complex.{re= 0.09248909; im= -0.10660936} + ; Complex.{re= 0.44658864; im= 0.45641473} + ; Complex.{re= 0.11685579; im= 0.18854907} + ; Complex.{re= -0.04475844; im= 0.22565553} + ; Complex.{re= 0.12334568; im= -0.21865806} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 4|] in + let result = fft ~norm:Forward x in + close_complex result expected + + let test_fft_inverse () = + let input = + [| 0.32539970 + ; 0.74649143 + ; 0.64963287 + ; 0.84922343 + ; 0.65761292 + ; 0.56830859 + ; 0.09367477 + ; 0.36771581 |] + in + let x = cast_d2z (of_array Bigarray.Float64 input [|8|]) in + let forward = fft x in + let result = ifft forward in + let expected = + [| Complex.{re= 0.32539970; im= 0.} + ; Complex.{re= 0.74649143; im= 0.} + ; Complex.{re= 0.64963281; im= 0.} + ; Complex.{re= 0.84922343; im= 0.} + ; Complex.{re= 0.65761292; im= 0.} + ; Complex.{re= 0.56830859; im= 0.} + ; Complex.{re= 0.09367481; im= 0.} + ; Complex.{re= 0.36771578; im= 0.} |] + in + let expected = of_array Bigarray.Complex64 expected [|8|] in + close_complex result expected + + let test_rfft_8_backward () = + let input = + [| 0.57690388 + ; 0.49251768 + ; 0.19524299 + ; 0.72245210 + ; 0.28077236 + ; 0.02431597 + ; 0.64547229 + ; 0.17711067 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| Complex.{re= 3.11478806; im= 0.00000000} + ; Complex.{re= 0.24158551; im= -0.26645392} + ; Complex.{re= 0.01696096; im= 0.38272914} + ; Complex.{re= 0.35067752; im= -1.16691256} + ; Complex.{re= 0.28199509; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|5|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Backward x in + close_complex result expected + + let test_rfft_8_ortho () = + let input = + [| 0.94045860 + ; 0.95392859 + ; 0.91486436 + ; 0.37015870 + ; 0.01545662 + ; 0.92831856 + ; 0.42818415 + ; 0.96665484 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| Complex.{re= 1.95091629; im= 0.00000000} + ; Complex.{re= 0.48256412; im= -0.02934592} + ; Complex.{re= -0.13687231; im= -0.19283991} + ; Complex.{re= 0.17151105; im= 0.31478897} + ; Complex.{re= -0.32530338; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|5|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Ortho x in + close_complex result expected + + let test_rfft_8_forward () = + let input = + [| 0.96361995 + ; 0.85300946 + ; 0.29444888 + ; 0.38509774 + ; 0.85113668 + ; 0.31692201 + ; 0.16949275 + ; 0.55680126 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| Complex.{re= 0.54881608; im= 0.00000000} + ; Complex.{re= 0.07662089; im= -0.04782681} + ; Complex.{re= 0.16885188; im= -0.02850406} + ; Complex.{re= -0.04850006; im= -0.01658777} + ; Complex.{re= 0.02085848; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|5|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Forward x in + close_complex result expected + + let test_rfft_4x4_backward () = + let input = + [| 0.93615478 + ; 0.69602978 + ; 0.57006115 + ; 0.09717649 + ; 0.61500722 + ; 0.99005383 + ; 0.14008401 + ; 0.51832968 + ; 0.87737310 + ; 0.74076861 + ; 0.69701576 + ; 0.70248407 + ; 0.35949114 + ; 0.29359186 + ; 0.80936116 + ; 0.81011337 |] + in + let x = of_array Bigarray.Float64 input [|4; 4|] in + let expected = + [| Complex.{re= 2.29942226; im= 0.00000000} + ; Complex.{re= 0.36609361; im= -0.59885329} + ; Complex.{re= 0.71300966; im= 0.00000000} + ; Complex.{re= 2.26347470; im= 0.00000000} + ; Complex.{re= 0.47492322; im= -0.47172421} + ; Complex.{re= -0.75329226; im= 0.00000000} + ; Complex.{re= 3.01764154; im= 0.00000000} + ; Complex.{re= 0.18035734; im= -0.03828453} + ; Complex.{re= 0.13113610; im= 0.00000000} + ; Complex.{re= 2.27255750; im= 0.00000000} + ; Complex.{re= -0.44986999; im= 0.51652157} + ; Complex.{re= 0.06514706; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Backward x in + close_complex result expected + + let test_rfft_4x4_ortho () = + let input = + [| 0.86707234 + ; 0.91324055 + ; 0.51134241 + ; 0.50151628 + ; 0.79829520 + ; 0.64996392 + ; 0.70196688 + ; 0.79579270 + ; 0.89000535 + ; 0.33799517 + ; 0.37558296 + ; 0.09398194 + ; 0.57828015 + ; 0.03594228 + ; 0.46559802 + ; 0.54264462 |] + in + let x = of_array Bigarray.Float64 input [|4; 4|] in + let expected = + [| Complex.{re= 1.39658582; im= 0.00000000} + ; Complex.{re= 0.17786495; im= -0.20586213} + ; Complex.{re= -0.01817106; im= 0.00000000} + ; Complex.{re= 1.47300935; im= 0.00000000} + ; Complex.{re= 0.04816415; im= 0.07291437} + ; Complex.{re= 0.02725273; im= 0.00000000} + ; Complex.{re= 0.84878272; im= 0.00000000} + ; Complex.{re= 0.25721121; im= -0.12200661} + ; Complex.{re= 0.41680560; im= 0.00000000} + ; Complex.{re= 0.81123251; im= 0.00000000} + ; Complex.{re= 0.05634106; im= 0.25335118} + ; Complex.{re= 0.23264563; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Ortho x in + close_complex result expected + + let test_rfft_4x4_forward () = + let input = + [| 0.28654125 + ; 0.59083325 + ; 0.03050025 + ; 0.03734819 + ; 0.82260054 + ; 0.36019063 + ; 0.12706052 + ; 0.52224326 + ; 0.76999354 + ; 0.21582103 + ; 0.62289047 + ; 0.08534747 + ; 0.05168172 + ; 0.53135461 + ; 0.54063511 + ; 0.63742989 |] + in + let x = of_array Bigarray.Float64 input [|4; 4|] in + let expected = + [| Complex.{re= 0.23630574; im= 0.00000000} + ; Complex.{re= 0.06401025; im= -0.13837127} + ; Complex.{re= -0.07778499; im= 0.00000000} + ; Complex.{re= 0.45802376; im= 0.00000000} + ; Complex.{re= 0.17388502; im= 0.04051315} + ; Complex.{re= 0.01680679; im= 0.00000000} + ; Complex.{re= 0.42351314; im= 0.00000000} + ; Complex.{re= 0.03677577; im= -0.03261839} + ; Complex.{re= 0.27292889; im= 0.00000000} + ; Complex.{re= 0.44027534; im= 0.00000000} + ; Complex.{re= -0.12223835; im= 0.02651882} + ; Complex.{re= -0.14411692; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|4; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Forward x in + close_complex result expected + + let test_rfft_3x4_backward () = + let input = + [| 0.72609133 + ; 0.97585207 + ; 0.51630032 + ; 0.32295647 + ; 0.79518622 + ; 0.27083224 + ; 0.43897143 + ; 0.07845638 + ; 0.02535074 + ; 0.96264839 + ; 0.83598012 + ; 0.69597423 |] + in + let x = of_array Bigarray.Float64 input [|3; 4|] in + let expected = + [| Complex.{re= 2.54120016; im= 0.00000000} + ; Complex.{re= 0.20979099; im= -0.65289563} + ; Complex.{re= -0.05641687; im= 0.00000000} + ; Complex.{re= 1.58344626; im= 0.00000000} + ; Complex.{re= 0.35621476; im= -0.19237587} + ; Complex.{re= 0.88486898; im= 0.00000000} + ; Complex.{re= 2.51995349; im= 0.00000000} + ; Complex.{re= -0.81062937; im= -0.26667422} + ; Complex.{re= -0.79729176; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Backward x in + close_complex result expected + + let test_rfft_3x4_ortho () = + let input = + [| 0.40895295 + ; 0.17329432 + ; 0.15643704 + ; 0.25024289 + ; 0.54922664 + ; 0.71459591 + ; 0.66019738 + ; 0.27993390 + ; 0.95486528 + ; 0.73789692 + ; 0.55435407 + ; 0.61172074 |] + in + let x = of_array Bigarray.Float64 input [|3; 4|] in + let expected = + [| Complex.{re= 0.49446359; im= 0.00000000} + ; Complex.{re= 0.12625796; im= 0.03847429} + ; Complex.{re= 0.07092638; im= 0.00000000} + ; Complex.{re= 1.10197687; im= 0.00000000} + ; Complex.{re= -0.05548536; im= -0.21733101} + ; Complex.{re= 0.10744711; im= 0.00000000} + ; Complex.{re= 1.42941844; im= 0.00000000} + ; Complex.{re= 0.20025562; im= -0.06308808} + ; Complex.{re= 0.07980084; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Ortho x in + close_complex result expected + + let test_rfft_3x4_forward () = + let input = + [| 0.41960007 + ; 0.24773099 + ; 0.35597268 + ; 0.75784612 + ; 0.01439349 + ; 0.11607264 + ; 0.04600264 + ; 0.04072880 + ; 0.85546058 + ; 0.70365787 + ; 0.47417384 + ; 0.09783416 |] + in + let x = of_array Bigarray.Float64 input [|3; 4|] in + let expected = + [| Complex.{re= 0.44528747; im= 0.00000000} + ; Complex.{re= 0.01590685; im= 0.12752879} + ; Complex.{re= -0.05750109; im= 0.00000000} + ; Complex.{re= 0.05429939; im= 0.00000000} + ; Complex.{re= -0.00790229; im= -0.01883596} + ; Complex.{re= -0.02410133; im= 0.00000000} + ; Complex.{re= 0.53278160; im= 0.00000000} + ; Complex.{re= 0.09532169; im= -0.15145592} + ; Complex.{re= 0.13203560; im= 0.00000000} |] + in + let expected = of_array Bigarray.Complex64 expected [|3; 3|] in + let result = rfft ~otyp:Bigarray.Complex64 ~norm:Forward x in + close_complex result expected + + let test_rfft_inverse () = + let input = + [| 0.49161586 + ; 0.47347176 + ; 0.17320187 + ; 0.43385166 + ; 0.39850473 + ; 0.61585009 + ; 0.63509363 + ; 0.04530401 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = rfft ~otyp:Bigarray.Complex64 x in + let result = irfft ~otyp:Bigarray.Float64 forward in + let expected = + [| 0.49161588 + ; 0.47347177 + ; 0.17320187 + ; 0.43385165 + ; 0.39850473 + ; 0.61585010 + ; 0.63509365 + ; 0.04530401 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dct_1_backward () = + let input = + [| 0.37461263 + ; 0.62585992 + ; 0.50313628 + ; 0.85648984 + ; 0.65869361 + ; 0.16293442 + ; 0.07056875 + ; 0.64241928 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 6.77239752 + ; 1.24504578 + ; -1.14123142 + ; -0.88034922 + ; 1.39627695 + ; -0.75523102 + ; -0.08163272 + ; -1.09357774 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Backward ~ttype:I x in + close result expected + + let test_dct_1_ortho () = + let input = + [| 0.02651131 + ; 0.58577555 + ; 0.94023025 + ; 0.57547420 + ; 0.38816991 + ; 0.64328820 + ; 0.45825288 + ; 0.54561681 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.51025033 + ; -0.01354993 + ; -0.08824035 + ; -0.38646209 + ; -0.34938586 + ; -0.18381834 + ; 0.12657233 + ; -0.14549664 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Ortho ~ttype:I x in + close result expected + + let test_dct_1_forward () = + let input = + [| 0.94146478 + ; 0.38610265 + ; 0.96119058 + ; 0.90535063 + ; 0.19579114 + ; 0.06936130 + ; 0.10077800 + ; 0.01822183 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.44263110 + ; 0.20466121 + ; -0.06257220 + ; -0.10297162 + ; 0.01850824 + ; 0.10350926 + ; 0.06267007 + ; 0.05122380 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Forward ~ttype:I x in + close result expected + + let test_dct_2_backward () = + let input = + [| 0.09444296 + ; 0.68300676 + ; 0.07118865 + ; 0.31897563 + ; 0.84487534 + ; 0.02327194 + ; 0.81446850 + ; 0.28185478 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 6.26416922 + ; -0.73818803 + ; -0.38138762 + ; 0.22999576 + ; -0.07323811 + ; -0.80621248 + ; -3.19520020 + ; 1.18421984 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Backward ~ttype:II x in + close result expected + + let test_dct_2_ortho () = + let input = + [| 0.11816483 + ; 0.69673717 + ; 0.62894285 + ; 0.87747204 + ; 0.73507106 + ; 0.80348092 + ; 0.28203458 + ; 0.17743954 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.52711833 + ; 0.10874486 + ; -0.69514889 + ; -0.01905947 + ; -0.17785436 + ; -0.17765704 + ; -0.04242539 + ; -0.26337412 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Ortho ~ttype:II x in + close result expected + + let test_dct_2_forward () = + let input = + [| 0.75061476 + ; 0.80683476 + ; 0.99050516 + ; 0.41261768 + ; 0.37201810 + ; 0.77641296 + ; 0.34080353 + ; 0.93075734 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.67257053 + ; 0.04220918 + ; 0.07393602 + ; -0.05915445 + ; -0.03964647 + ; -0.06020422 + ; 0.11441326 + ; -0.01948318 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Forward ~ttype:II x in + close result expected + + let test_dct_3_backward () = + let input = + [| 0.85841274 + ; 0.42899403 + ; 0.75087106 + ; 0.75454289 + ; 0.10312387 + ; 0.90255290 + ; 0.50525236 + ; 0.82645744 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 6.19997597 + ; -1.91606784 + ; 1.79455709 + ; -1.56116223 + ; 0.02140164 + ; 0.34837565 + ; 2.62342930 + ; -0.64320761 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Backward ~ttype:III x in + close result expected + + let test_dct_3_ortho () = + let input = + [| 0.32004961 + ; 0.89552325 + ; 0.38920167 + ; 0.01083765 + ; 0.90538198 + ; 0.09128667 + ; 0.31931365 + ; 0.95006198 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.23583686 + ; -0.21741576 + ; 0.51341361 + ; -0.15123919 + ; 0.53597867 + ; -0.78123981 + ; -0.34254304 + ; 0.11244562 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Ortho ~ttype:III x in + close result expected + + let test_dct_3_forward () = + let input = + [| 0.95060712 + ; 0.57343787 + ; 0.63183719 + ; 0.44844553 + ; 0.29321077 + ; 0.32866454 + ; 0.67251843 + ; 0.75237453 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.34855044 + ; -0.05782470 + ; 0.15199459 + ; -0.09504779 + ; 0.05543073 + ; 0.00988157 + ; 0.02993466 + ; 0.03238407 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Forward ~ttype:III x in + close result expected + + let test_dct_4_backward () = + let input = + [| 0.79157907 + ; 0.78961813 + ; 0.09120610 + ; 0.49442029 + ; 0.05755876 + ; 0.54952890 + ; 0.44153050 + ; 0.88770419 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 5.03350830 + ; -0.32469130 + ; 2.84329343 + ; -1.06977797 + ; 1.26837552 + ; -1.16236269 + ; -0.15100092 + ; -2.64669466 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Backward ~ttype:IV x in + close result expected + + let test_dct_4_ortho () = + let input = + [| 0.35091501 + ; 0.11706702 + ; 0.14299168 + ; 0.76151061 + ; 0.61821806 + ; 0.10112268 + ; 0.08410680 + ; 0.70096916 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.85449934 + ; -0.42461908 + ; -0.10973582 + ; 0.13980620 + ; 0.62992477 + ; -0.38134995 + ; 0.20042425 + ; -0.32184014 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Ortho ~ttype:IV x in + close result expected + + let test_dct_4_forward () = + let input = + [| 0.07276300 + ; 0.82186007 + ; 0.70624220 + ; 0.08134878 + ; 0.08483771 + ; 0.98663956 + ; 0.37427080 + ; 0.37064216 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.27606383 + ; -0.10396589 + ; 0.08367845 + ; -0.14917605 + ; -0.16153003 + ; 0.00991914 + ; 0.03385765 + ; -0.09567360 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dct ~norm:Forward ~ttype:IV x in + close result expected + + let test_dct_inverse_1 () = + let input = + [| 0.81279957 + ; 0.94724858 + ; 0.98600107 + ; 0.75337821 + ; 0.37625960 + ; 0.08350071 + ; 0.77714694 + ; 0.55840427 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dct ~ttype:I x in + let result = idct ~ttype:I forward in + let expected = + [| 0.81279957 + ; 0.94724858 + ; 0.98600106 + ; 0.75337819 + ; 0.37625959 + ; 0.08350072 + ; 0.77714692 + ; 0.55840425 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dct_inverse_2 () = + let input = + [| 0.42422202 + ; 0.90635437 + ; 0.11119748 + ; 0.49262512 + ; 0.01135364 + ; 0.46866065 + ; 0.05630327 + ; 0.11881792 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dct ~ttype:II x in + let result = idct ~ttype:II forward in + let expected = + [| 0.42422201 + ; 0.90635439 + ; 0.11119748 + ; 0.49262510 + ; 0.01135364 + ; 0.46866064 + ; 0.05630328 + ; 0.11881792 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dct_inverse_3 () = + let input = + [| 0.11752625 + ; 0.64921027 + ; 0.74604487 + ; 0.58336878 + ; 0.96217257 + ; 0.37487057 + ; 0.28571209 + ; 0.86859912 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dct ~ttype:III x in + let result = idct ~ttype:III forward in + let expected = + [| 0.11752625 + ; 0.64921030 + ; 0.74604488 + ; 0.58336877 + ; 0.96217255 + ; 0.37487058 + ; 0.28571209 + ; 0.86859913 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dct_inverse_4 () = + let input = + [| 0.22359584 + ; 0.96322256 + ; 0.01215447 + ; 0.96987885 + ; 0.04315991 + ; 0.89114314 + ; 0.52770108 + ; 0.99296480 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dct ~ttype:IV x in + let result = idct ~ttype:IV forward in + let expected = + [| 0.22359584 + ; 0.96322254 + ; 0.01215447 + ; 0.96987883 + ; 0.04315991 + ; 0.89114311 + ; 0.52770111 + ; 0.99296480 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dst_1_backward () = + let input = + [| 0.07379656 + ; 0.55385429 + ; 0.96930254 + ; 0.52309787 + ; 0.62939864 + ; 0.69574869 + ; 0.45454106 + ; 0.62755805 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 6.93005133 + ; -0.11519809 + ; 0.96519142 + ; -1.35991454 + ; -0.71071649 + ; -1.31527698 + ; 1.01109231 + ; 0.17671105 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Backward ~ttype:I x in + close result expected + + let test_dst_1_ortho () = + let input = + [| 0.58431429 + ; 0.90115803 + ; 0.04544638 + ; 0.28096318 + ; 0.95041150 + ; 0.89026380 + ; 0.45565677 + ; 0.62013263 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.55898416 + ; -0.25686294 + ; 0.54292411 + ; 0.60294652 + ; 0.33151725 + ; -0.46979901 + ; -0.08146074 + ; -0.17487633 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Ortho ~ttype:I x in + close result expected + + let test_dst_1_forward () = + let input = + [| 0.27738118 + ; 0.18812115 + ; 0.46369842 + ; 0.35335222 + ; 0.58365613 + ; 0.07773463 + ; 0.97439480 + ; 0.98621076 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.28567696 + ; -0.10827438 + ; 0.14328867 + ; -0.12813336 + ; 0.10891043 + ; -0.01470894 + ; -0.02046827 + ; 0.09155916 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Forward ~ttype:I x in + close result expected + + let test_dst_2_backward () = + let input = + [| 0.69816172 + ; 0.53609639 + ; 0.30952761 + ; 0.81379503 + ; 0.68473119 + ; 0.16261694 + ; 0.91092718 + ; 0.82253724 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 5.92580652 + ; -0.41755322 + ; 2.22041273 + ; -1.09627128 + ; 3.83235884 + ; 0.18310541 + ; 0.51656908 + ; 0.53660423 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Backward ~ttype:II x in + close result expected + + let test_dst_2_ortho () = + let input = + [| 0.94979990 + ; 0.72571951 + ; 0.61341518 + ; 0.41824305 + ; 0.93272847 + ; 0.86606389 + ; 0.04521867 + ; 0.02636698 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.58695292 + ; 0.27589065 + ; 0.23189719 + ; 0.83829910 + ; 0.13078196 + ; 0.10704315 + ; 0.43739575 + ; 0.17846274 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Ortho ~ttype:II x in + close result expected + + let test_dst_2_forward () = + let input = + [| 0.37646335 + ; 0.81055331 + ; 0.98727614 + ; 0.15041690 + ; 0.59413069 + ; 0.38089085 + ; 0.96991438 + ; 0.84211892 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.38684240 + ; 0.00812449 + ; 0.25888899 + ; -0.06962245 + ; 0.05404207 + ; -0.12640207 + ; 0.04120270 + ; 0.09297558 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Forward ~ttype:II x in + close result expected + + let test_dst_3_backward () = + let input = + [| 0.83832872 + ; 0.46869317 + ; 0.41481951 + ; 0.27340707 + ; 0.05637550 + ; 0.86472237 + ; 0.81290102 + ; 0.99971765 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 5.81922865 + ; 0.00653082 + ; 3.16587067 + ; 1.27023125 + ; 0.12993698 + ; 1.53134310 + ; 0.82424980 + ; -0.86656040 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Backward ~ttype:III x in + close result expected + + let test_dst_3_ortho () = + let input = + [| 0.99663681 + ; 0.55543172 + ; 0.76898742 + ; 0.94476575 + ; 0.84964740 + ; 0.24734810 + ; 0.45054415 + ; 0.12915942 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 1.48522890 + ; 1.04713714 + ; 0.11872885 + ; 0.20196684 + ; 0.52027225 + ; 0.27695364 + ; 0.05192040 + ; 0.28477481 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Ortho ~ttype:III x in + close result expected + + let test_dst_3_forward () = + let input = + [| 0.95405102 + ; 0.60617465 + ; 0.22864281 + ; 0.67170066 + ; 0.61812824 + ; 0.35816273 + ; 0.11355759 + ; 0.67157322 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.28901333 + ; 0.16782624 + ; 0.07231256 + ; 0.10237385 + ; 0.16434348 + ; 0.00136459 + ; 0.02728951 + ; -0.05439240 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Forward ~ttype:III x in + close result expected + + let test_dst_4_backward () = + let input = + [| 0.52030772 + ; 0.77231836 + ; 0.52016348 + ; 0.85218149 + ; 0.55190682 + ; 0.56093800 + ; 0.87665361 + ; 0.40348285 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 6.44558191 + ; 2.56039333 + ; 1.15299642 + ; 1.19234645 + ; 0.10844552 + ; 1.92546284 + ; -0.05677728 + ; -0.24125953 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Backward ~ttype:IV x in + close result expected + + let test_dst_4_ortho () = + let input = + [| 0.13401523 + ; 0.02878268 + ; 0.75513726 + ; 0.62030953 + ; 0.70407975 + ; 0.21296416 + ; 0.13637148 + ; 0.01454467 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.82401651 + ; 0.78515619 + ; -0.21069290 + ; -0.26915047 + ; -0.06921585 + ; -0.04269409 + ; -0.01689269 + ; 0.33836311 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Ortho ~ttype:IV x in + close result expected + + let test_dst_4_forward () = + let input = + [| 0.35058755 + ; 0.58991766 + ; 0.39224404 + ; 0.43747494 + ; 0.90415871 + ; 0.34825546 + ; 0.51398951 + ; 0.78365302 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let expected = + [| 0.36822951 + ; 0.08125728 + ; 0.06015708 + ; 0.00932478 + ; 0.12167707 + ; 0.00318824 + ; -0.05184158 + ; 0.03424457 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + let result = dst ~norm:Forward ~ttype:IV x in + close result expected + + let test_dst_inverse_1 () = + let input = + [| 0.39654279 + ; 0.62208670 + ; 0.86236370 + ; 0.94952065 + ; 0.14707348 + ; 0.92658764 + ; 0.49211630 + ; 0.25824440 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dst ~ttype:I x in + let result = idst ~ttype:I forward in + let expected = + [| 0.39654278 + ; 0.62208670 + ; 0.86236371 + ; 0.94952062 + ; 0.14707348 + ; 0.92658763 + ; 0.49211629 + ; 0.25824439 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dst_inverse_2 () = + let input = + [| 0.45913577 + ; 0.98003256 + ; 0.49261808 + ; 0.32875162 + ; 0.63340086 + ; 0.24014562 + ; 0.07586333 + ; 0.12887973 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dst ~ttype:II x in + let result = idst ~ttype:II forward in + let expected = + [| 0.45913576 + ; 0.98003258 + ; 0.49261809 + ; 0.32875161 + ; 0.63340085 + ; 0.24014562 + ; 0.07586333 + ; 0.12887972 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dst_inverse_3 () = + let input = + [| 0.12804584 + ; 0.15190269 + ; 0.13882717 + ; 0.64087474 + ; 0.18188009 + ; 0.34566727 + ; 0.89678842 + ; 0.47396165 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dst ~ttype:III x in + let result = idst ~ttype:III forward in + let expected = + [| 0.12804584 + ; 0.15190269 + ; 0.13882717 + ; 0.64087474 + ; 0.18188008 + ; 0.34566728 + ; 0.89678841 + ; 0.47396164 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected + + let test_dst_inverse_4 () = + let input = + [| 0.66755772 + ; 0.17231987 + ; 0.19228902 + ; 0.04086862 + ; 0.16893506 + ; 0.27859035 + ; 0.17701049 + ; 0.08870254 |] + in + let x = of_array Bigarray.Float64 input [|8|] in + let forward = dst ~ttype:IV x in + let result = idst ~ttype:IV forward in + let expected = + [| 0.66755774 + ; 0.17231987 + ; 0.19228902 + ; 0.04086862 + ; 0.16893506 + ; 0.27859034 + ; 0.17701048 + ; 0.08870253 |] + in + let expected = of_array Bigarray.Float64 expected [|8|] in + close result expected +end + +(* The actual tests *) +let test_fft_8_backward () = + Alcotest.(check bool) + "test_fft_8_backward" true + (To_test.test_fft_8_backward ()) + +let test_fft_8_ortho () = + Alcotest.(check bool) "test_fft_8_ortho" true (To_test.test_fft_8_ortho ()) + +let test_fft_8_forward () = + Alcotest.(check bool) + "test_fft_8_forward" true + (To_test.test_fft_8_forward ()) + +let test_fft_4x4_backward () = + Alcotest.(check bool) + "test_fft_4x4_backward" true + (To_test.test_fft_4x4_backward ()) + +let test_fft_4x4_ortho () = + Alcotest.(check bool) + "test_fft_4x4_ortho" true + (To_test.test_fft_4x4_ortho ()) + +let test_fft_4x4_forward () = + Alcotest.(check bool) + "test_fft_4x4_forward" true + (To_test.test_fft_4x4_forward ()) + +let test_fft_3x4_backward () = + Alcotest.(check bool) + "test_fft_3x4_backward" true + (To_test.test_fft_3x4_backward ()) + +let test_fft_3x4_ortho () = + Alcotest.(check bool) + "test_fft_3x4_ortho" true + (To_test.test_fft_3x4_ortho ()) + +let test_fft_3x4_forward () = + Alcotest.(check bool) + "test_fft_3x4_forward" true + (To_test.test_fft_3x4_forward ()) + +let test_fft_inverse () = + Alcotest.(check bool) "test_fft_inverse" true (To_test.test_fft_inverse ()) + +let test_rfft_8_backward () = + Alcotest.(check bool) + "test_rfft_8_backward" true + (To_test.test_rfft_8_backward ()) + +let test_rfft_8_ortho () = + Alcotest.(check bool) "test_rfft_8_ortho" true (To_test.test_rfft_8_ortho ()) + +let test_rfft_8_forward () = + Alcotest.(check bool) + "test_rfft_8_forward" true + (To_test.test_rfft_8_forward ()) + +let test_rfft_4x4_backward () = + Alcotest.(check bool) + "test_rfft_4x4_backward" true + (To_test.test_rfft_4x4_backward ()) + +let test_rfft_4x4_ortho () = + Alcotest.(check bool) + "test_rfft_4x4_ortho" true + (To_test.test_rfft_4x4_ortho ()) + +let test_rfft_4x4_forward () = + Alcotest.(check bool) + "test_rfft_4x4_forward" true + (To_test.test_rfft_4x4_forward ()) + +let test_rfft_3x4_backward () = + Alcotest.(check bool) + "test_rfft_3x4_backward" true + (To_test.test_rfft_3x4_backward ()) + +let test_rfft_3x4_ortho () = + Alcotest.(check bool) + "test_rfft_3x4_ortho" true + (To_test.test_rfft_3x4_ortho ()) + +let test_rfft_3x4_forward () = + Alcotest.(check bool) + "test_rfft_3x4_forward" true + (To_test.test_rfft_3x4_forward ()) + +let test_rfft_inverse () = + Alcotest.(check bool) "test_rfft_inverse" true (To_test.test_rfft_inverse ()) + +let test_dct_1_backward () = + Alcotest.(check bool) + "test_dct_1_backward" true + (To_test.test_dct_1_backward ()) + +let test_dct_1_ortho () = + Alcotest.(check bool) "test_dct_1_ortho" true (To_test.test_dct_1_ortho ()) + +let test_dct_1_forward () = + Alcotest.(check bool) + "test_dct_1_forward" true + (To_test.test_dct_1_forward ()) + +let test_dct_2_backward () = + Alcotest.(check bool) + "test_dct_2_backward" true + (To_test.test_dct_2_backward ()) + +let test_dct_2_ortho () = + Alcotest.(check bool) "test_dct_2_ortho" true (To_test.test_dct_2_ortho ()) + +let test_dct_2_forward () = + Alcotest.(check bool) + "test_dct_2_forward" true + (To_test.test_dct_2_forward ()) + +let test_dct_3_backward () = + Alcotest.(check bool) + "test_dct_3_backward" true + (To_test.test_dct_3_backward ()) + +let test_dct_3_ortho () = + Alcotest.(check bool) "test_dct_3_ortho" true (To_test.test_dct_3_ortho ()) + +let test_dct_3_forward () = + Alcotest.(check bool) + "test_dct_3_forward" true + (To_test.test_dct_3_forward ()) + +let test_dct_4_backward () = + Alcotest.(check bool) + "test_dct_4_backward" true + (To_test.test_dct_4_backward ()) + +let test_dct_4_ortho () = + Alcotest.(check bool) "test_dct_4_ortho" true (To_test.test_dct_4_ortho ()) + +let test_dct_4_forward () = + Alcotest.(check bool) + "test_dct_4_forward" true + (To_test.test_dct_4_forward ()) + +let test_dct_inverse_1 () = + Alcotest.(check bool) + "test_dct_inverse_1" true + (To_test.test_dct_inverse_1 ()) + +let test_dct_inverse_2 () = + Alcotest.(check bool) + "test_dct_inverse_2" true + (To_test.test_dct_inverse_2 ()) + +let test_dct_inverse_3 () = + Alcotest.(check bool) + "test_dct_inverse_3" true + (To_test.test_dct_inverse_3 ()) + +let test_dct_inverse_4 () = + Alcotest.(check bool) + "test_dct_inverse_4" true + (To_test.test_dct_inverse_4 ()) + +let test_dst_1_backward () = + Alcotest.(check bool) + "test_dst_1_backward" true + (To_test.test_dst_1_backward ()) + +let test_dst_1_ortho () = + Alcotest.(check bool) "test_dst_1_ortho" true (To_test.test_dst_1_ortho ()) + +let test_dst_1_forward () = + Alcotest.(check bool) + "test_dst_1_forward" true + (To_test.test_dst_1_forward ()) + +let test_dst_2_backward () = + Alcotest.(check bool) + "test_dst_2_backward" true + (To_test.test_dst_2_backward ()) + +let test_dst_2_ortho () = + Alcotest.(check bool) "test_dst_2_ortho" true (To_test.test_dst_2_ortho ()) + +let test_dst_2_forward () = + Alcotest.(check bool) + "test_dst_2_forward" true + (To_test.test_dst_2_forward ()) + +let test_dst_3_backward () = + Alcotest.(check bool) + "test_dst_3_backward" true + (To_test.test_dst_3_backward ()) + +let test_dst_3_ortho () = + Alcotest.(check bool) "test_dst_3_ortho" true (To_test.test_dst_3_ortho ()) + +let test_dst_3_forward () = + Alcotest.(check bool) + "test_dst_3_forward" true + (To_test.test_dst_3_forward ()) + +let test_dst_4_backward () = + Alcotest.(check bool) + "test_dst_4_backward" true + (To_test.test_dst_4_backward ()) + +let test_dst_4_ortho () = + Alcotest.(check bool) "test_dst_4_ortho" true (To_test.test_dst_4_ortho ()) + +let test_dst_4_forward () = + Alcotest.(check bool) + "test_dst_4_forward" true + (To_test.test_dst_4_forward ()) + +let test_dst_inverse_1 () = + Alcotest.(check bool) + "test_dst_inverse_1" true + (To_test.test_dst_inverse_1 ()) + +let test_dst_inverse_2 () = + Alcotest.(check bool) + "test_dst_inverse_2" true + (To_test.test_dst_inverse_2 ()) + +let test_dst_inverse_3 () = + Alcotest.(check bool) + "test_dst_inverse_3" true + (To_test.test_dst_inverse_3 ()) + +let test_dst_inverse_4 () = + Alcotest.(check bool) + "test_dst_inverse_4" true + (To_test.test_dst_inverse_4 ()) + +let test_set = + [ ("test_fft_8_backward", `Slow, test_fft_8_backward) + ; ("test_fft_8_ortho", `Slow, test_fft_8_ortho) + ; ("test_fft_8_forward", `Slow, test_fft_8_forward) + ; ("test_fft_4x4_backward", `Slow, test_fft_4x4_backward) + ; ("test_fft_4x4_ortho", `Slow, test_fft_4x4_ortho) + ; ("test_fft_4x4_forward", `Slow, test_fft_4x4_forward) + ; ("test_fft_3x4_backward", `Slow, test_fft_3x4_backward) + ; ("test_fft_3x4_ortho", `Slow, test_fft_3x4_ortho) + ; ("test_fft_3x4_forward", `Slow, test_fft_3x4_forward) + ; ("test_fft_inverse", `Slow, test_fft_inverse) + ; ("test_rfft_8_backward", `Slow, test_rfft_8_backward) + ; ("test_rfft_8_ortho", `Slow, test_rfft_8_ortho) + ; ("test_rfft_8_forward", `Slow, test_rfft_8_forward) + ; ("test_rfft_4x4_backward", `Slow, test_rfft_4x4_backward) + ; ("test_rfft_4x4_ortho", `Slow, test_rfft_4x4_ortho) + ; ("test_rfft_4x4_forward", `Slow, test_rfft_4x4_forward) + ; ("test_rfft_3x4_backward", `Slow, test_rfft_3x4_backward) + ; ("test_rfft_3x4_ortho", `Slow, test_rfft_3x4_ortho) + ; ("test_rfft_3x4_forward", `Slow, test_rfft_3x4_forward) + ; ("test_rfft_inverse", `Slow, test_rfft_inverse) + ; ("test_dct_1_backward", `Slow, test_dct_1_backward) + ; ("test_dct_1_ortho", `Slow, test_dct_1_ortho) + ; ("test_dct_1_forward", `Slow, test_dct_1_forward) + ; ("test_dct_2_backward", `Slow, test_dct_2_backward) + ; ("test_dct_2_ortho", `Slow, test_dct_2_ortho) + ; ("test_dct_2_forward", `Slow, test_dct_2_forward) + ; ("test_dct_3_backward", `Slow, test_dct_3_backward) + ; ("test_dct_3_ortho", `Slow, test_dct_3_ortho) + ; ("test_dct_3_forward", `Slow, test_dct_3_forward) + ; ("test_dct_4_backward", `Slow, test_dct_4_backward) + ; ("test_dct_4_ortho", `Slow, test_dct_4_ortho) + ; ("test_dct_4_forward", `Slow, test_dct_4_forward) + ; ("test_dct_inverse_1", `Slow, test_dct_inverse_1) + ; ("test_dct_inverse_2", `Slow, test_dct_inverse_2) + ; ("test_dct_inverse_3", `Slow, test_dct_inverse_3) + ; ("test_dct_inverse_4", `Slow, test_dct_inverse_4) + ; ("test_dst_1_backward", `Slow, test_dst_1_backward) + ; ("test_dst_1_ortho", `Slow, test_dst_1_ortho) + ; ("test_dst_1_forward", `Slow, test_dst_1_forward) + ; ("test_dst_2_backward", `Slow, test_dst_2_backward) + ; ("test_dst_2_ortho", `Slow, test_dst_2_ortho) + ; ("test_dst_2_forward", `Slow, test_dst_2_forward) + ; ("test_dst_3_backward", `Slow, test_dst_3_backward) + ; ("test_dst_3_ortho", `Slow, test_dst_3_ortho) + ; ("test_dst_3_forward", `Slow, test_dst_3_forward) + ; ("test_dst_4_backward", `Slow, test_dst_4_backward) + ; ("test_dst_4_ortho", `Slow, test_dst_4_ortho) + ; ("test_dst_4_forward", `Slow, test_dst_4_forward) + ; ("test_dst_inverse_1", `Slow, test_dst_inverse_1) + ; ("test_dst_inverse_2", `Slow, test_dst_inverse_2) + ; ("test_dst_inverse_3", `Slow, test_dst_inverse_3) + ; ("test_dst_inverse_4", `Slow, test_dst_inverse_4) ] From 9fd7bb0e3909db570ede22401bcfbd093264229a Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 18 Nov 2024 00:08:26 +0100 Subject: [PATCH 05/12] Remove pocketfft submodule from .gitmodules --- .gitmodules | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index 775cc6982..e69de29bb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "src/owl/fftpack/pocketfft"] - path = src/owl/fftpack/pocketfft - url = https://github.com/mreineck/pocketfft From c05f7dc3632d1bd8af94fed67ab38924a2f87cbc Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 18 Nov 2024 00:09:27 +0100 Subject: [PATCH 06/12] Convert pocketfft submodule to regular directory --- src/owl/fftpack/pocketfft | 1 - src/owl/fftpack/pocketfft/LICENSE.md | 25 + src/owl/fftpack/pocketfft/README.md | 246 ++ src/owl/fftpack/pocketfft/pocketfft_demo.cc | 87 + src/owl/fftpack/pocketfft/pocketfft_hdronly.h | 3743 +++++++++++++++++ 5 files changed, 4101 insertions(+), 1 deletion(-) delete mode 160000 src/owl/fftpack/pocketfft create mode 100644 src/owl/fftpack/pocketfft/LICENSE.md create mode 100644 src/owl/fftpack/pocketfft/README.md create mode 100644 src/owl/fftpack/pocketfft/pocketfft_demo.cc create mode 100644 src/owl/fftpack/pocketfft/pocketfft_hdronly.h diff --git a/src/owl/fftpack/pocketfft b/src/owl/fftpack/pocketfft deleted file mode 160000 index bb87ca50d..000000000 --- a/src/owl/fftpack/pocketfft +++ /dev/null @@ -1 +0,0 @@ -Subproject commit bb87ca50df0478415a12d9011dc374eeed4e9d93 diff --git a/src/owl/fftpack/pocketfft/LICENSE.md b/src/owl/fftpack/pocketfft/LICENSE.md new file mode 100644 index 000000000..c3a4c06a9 --- /dev/null +++ b/src/owl/fftpack/pocketfft/LICENSE.md @@ -0,0 +1,25 @@ +Copyright (C) 2010-2018 Max-Planck-Society +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/src/owl/fftpack/pocketfft/README.md b/src/owl/fftpack/pocketfft/README.md new file mode 100644 index 000000000..3792e0617 --- /dev/null +++ b/src/owl/fftpack/pocketfft/README.md @@ -0,0 +1,246 @@ +PocketFFT for C++ +================= + +This is a heavily modified implementation of FFTPack [1,2], with the following +advantages: + +- Strictly C++11 compliant +- More accurate twiddle factor computation +- Worst case complexity for transform sizes with large prime factors is + `N*log(N)`, because Bluestein's algorithm [3] is used for these cases. +- Supports multidimensional arrays and selection of the axes to be transformed. +- Supports `float`, `double`, and `long double` types. +- Supports fully complex and half-complex (i.e. complex-to-real and + real-to-complex) FFTs. For half-complex transforms, several conventions for + representing the complex-valued side are supported (reduced-size complex + array, FFTPACK-style half-complex format and Hartley transform). +- Supports discrete cosine and sine transforms (Types I-IV) +- Makes use of CPU vector instructions when performing 2D and higher-dimensional + transforms, if they are available. +- Has a small internal cache for transform plans, which speeds up repeated + transforms of the same length (most significant for 1D transforms). +- Has optional multi-threading support for multidimensional transforms + + +License +------- + +3-clause BSD (see LICENSE.md) + + +Some code details +----------------- + +Twiddle factor computation: + +- making use of symmetries to reduce number of sin/cos evaluations +- all angles are reduced to the range `[0; pi/4]` for higher accuracy +- if `n` sin/cos pairs are required, the trigonometric functions are only called + `2*sqrt(n)` times; the remaining values are obtained by evaluating the + angle addition theorems in a numerically accurate way. + +Efficient codelets are available for the factors: + +- 2, 3, 4, 5, 7, 8, 11 for complex-valued FFTs +- 2, 3, 4, 5 for real-valued FFTs + +Larger prime factors are handled by somewhat less efficient, generic routines. + +For lengths with very large prime factors, Bluestein's algorithm is used, and +instead of an FFT of length `n`, a convolution of length `n2 >= 2*n-1` +is performed, where `n2` is chosen to be highly composite. + + +[1] Swarztrauber, P. 1982, Vectorizing the Fast Fourier Transforms + (New York: Academic Press), 51 + +[2] https://www.netlib.org/fftpack/ + +[3] https://en.wikipedia.org/wiki/Chirp_Z-transform + + +Configuration options +===================== + +Since this is a header-only library, it can only be configured via preprocessor +macros. + +POCKETFFT_CACHE_SIZE:\ +if 0, disable all caching of FFT plans, else use an LRU cache with the +requested size. If undefined, assume a cache size of 0.\ +NOTE: caching is disabled by default because its benefits are only really +noticeable for short 1D transforms. When using caching with transforms that +have very large axis lengths, it may use up a lot of memory, so only switch this +on if you know you really need it! +Default: undefined + +POCKETFFT_NO_VECTORS:\ +if defined, disable all support for CPU vector instructions.\ +Default: undefined + +POCKETFFT_NO_MULTITHREADING:\ +if defined, multi-threading will be disabled.\ +Default: undefined + + +Programming interface +===================== + +All symbols are encapsulated in the namespace `pocketfft`. + +Arguments +--------- + - `shape[_*]` contains the number of array entries along each axis. + For `c2c` and `r2r` transforms, `shape` is identical for input and output + arrays. For `r2c` transforms the shape of the input array must be specified, + while for `c2r` transforms the shape of the *output* array must be given. + + - `stride_*` describes array strides, i.e. the memory distance (in bytes) + between two neighboring array entries along an axis. + + - `axes` is a vector of nonnegative integers, describing the axes along + which a transform is to be carried out. The order of axes usually does not + matter, except for `r2c` and `c2r` transforms, where the last entry of + `axes` is treated specially. + + - `forward` describes the direction of a transform. Generally a forward + transform has a minus sign in the complex exponent, while the backward + transform has a positive one. Instead if `true`/`false`, the symbolic + constants `FORWARD`/`BACKWARD` can be used. + NOTE: Unlike many other libraries, pocketfft also allows a `forward` argument + in `r2c` and `c2r` transforms, instead of having hard-wired forward `r2c` and + backward `c2r` transforms. Calling `r2c` with `forward=false`, for + example, performs a transform from purely real data in the frequency domain + to Hermitian data in the position domain. + If you want the "traditional" behavior, call `r2c` with `forward=true` and + `c2r` with `forward=false`. + + - `fct` is a floating-point value which is used to scale the result of a + transform. `pocketfft`'s transforms are not normalized, so if normalization + is required, an appropriate scaling factor has to be specified. + + - `data_in` and `data_out` are pointers to the first element of the input + and output data arrays. + + - `nthreads` is a nonnegative integer specifying the number of threads to use + for the operation. A value of 0 means that the number of logical CPU cores + will be used. + This value is only a recommendation. If `pocketfft` is compiled without + multi-threading support, it will be silently ignored. For one-dimensional + transforms, multi-threading is disabled as well. + +General constraints on arguments +-------------------------------- + - `shape[_*]`, `stride_in` and `stride_out` must have the same `size()` + and must not be empty. + - Entries in `shape[_*]` must be >=1. + - If `data_in==data_out`, `stride_in` and `stride_out` must have identical + content. These in-place transforms are fine for `c2c` and `r2r`, but not for + `r2c/c2r`. + - Axes are numbered from 0 to `shape.size()-1`, inclusively. + - Strides are measured in bytes, to allow maximum flexibility. Negative strides + are fine. Strides that lead to multiple accesses of the same memory address + are not allowed. + - All memory addresses resulting from a combination of data pointers and + strides must have sufficient alignment. On x86 CPUs, badly aligned adresses + will only lead to slower execution, but on some other hardware, misaligned + memory accesses will cause a crash. + - The same axis must not be specified more than once in an `axes` argument. + - For `r2c` and `c2r` transforms: the length of the complex array along `axis` + (or the last entry in `axes`) is assumed to be `s/2 + 1`, where `s` is the + length of the corresponding axis of the real array. + +Detailed public interface +------------------------- + +``` +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +template void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const complex *data_in, complex *data_out, T fct, + size_t nthreads=1) + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, complex *data_out, T fct, + size_t nthreads=1) + +/* This function first carries out an r2c transform along the last axis in axes, + storing the result in data_out. Then, an in-place c2c transform + is carried out in data_out along all other axes. */ +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, complex *data_out, T fct, + size_t nthreads=1) + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const complex *data_in, T *data_out, T fct, + size_t nthreads=1) + +/* This function first carries out a c2c transform along all axes except the + last one, storing the result into a temporary array. Then, a c2r transform + is carried out along the last axis, storing the result in data_out. */ +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const complex *data_in, T *data_out, T fct, + size_t nthreads=1) + +/* This function carries out a FFTPACK-style real-to-halfcomplex or + halfcomplex-to-real transform (depending on the parameter `real2hermitian`) + on all specified axes in the given order. + NOTE: interpreting the result of this function can be complicated when + transforming more than one axis! */ +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + +/* For every requested axis, this function carries out a forward Fourier + transform, and the real and imaginary parts of the result are added before + the next axis is processed. + This is analogous to FFTW's implementation of the Hartley transform. */ +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1); + +/* This function carries out a full Fourier transform over the requested axes, + and the sum of real and imaginary parts of the result is stored in the output + array. For a single transformed axis, this is identical to + `r2r_separable_hartley`, but when transforming multiple axes, the results + are different. + + NOTE: This function allocates temporary working space with a size + comparable to the input array. */ +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1); + +/* if ortho==true, the transform is made orthogonal by these additional steps + in every 1D sub-transform: + Type 1 : multiply first and last input value by sqrt(2) + divide first and last output value by sqrt(2) + Type 2 : divide first output value by sqrt(2) + Type 3 : multiply first input value by sqrt(2) + Type 4 : nothing */ +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, + size_t nthreads=1); + +/* if ortho==true, the transform is made orthogonal by these additional steps + in every 1D sub-transform: + Type 1 : nothing + Type 2 : divide last output value by sqrt(2) + Type 3 : multiply last input value by sqrt(2) + Type 4 : nothing */ +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, + size_t nthreads=1); +``` diff --git a/src/owl/fftpack/pocketfft/pocketfft_demo.cc b/src/owl/fftpack/pocketfft/pocketfft_demo.cc new file mode 100644 index 000000000..38fb6dcbe --- /dev/null +++ b/src/owl/fftpack/pocketfft/pocketfft_demo.cc @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include "pocketfft_hdronly.h" + +using namespace std; +using namespace pocketfft; + +// floating point RNG which is good enough for simple demos +// Do not use for anything important! +inline double simple_drand() + { + constexpr double norm = 1./RAND_MAX; + return rand()*norm; + } + +template void crand(vector> &v) + { + for (auto & i:v) + i = complex(simple_drand()-0.5, simple_drand()-0.5); + } + +template long double l2err + (const vector &v1, const vector &v2) + { + long double sum1=0, sum2=0; + for (size_t i=0; i), + tmpd=sizeof(complex), + tmpl=sizeof(complex); + for (int i=shape.size()-1; i>=0; --i) + { + stridef[i]=tmpf; + tmpf*=shape[i]; + strided[i]=tmpd; + tmpd*=shape[i]; + stridel[i]=tmpl; + tmpl*=shape[i]; + } + size_t ndata=1; + for (size_t i=0; i> dataf(ndata); + vector> datad(ndata); + vector> datal(ndata); + crand(dataf); + for (size_t i=0; i= 201103L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201103L)) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE!=0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +# include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::size_t; +using std::ptrdiff_t; + +// Always use std:: for functions +template T cos(T) = delete; +template T sin(T) = delete; +template T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +# if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +# undef POCKETFFT_NO_VECTORS +# endif +#elif __clang_major__ >= 5 +# undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template struct VLEN { static constexpr size_t val=1; }; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> struct VLEN { static constexpr size_t val=16; }; +template<> struct VLEN { static constexpr size_t val=8; }; +#elif (defined(__AVX__)) +template<> struct VLEN { static constexpr size_t val=8; }; +template<> struct VLEN { static constexpr size_t val=4; }; +#elif (defined(__SSE2__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__VSX__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// std::aligned_alloc is a bit cursed ... it doesn't exist on MacOS < 10.15 +// and in musl, and other OSes seem to have even more peculiarities. +// Let's unconditionally work around it for now. +# if 0 +//#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) && (__MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_15) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) + { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size+align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast + ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; + } +inline void aligned_dealloc(void *ptr) + { if (ptr) free((reinterpret_cast(ptr))[-1]); } +#endif + +template class arr + { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *res = malloc(num*sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) + { free(ptr); } +#else + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *ptr = aligned_alloc(64, num*sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) + { aligned_dealloc(ptr); } +#endif + + public: + arr() : p(0), sz(0) {} + arr(size_t n) : p(ralloc(n)), sz(n) {} + arr(arr &&other) + : p(other.p), sz(other.sz) + { other.p=nullptr; other.sz=0; } + ~arr() { dealloc(p); } + + void resize(size_t n) + { + if (n==sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { return p[idx]; } + const T &operator[](size_t idx) const { return p[idx]; } + + T *data() { return p; } + const T *data() const { return p; } + + size_t size() const { return sz; } + }; + +template struct cmplx { + T r, i; + cmplx() {} + cmplx(T r_, T i_) : r(r_), i(i_) {} + void Set(T r_, T i_) { r=r_; i=i_; } + void Set(T r_) { r=r_; i=T(0); } + cmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator*= (T2 other) + { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } + template auto operator* (const T2 &other) const + -> cmplx + { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } + template auto operator* (const cmplx &other) const + -> cmplx + { return {r*other.r-i*other.i, r*other.i + i*other.r}; } + template auto special_mul (const cmplx &other) const + -> cmplx + { + using Tres = cmplx; + return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) + : Tres(r*other.r-i*other.i, r*other.i+i*other.r); + } +}; +template inline void PM(T &a, T &b, T c, T d) + { a=c+d; b=c-d; } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template inline void MPINPLACE(T &a, T &b) + { T t = a; a-=b; b=t+b; } +template cmplx conj(const cmplx &a) + { return {a.r, -a.i}; } +template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) + { + res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) + : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); + } + +template void ROT90(cmplx &a) + { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } + +// +// twiddle factor section +// +template class sincos_2pibyn + { + private: + using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) + { + x<<=3; + if (x<4*n) // first half + { + if (x<2*n) // first quadrant + { + if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); + } + else // second quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); + } + } + else + { + x=8*n-x; + if (x<2*n) // third quadrant + { + if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); + } + else // fourth quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L*pi/n); + size_t nval = (n+2)/2; + shift = 1; + while((size_t(1)< operator[](size_t idx) const + { + if (2*idx<=N) + { + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); + } + idx = N-idx; + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); + } + }; + +struct util // hack to avoid duplicate symbols + { + static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) + { + size_t res=1; + while ((n&1)==0) + { res=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { res=x; n/=x; } + if (n>1) res=n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess (size_t n) + { + constexpr double lfp=1.1; // penalty for non-hardcoded larger factors + size_t ni=n; + double result=0.; + while ((n&1)==0) + { result+=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { + result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors + n/=x; + } + if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); + return result*double(ni); + } + + /* inner workings of good_size_cmplx() */ + template + static POCKETFFT_NOINLINE UIntT good_size_cmplx_typed(UIntT n) + { + static_assert(std::numeric_limits::is_integer && (!std::numeric_limits::is_signed), + "type must be unsigned integer"); + if (n<=12) return n; + if (n>std::numeric_limits::max()/11/2) + { + // The algorithm below doesn't work for this value, the multiplication can overflow. + if (sizeof(UIntT)(n); + if (res<=std::numeric_limits::max()) + return static_cast(res); + } + // Otherwise, this size is ridiculously large, people shouldn't be computing FFTs this large. + throw std::runtime_error("FFT size is too large."); + } + + UIntT bestfac=2*n; + for (UIntT f11=1; f11n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) + { + return good_size_cmplx_typed(n); + } + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n + and a multiple of required_factor. */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n, + size_t required_factor) + { + if (required_factor<1) + throw std::runtime_error("required factor must not be 0"); + return good_size_cmplx((n+required_factor-1)/required_factor) * required_factor; + } + + /* inner workings of good_size_real() */ + template + static POCKETFFT_NOINLINE UIntT good_size_real_typed(UIntT n) + { + static_assert(std::numeric_limits::is_integer && (!std::numeric_limits::is_signed), + "type must be unsigned integer"); + if (n<=6) return n; + if (n>std::numeric_limits::max()/5/2) + { + // The algorithm below doesn't work for this value, the multiplication can overflow. + if (sizeof(UIntT)(n); + if (res<=std::numeric_limits::max()) + return static_cast(res); + } + // Otherwise, this size is ridiculously large, people shouldn't be computing FFTs this large. + throw std::runtime_error("FFT size is too large."); + } + + UIntT bestfac=2*n; + for (UIntT f5=1; f5n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) + { + return good_size_real_typed(n); + } + /* returns the smallest composite of 2, 3, 5 which is >= n + and a multiple of required_factor. */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n, + size_t required_factor) + { + if (required_factor<1) + throw std::runtime_error("required factor must not be 0"); + return good_size_real((n+required_factor-1)/required_factor) * required_factor; + } + + /* inner workings of prev_good_size_cmplx() */ + template + static POCKETFFT_NOINLINE UIntT prev_good_size_cmplx_typed(UIntT n) + { + static_assert(std::numeric_limits::is_integer && (!std::numeric_limits::is_signed), + "type must be unsigned integer"); + if (n<=12) return n; + if (n>std::numeric_limits::max()/11) + { + // The algorithm below doesn't work for this value, the multiplication can overflow. + if (sizeof(UIntT)(n); + if (res<=std::numeric_limits::max()) + return static_cast(res); + } + // Otherwise, this size is ridiculously large, people shouldn't be computing FFTs this large. + throw std::runtime_error("FFT size is too large."); + } + + UIntT bestfound = 1; + for (UIntT f11 = 1;f11 <= n; f11 *= 11) + for (UIntT f117 = f11; f117 <= n; f117 *= 7) + for (UIntT f1175 = f117; f1175 <= n; f1175 *= 5) + { + UIntT x = f1175; + while (x*2 <= n) x *= 2; + if (x > bestfound) bestfound = x; + while (true) + { + if (x * 3 <= n) x *= 3; + else if (x % 2 == 0) x /= 2; + else break; + + if (x > bestfound) bestfound = x; + } + } + return bestfound; + } + /* returns the largest composite of 2, 3, 5, 7 and 11 which is <= n */ + static POCKETFFT_NOINLINE size_t prev_good_size_cmplx(size_t n) + { + return prev_good_size_cmplx_typed(n); + } + + /* inner workings of prev_good_size_real() */ + template + static POCKETFFT_NOINLINE UIntT prev_good_size_real_typed(UIntT n) + { + static_assert(std::numeric_limits::is_integer && (!std::numeric_limits::is_signed), + "type must be unsigned integer"); + if (n<=6) return n; + if (n>std::numeric_limits::max()/5) + { + // The algorithm below doesn't work for this value, the multiplication can overflow. + if (sizeof(UIntT)(n); + if (res<=std::numeric_limits::max()) + return static_cast(res); + } + // Otherwise, this size is ridiculously large, people shouldn't be computing FFTs this large. + throw std::runtime_error("FFT size is too large."); + } + + UIntT bestfound = 1; + for (UIntT f5 = 1; f5 <= n; f5 *= 5) + { + UIntT x = f5; + while (x*2 <= n) x *= 2; + if (x > bestfound) bestfound = x; + while (true) + { + if (x * 3 <= n) x *= 3; + else if (x % 2 == 0) x /= 2; + else break; + + if (x > bestfound) bestfound = x; + } + } + return bestfound; + } + /* returns the largest composite of 2, 3, 5 which is <= n */ + static POCKETFFT_NOINLINE size_t prev_good_size_real(size_t n) + { + return prev_good_size_real_typed(n); + } + + static size_t prod(const shape_t &shape) + { + size_t res=1; + for (auto sz: shape) + res*=sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace) + { + auto ndim = shape.size(); + if (ndim<1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in!=stride_out)) + throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + const shape_t &axes) + { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim,0); + for (auto ax : axes) + { + if (ax>=ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + size_t axis) + { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif + }; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { return 0; } +constexpr inline size_t num_threads() { return 1; } + +template +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + +inline size_t &thread_id() + { + static thread_local size_t thread_id_=0; + return thread_id_; + } +inline size_t &num_threads() + { + static thread_local size_t num_threads_=1; + return num_threads_; + } +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch + { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n): num_left_(n) {} + + void count_down() + { + lock_t lock(mut_); + if (--num_left_) + return; + completed_.notify_all(); + } + + void wait() + { + lock_t lock(mut_); + completed_.wait(lock, [this]{ return is_ready(); }); + } + bool is_ready() { return num_left_ == 0; } + }; + +template class concurrent_queue + { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + + void push(T val) + { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) + { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { return size_==0; } + }; + +// C++ allocator with support for over-aligned types +template struct aligned_allocator + { + using value_type = T; + template + aligned_allocator(const aligned_allocator&) {} + aligned_allocator() = default; + + T *allocate(size_t n) + { + void* mem = aligned_alloc(alignof(T), n*sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) + { aligned_dealloc(p); } + }; + +class thread_pool + { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker + { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) + { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) + { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) + { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) + { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) + { + if (!marked_busy && busy_flag.test_and_set()) + { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) + { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() + { + lock_t lock(mut_); + size_t nthreads=workers_.size(); + for (size_t i=0; ibusy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread([worker, this] + { + worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); + }); + } + catch (...) + { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() + { + shutdown_ = true; + for (auto &worker : workers_) + worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) + worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads): + workers_(nthreads) + { create_threads(); } + + thread_pool(): thread_pool(max_threads) {} + + ~thread_pool() { shutdown(); } + + void submit(std::function work) + { + lock_t lock(mut_); + if (shutdown_) + throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) + { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() + { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() + { + shutdown_ = false; + create_threads(); + } + }; + +inline thread_pool & get_pool() + { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + + return pool; + } + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) + { + if (nthreads == 0) + nthreads = max_threads; + + if (nthreads == 1) + { f(); return; } + + auto & pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i=0; i lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) + std::rethrow_exception(ex); + } + +#endif + +} + +// +// complex FFTPACK transforms +// + +template class cfftp + { + private: + struct fctdata + { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +template void pass2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx,0,k), t1, t2; \ + PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ + CH(idx,k,0)=t0+t1; +#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ + } +#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass3 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r=-0.5, + tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void pass4 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + else + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + for (size_t i=1; i(t4); + CH(i,k,0) = t2+t3; + special_mul(t1+t4,WA(0,i),CH(i,k,1)); + special_mul(t2-t3,WA(1,i),CH(i,k,2)); + special_mul(t1-t4,WA(2,i),CH(i,k,3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx,0,k), t1, t2, t3, t4; \ + PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ + PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ + CH(idx,k,0).r=t0.r+t1.r+t2.r; \ + CH(idx,k,0).i=t0.i+t1.i+t2.i; + +#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb,da,db; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass5 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), + tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), + tw2r= T0(-0.8090169943749474241022934171828191L), + tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass7(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), + tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r= T0(-0.2225209339563144042889025644967948L), + tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r= T0(-0.9009688679024191262361023195074451L), + tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+7*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void ROTX45(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + } +template void ROTX135(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + } + +template void pass8 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+8*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + else + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + for (size_t i=1; i(a7); + PMINPLACE(a1,a3); + ROTX90(a3); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + PM(a0,a4,CC(i,0,k),CC(i,4,k)); + PM(a2,a6,CC(i,2,k),CC(i,6,k)); + PMINPLACE(a0,a2); + CH(i,k,0) = a0+a1; + special_mul(a0-a1,WA(3,i),CH(i,k,4)); + special_mul(a2+a3,WA(1,i),CH(i,k,2)); + special_mul(a2-a3,WA(5,i),CH(i,k,6)); + ROTX90(a6); + PMINPLACE(a4,a6); + special_mul(a4+a5,WA(0,i),CH(i,k,1)); + special_mul(a4-a5,WA(4,i),CH(i,k,5)); + special_mul(a6+a7,WA(2,i),CH(i,k,3)); + special_mul(a6-a7,WA(6,i),CH(i,k,7)); + } + } + } + + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ + PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ + PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ + PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ + PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ + CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ + CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ + { \ + T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ + cb; \ + cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ + cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ + PM(out1,out2,ca,cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) +#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + { \ + T da,db; \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ + special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass11 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), + tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r= T0(0.4154150130018864255292741492296232L), + tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r= T0(-0.1423148382732851404437926686163697L), + tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r= T0(-0.6548607339452850640569250724662936L), + tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r= T0(-0.9594929736144973898903680570663277L), + tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+11*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void passg (size_t ido, size_t ip, + size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa, + const cmplx * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph = (ip+1)/2; + size_t idl1 = ido*l1; + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& + { return ch[a+idl1*b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k=0; kip) iwal-=ip; + cmplx xwal=wal[iwal]; + iwal+=l; if (iwal>ip) iwal-=ip; + cmplx xwal2=wal[iwal]; + for (size_t ik=0; ikip) iwal-=ip; + cmplx xwal=wal[iwal]; + for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); + idij=(jc-1)*(ido-1)+i-1; + special_mul(x2,wa[idij],CX(i,k,jc)); + } + } + } + } + +template void pass_all(T c[], T0 fct) const + { + if (length==1) { c[0]*=fct; return; } + size_t l1=1; + arr ch(length); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if(ip==2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if(ip==3) + pass3 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==5) + pass5 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==7) + pass7 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==11) + pass11 (ido, l1, p1, p2, fact[k1].tw); + else + { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1,p2); + } + std::swap(p1,p2); + l1=l2; + } + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const + { fwd ? pass_all(c, fct) : pass_all(c, fct); } + + private: + POCKETFFT_NOINLINE void factorize() + { + size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } + while ((len&3)==0) + { add_factor(4); len>>=2; } + if ((len&1)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsize=0, l1=1; + for (size_t k=0; k11) + twsize+=ip; + l1*=ip; + } + return twsize; + } + + void comp_twiddle() + { + sincos_2pibyn twiddle(length); + size_t l1=1; + size_t memofs=0; + for (size_t k=0; k11) + { + fact[k].tws=mem.data()+memofs; + memofs+=ip; + for (size_t j=0; j class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +/* (a+ib) = conj(c+id) * (e+if) */ +template inline void MULPM + (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const + { a=c*e+d*f; b=c*f-d*e; } + +template void radf2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+2*c)]; }; + + for (size_t k=0; k void radf3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+3*c)]; }; + + for (size_t k=0; k void radf4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+4*c)]; }; + + for (size_t k=0; k void radf5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+5*c)]; }; + + for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1] (size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik void radb2(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1](size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + for (size_t k=0; kip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const + { + if (length==1) { c[0]*=fct; return; } + size_t nf=fact.size(); + arr ch(length); + T *p1=c, *p2=ch.data(); + + if (r2hc) + for(size_t k1=0, l1=length; k1>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + + void comp_twiddle() + { + sincos_2pibyn twid(length); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k5) // special factors required by *g functions + { + fact[k].tws=ptr; ptr+=2*ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) + { + fact[k].tws[i ] = twid[i/2*(length/ip)].r; + fact[k].tws[i+1] = twid[i/2*(length/ip)].i; + fact[k].tws[ic] = twid[i/2*(length/ip)].r; + fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; + } + } + l1*=ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template class fftblue + { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template void fft(cmplx c[], T0 fct) const + { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m=0; m(c[m],bk[m],akf[m]); + auto zero = akf[0]*T0(0); + for (size_t m=n; m(bkf[0]); + for (size_t m=1; m<(n2+1)/2; ++m) + { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); + } + if ((n2&1)==0) + akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); + + /* inverse FFT */ + plan.exec (akf.data(),1.,false); + + /* multiply by b_k */ + for (size_t m=0; m(bk[m])*fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), + bk(mem.data()), bkf(mem.data()+n) + { + /* initialize b_k */ + sincos_2pibyn tmp(2*n); + bk[0].Set(1, 0); + + size_t coeff=0; + for (size_t m=1; m=2*n) coeff-=2*n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1)/T0(n2); + tbkf[0] = bk[0]*xn2; + for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const + { fwd ? fft(c,fct) : fft(c,fct); } + + template void exec_r(T c[], T0 fct, bool fwd) + { + arr> tmp(n); + if (fwd) + { + auto zero = T0(0)*c[0]; + for (size_t m=0; m(tmp.data(),fct); + c[0] = tmp[0].r; + std::copy_n (&tmp[1].r, n-1, &c[1]); + } + else + { + tmp[0].Set(c[0],c[0]*0); + std::copy_n (c+1, n-1, &tmp[1].r); + if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; + for (size_t m=1; 2*m(tmp.data(),fct); + for (size_t m=0; m class pocketfft_c + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new cfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } + + size_t length() const { return len; } + }; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template class pocketfft_r + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5*util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new rfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } + + size_t length() const { return len; } + }; + + +// +// sine/cosine transforms +// + +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dcst23 + { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + sincos_2pibyn tw(4*length); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + size_t NS2 = (N+1)/2; + if (type==2) + { + if (!cosine) + for (size_t k=1; k class T_dcst4 + { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + if ((N&1)==0) + { + sincos_2pibyn tw(16*N); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k y(N); + { + size_t i=0, m=n2; + for (; mexec(y.data(), fct, true); + { + auto SGN = [](size_t i) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + return (i&2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k> y(n2); + for(size_t i=0; iexec(y.data(), fct, true); + for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) + { +#if POCKETFFT_CACHE_SIZE==0 + return std::make_shared(length); +#else + constexpr size_t nmax=POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr + { + for (size_t i=0; ilength()==length)) + { + // no need to update if this is already the most recent entry + if (last_access[i]!=access_counter) + { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) + last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i=1; i class cndarr: public arr_info + { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) {} + const T &operator[](ptrdiff_t ofs) const + { return *reinterpret_cast(d+ofs); } + }; + +template class ndarr: public cndarr + { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) + {} + T &operator[](ptrdiff_t ofs) + { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } + }; + +template class multi_iter + { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() + { + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + if (i==idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) + return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), + str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), + idim(idim_), rem(iarr.size()/iarr.shape(idim)) + { + auto nshares = threading::num_threads(); + if (nshares==1) return; + if (nshares==0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare>=nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem/nshares; + size_t additional = rem%nshares; + size_t lo = myshare*nbase + ((myshare=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + +class rev_iter + { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), p(0), rp(0) + { + for (auto ax: axes) + rev_axis[ax]=1; + last_axis = axes.back(); + last_size = arr.shape(last_axis)/2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem=1; + for (auto i: shp) + rem *= i; + } + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else + { + rp -= arr.stride(i); + if (rev_jump[i]) + { + rp += ptrdiff_t(arr.shape(i))*arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) + return; + pos[i] = 0; + p -= ptrdiff_t(shp[i])*arr.stride(i); + if (rev_axis[i]) + { + rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); + rev_jump[i] = 1; + } + else + rp -= ptrdiff_t(shp[i])*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + ptrdiff_t rev_ofs() const { return rp; } + size_t remaining() const { return rem; } + }; + +template struct VTYPE {}; +template using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> struct VTYPE + { + using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); + }; +template<> struct VTYPE + { + using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); + }; +template<> struct VTYPE + { + using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); + }; +#endif + +template arr alloc_tmp(const shape_t &shape, + size_t axsize, size_t elemsize) + { + auto othersize = util::prod(shape)/axsize; + auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); + return arr(tmpsize*elemsize); + } +template arr alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) + { + size_t fullsize=util::prod(shape); + size_t tmpsize=0; + for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); + if (sz>tmpsize) tmpsize=sz; + } + return arr(tmpsize*elemsize); + } + +template void copy_input(const multi_iter &it, + const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, T *POCKETFFT_RESTRICT dst) + { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i=0; i void copy_output(const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i=0; i struct add_vec { using type = vtype_t; }; +template struct add_vec> + { using type = cmplx>; }; +template using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, + const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, + const bool allow_inplace=true) + { + std::shared_ptr plan; + + for (size_t iax=0; iaxlength())) + plan = get_plan(len); + + threading::thread_map( + util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) ? + &out[it.oofs(0)] : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } + } + +struct ExecC2C + { + bool forward; + + template void operator () ( + const multi_iter &it, const cndarr> &in, + ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } + }; + +template void copy_hartley(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t j=0; j void copy_hartley(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + dst[it.oofs(0)] = src[0]; + size_t i=1, i1=1, i2=it.length_out()-1; + for (i=1; i void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, + T * buf, const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } + }; + +struct ExecDcst + { + bool ortho; + int type; + bool cosine; + + template + void operator () (const multi_iter &it, const cndarr &in, + ndarr &out, T * buf, const Tplan &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } + }; + +template POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(in.shape(axis)); + size_t len=in.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j=0; j0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i=1, ii=1; + if (forward) + for (; i POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(out.shape(axis)); + size_t len=out.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j=0; jexec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0]=in[it.iofs(0)].r; + { + size_t i=1, ii=1; + if (forward) + for (; iexec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region + } + +struct ExecR2R + { + bool r2h, forward; + + template void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, T * buf, + const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); + } + +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis]/2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, + fct, nthreads); + if (axes.size()==1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, + T(1), nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis]/2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + if (axes.size()==1) + return c2r(shape_out, stride_in, stride_out, axes[0], forward, + data_in, data_out, fct, nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i=int(shape_in.size())-2; i>=0; --i) + stride_inter[size_t(i)] = + stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), + T(1), nthreads); + c2r(shape_out, stride_inter, stride_out, axes.back(), forward, + tmp.data(), data_out, fct, nthreads); + } + +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, + ExecR2R{real2hermitian, forward}); + } + +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, + false); + } + +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()]/2+1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back()=sizeof(std::complex); + for (size_t i=tstride.size()-1; i>0; --i) + tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while(iin.remaining()>0) + { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r+v.i; + aout[iout.rev_ofs()] = v.r-v.i; + iin.advance(); iout.advance(); + } + } + +} // namespace detail + +using detail::FORWARD; +using detail::BACKWARD; +using detail::shape_t; +using detail::stride_t; +using detail::c2c; +using detail::c2r; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_separable_hartley; +using detail::r2r_genuine_hartley; +using detail::dct; +using detail::dst; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H From 41752c5a36ac6e88d79b2d3096710a746f0ff6d9 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 18 Nov 2024 12:57:25 +0100 Subject: [PATCH 07/12] Ubuntu 4.x build fix. MacOS fix attempt. --- src/owl/dune | 1 + src/owl/fftpack/owl_fftpack_float32.cc | 8 ++++++++ src/owl/fftpack/owl_fftpack_float64.cc | 8 ++++++++ src/owl/fftpack/owl_fftpack_impl.h | 8 ++++++++ 4 files changed, 25 insertions(+) diff --git a/src/owl/dune b/src/owl/dune index 0bd4c0914..1aa0bbd45 100644 --- a/src/owl/dune +++ b/src/owl/dune @@ -52,6 +52,7 @@ owl_fftpack_float64) (flags :standard + -std=c++11 (:include c_flags.sexp))) (foreign_stubs (language c) diff --git a/src/owl/fftpack/owl_fftpack_float32.cc b/src/owl/fftpack/owl_fftpack_float32.cc index c50778385..f18d99ae7 100644 --- a/src/owl/fftpack/owl_fftpack_float32.cc +++ b/src/owl/fftpack/owl_fftpack_float32.cc @@ -10,6 +10,14 @@ extern "C" { #include "owl_core.h" + value float32_cfft(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); + value float32_cfft_bytecode(value *argv, int argn); + value float32_rfftf(value vX, value vY, value vD, value vNorm, value vNthreads); + value float32_rfftb(value vX, value vY, value vD, value vNorm, value vNthreads); + value float32_dct(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value float32_dct_bytecode(value *argv, int argn); + value float32_dst(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value float32_dst_bytecode(value *argv, int argn); } #define REAL_COPY owl_float32_copy diff --git a/src/owl/fftpack/owl_fftpack_float64.cc b/src/owl/fftpack/owl_fftpack_float64.cc index b34050df0..2051941c9 100644 --- a/src/owl/fftpack/owl_fftpack_float64.cc +++ b/src/owl/fftpack/owl_fftpack_float64.cc @@ -9,6 +9,14 @@ extern "C" { #include "owl_core.h" + value float64_cfft(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); + value float64_cfft_bytecode(value *argv, int argn); + value float64_rfftf(value vX, value vY, value vD, value vNorm, value vNthreads); + value float64_rfftb(value vX, value vY, value vD, value vNorm, value vNthreads); + value float64_dct(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value float64_dct_bytecode(value *argv, int argn); + value float64_dst(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value float64_dst_bytecode(value *argv, int argn); } #define REAL_COPY owl_float64_copy diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index 4f6018e35..3def4d2e4 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -5,8 +5,16 @@ #ifdef Treal +#ifdef CAML_COMPATIBILITY_H +#undef invalid_argument /* For version < 5.0 of the OCaml compiler, allowing std::invalid_argument to be used */ +#endif + #include "pocketfft_hdronly.h" +#ifdef CAML_COMPATIBILITY_H +#define invalid_argument caml_invalid_argument +#endif + /** Owl's interface function to pocketfft **/ /** Adapted from scipy's pypocketfft.cxx **/ From e9a1f08b31690528d5f8911c61479d07fc00c250 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 18 Nov 2024 15:04:00 +0100 Subject: [PATCH 08/12] Moved the definition outside extern C. - Kept the declaration inside the extern C but moved away the declaration, as an attempt to fix MacOS builds. --- src/owl/fftpack/owl_fftpack_impl.h | 540 ++++++++++++++--------------- 1 file changed, 268 insertions(+), 272 deletions(-) diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index 3def4d2e4..12bb48420 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -51,342 +51,338 @@ T compute_norm_factor(const shape_t &dims, const shape_t &axes, int inorm, size_ return norm_fct(inorm, N); } -extern "C" +/** Owl's stub functions **/ + +/** + * Complex-to-complex FFT + * @param forward: true for forward transform, false for backward transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ +value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads) { + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); - /** Owl's stub functions **/ - - /** - * Complex-to-complex FFT - * @param forward: true for forward transform, false for backward transform - * @param X: input array - * @param Y: output array - * @param d: dimension along which to perform the transform - * @param norm: normalization factor - * @param nthreads: number of threads to use - * - * @return unit - */ - value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads) - { - struct caml_ba_array *X = Caml_ba_array_val(vX); - std::complex *X_data = reinterpret_cast *>(X->data); + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); - struct caml_ba_array *Y = Caml_ba_array_val(vY); - std::complex *Y_data = reinterpret_cast *>(Y->data); + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); + int forward = Bool_val(vForward); - int d = Long_val(vD); - int n = X->dim[d]; - int norm = Long_val(vNorm); - int nthreads = Long_val(vNthreads); - int forward = Bool_val(vForward); + shape_t dims; + stride_t stride_in, stride_out; - shape_t dims; - stride_t stride_in, stride_out; + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - for (int i = 0; i < X->num_dims; ++i) - { - dims.push_back(static_cast(X->dim[i])); - } + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - size_t multiplier = sizeof(std::complex); - for (int i = 0; i < X->num_dims; ++i) + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try { - stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); - stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + pocketfft::c2c(dims, stride_in, stride_out, axes, forward, + X_data, Y_data, norm_factor, nthreads); } - - shape_t axes{static_cast(d)}; + catch (const std::exception &e) { - Treal norm_factor = compute_norm_factor(dims, axes, norm); - try - { - pocketfft::c2c(dims, stride_in, stride_out, axes, forward, - X_data, Y_data, norm_factor, nthreads); - } - catch (const std::exception &e) - { - caml_failwith(e.what()); // maybe raise an OCaml exception here ?? - } + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? } - - return Val_unit; } - /** - * Complex-to-complex FFT - * @param argv: array of arguments - * @param argn: number of arguments - * @see STUB_CFFT, https://ocaml.org/manual/5.2/intfc.html#ss:c-prim-impl - */ - value STUB_CFFT_bytecode(value *argv, int argn) - { - return STUB_CFFT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]); - } + return Val_unit; +} - /** - * Forward Real-to-complex FFT - * @param X: input array (real data) - * @param Y: output array (complex data) - * @param d: dimension along which to perform the transform - * @param norm: normalization factor - * @param nthreads: number of threads to use - * - * @return unit - */ - value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads) - { - struct caml_ba_array *X = Caml_ba_array_val(vX); - Treal *X_data = reinterpret_cast(X->data); +/** + * Complex-to-complex FFT + * @param argv: array of arguments + * @param argn: number of arguments + * @see STUB_CFFT, https://ocaml.org/manual/5.2/intfc.html#ss:c-prim-impl + */ +value STUB_CFFT_bytecode(value *argv, int argn) +{ + return STUB_CFFT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5]); +} - struct caml_ba_array *Y = Caml_ba_array_val(vY); - std::complex *Y_data = reinterpret_cast *>(Y->data); +/** + * Forward Real-to-complex FFT + * @param X: input array (real data) + * @param Y: output array (complex data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ +value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads) +{ + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); - int d = Long_val(vD); - int n = X->dim[d]; - int norm = Long_val(vNorm); - int nthreads = Long_val(vNthreads); + struct caml_ba_array *Y = Caml_ba_array_val(vY); + std::complex *Y_data = reinterpret_cast *>(Y->data); - shape_t dims; - stride_t stride_in, stride_out; + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); - for (int i = 0; i < X->num_dims; ++i) - { - dims.push_back(static_cast(X->dim[i])); - } + shape_t dims; + stride_t stride_in, stride_out; - size_t multiplier = sizeof(Treal); - for (int i = 0; i < X->num_dims; ++i) - { - stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); - } + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } + + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - multiplier = sizeof(std::complex); - for (int i = 0; i < Y->num_dims; ++i) + multiplier = sizeof(std::complex); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } + + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try { - stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD, + X_data, Y_data, norm_factor, nthreads); } - - shape_t axes{static_cast(d)}; + catch (const std::exception &e) { - Treal norm_factor = compute_norm_factor(dims, axes, norm); - try - { - pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD, - X_data, Y_data, norm_factor, nthreads); - } - catch (const std::exception &e) - { - caml_failwith(e.what()); // maybe raise an OCaml exception here ?? - } + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? } - - return Val_unit; } - /** - * Backward Real-to-complex FFT - * @param X: input array (complex data) - * @param Y: output array (real data) - * @param d: dimension along which to perform the transform - * @param norm: normalization factor - * @param nthreads: number of threads to use - * - * @return unit - */ - value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads) - { - struct caml_ba_array *X = Caml_ba_array_val(vX); - std::complex *X_data = reinterpret_cast *>(X->data); + return Val_unit; +} + +/** + * Backward Real-to-complex FFT + * @param X: input array (complex data) + * @param Y: output array (real data) + * @param d: dimension along which to perform the transform + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ +value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads) +{ + struct caml_ba_array *X = Caml_ba_array_val(vX); + std::complex *X_data = reinterpret_cast *>(X->data); - struct caml_ba_array *Y = Caml_ba_array_val(vY); - Treal *Y_data = reinterpret_cast(Y->data); + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); - int d = Long_val(vD); - int n = X->dim[d]; - int norm = Long_val(vNorm); - int nthreads = Long_val(vNthreads); + int d = Long_val(vD); + int n = X->dim[d]; + int norm = Long_val(vNorm); + int nthreads = Long_val(vNthreads); - if (Y->dim[d] != (X->dim[d] - 1) * 2) - caml_failwith("Invalid output dimension for inverse real FFT"); + if (Y->dim[d] != (X->dim[d] - 1) * 2) + caml_failwith("Invalid output dimension for inverse real FFT"); - shape_t dims; - stride_t stride_in, stride_out; + shape_t dims; + stride_t stride_in, stride_out; - int ncomplex = X->dim[d]; - int nreal = Y->dim[d]; + int ncomplex = X->dim[d]; + int nreal = Y->dim[d]; - for (int i = 0; i < X->num_dims; ++i) + for (int i = 0; i < X->num_dims; ++i) + { + if (i == d) { - if (i == d) - { - dims.push_back(static_cast(nreal)); - } - else - { - dims.push_back(static_cast(X->dim[i])); - } + dims.push_back(static_cast(nreal)); } - - size_t multiplier = sizeof(std::complex); - for (int i = 0; i < X->num_dims; ++i) + else { - stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + dims.push_back(static_cast(X->dim[i])); } + } + + size_t multiplier = sizeof(std::complex); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + } - multiplier = sizeof(Treal); - for (int i = 0; i < Y->num_dims; ++i) + multiplier = sizeof(Treal); + for (int i = 0; i < Y->num_dims; ++i) + { + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } + + shape_t axes{static_cast(d)}; + { + Treal norm_factor = compute_norm_factor(dims, axes, norm); + try { - stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD, + X_data, Y_data, norm_factor, nthreads); } - - shape_t axes{static_cast(d)}; + catch (const std::exception &e) { - Treal norm_factor = compute_norm_factor(dims, axes, norm); - try - { - pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD, - X_data, Y_data, norm_factor, nthreads); - } - catch (const std::exception &e) - { - caml_failwith(e.what()); // maybe raise an OCaml exception here ?? - } + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? } - - return Val_unit; } - /** - * Discrete Cosine Transform - * @param X: input array - * @param Y: output array - * @param d: dimension along which to perform the transform - * @param type: type of DCT (1, 2, 3, or 4) - * @param norm: normalization factor - * @param nthreads: number of threads to use - * - * @return unit - */ - value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) - { - struct caml_ba_array *X = Caml_ba_array_val(vX); - Treal *X_data = reinterpret_cast(X->data); + return Val_unit; +} - struct caml_ba_array *Y = Caml_ba_array_val(vY); - Treal *Y_data = reinterpret_cast(Y->data); +/** + * Discrete Cosine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DCT (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ +value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) +{ + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); - int d = Long_val(vD); - int n = X->dim[d]; - int type = Long_val(vType); - if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side - caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); - int norm = Long_val(vNorm); - bool ortho = Bool_val(vOrtho); - int nthreads = Long_val(vNthreads); + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); - shape_t dims; - stride_t stride_in, stride_out; + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); - for (int i = 0; i < X->num_dims; ++i) - { - dims.push_back(static_cast(X->dim[i])); - } + shape_t dims; + stride_t stride_in, stride_out; + + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } + + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } - size_t multiplier = sizeof(Treal); - for (int i = 0; i < X->num_dims; ++i) + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, -1) + : compute_norm_factor(dims, axes, norm, 2); + try { - stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); - stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + pocketfft::dct(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); } - - shape_t axes{static_cast(d)}; + catch (const std::exception &e) { - Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, -1) - : compute_norm_factor(dims, axes, norm, 2); - try - { - pocketfft::dct(dims, stride_in, stride_out, axes, type, - X_data, Y_data, norm_factor, ortho, nthreads); - } - catch (const std::exception &e) - { - caml_failwith(e.what()); // maybe raise an OCaml exception here ?? - } + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? } - - return Val_unit; } - value STUB_RDCT_bytecode(value *argv, int argn) - { - return STUB_RDCT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); - } + return Val_unit; +} - /** - * Discrete Sine Transform - * @param X: input array - * @param Y: output array - * @param d: dimension along which to perform the transform - * @param type: type of DST (1, 2, 3, or 4) - * @param norm: normalization factor - * @param nthreads: number of threads to use - * - * @return unit - */ - value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) - { - struct caml_ba_array *X = Caml_ba_array_val(vX); - Treal *X_data = reinterpret_cast(X->data); +value STUB_RDCT_bytecode(value *argv, int argn) +{ + return STUB_RDCT(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); +} + +/** + * Discrete Sine Transform + * @param X: input array + * @param Y: output array + * @param d: dimension along which to perform the transform + * @param type: type of DST (1, 2, 3, or 4) + * @param norm: normalization factor + * @param nthreads: number of threads to use + * + * @return unit + */ +value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads) +{ + struct caml_ba_array *X = Caml_ba_array_val(vX); + Treal *X_data = reinterpret_cast(X->data); - struct caml_ba_array *Y = Caml_ba_array_val(vY); - Treal *Y_data = reinterpret_cast(Y->data); + struct caml_ba_array *Y = Caml_ba_array_val(vY); + Treal *Y_data = reinterpret_cast(Y->data); - int d = Long_val(vD); - int n = X->dim[d]; - int type = Long_val(vType); - if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side - caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); - int norm = Long_val(vNorm); - bool ortho = Bool_val(vOrtho); - int nthreads = Long_val(vNthreads); + int d = Long_val(vD); + int n = X->dim[d]; + int type = Long_val(vType); + if (type < 1 || type > 4) // should not happen as it's checked on the OCaml side + caml_failwith("invalid value for type (must be 1, 2, 3, or 4)"); + int norm = Long_val(vNorm); + bool ortho = Bool_val(vOrtho); + int nthreads = Long_val(vNthreads); - shape_t dims; - stride_t stride_in, stride_out; + shape_t dims; + stride_t stride_in, stride_out; - for (int i = 0; i < X->num_dims; ++i) - { - dims.push_back(static_cast(X->dim[i])); - } + for (int i = 0; i < X->num_dims; ++i) + { + dims.push_back(static_cast(X->dim[i])); + } - size_t multiplier = sizeof(Treal); - for (int i = 0; i < X->num_dims; ++i) + size_t multiplier = sizeof(Treal); + for (int i = 0; i < X->num_dims; ++i) + { + stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); + stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + } + + shape_t axes{static_cast(d)}; + { + Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) + : compute_norm_factor(dims, axes, norm, 2); + try { - stride_in.push_back(c_ndarray_stride_dim(X, i) * multiplier); - stride_out.push_back(c_ndarray_stride_dim(Y, i) * multiplier); + pocketfft::dst(dims, stride_in, stride_out, axes, type, + X_data, Y_data, norm_factor, ortho, nthreads); } - - shape_t axes{static_cast(d)}; + catch (const std::exception &e) { - Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) - : compute_norm_factor(dims, axes, norm, 2); - try - { - pocketfft::dst(dims, stride_in, stride_out, axes, type, - X_data, Y_data, norm_factor, ortho, nthreads); - } - catch (const std::exception &e) - { - caml_failwith(e.what()); // maybe raise an OCaml exception here ?? - } + caml_failwith(e.what()); // maybe raise an OCaml exception here ?? } - - return Val_unit; } - value STUB_RDST_bytecode(value *argv, int argn) - { - return STUB_RDST(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); - } -} // extern "C" + return Val_unit; +} + +value STUB_RDST_bytecode(value *argv, int argn) +{ + return STUB_RDST(argv[0], argv[1], argv[2], argv[3], argv[4], argv[5], argv[6]); +} #endif // Treal From 5fcd0c3cd36f26a6f2a8f7bd96a2a306dfa3dd4d Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Mon, 18 Nov 2024 15:43:05 +0100 Subject: [PATCH 09/12] owl_core preprocessor for cpp compiling. --- src/owl/core/owl_core.h | 258 +++++++++++++------------ src/owl/fftpack/owl_fftpack_float32.cc | 30 ++- src/owl/fftpack/owl_fftpack_float64.cc | 30 ++- 3 files changed, 161 insertions(+), 157 deletions(-) diff --git a/src/owl/core/owl_core.h b/src/owl/core/owl_core.h index 9a7cfd62f..01e7b743b 100644 --- a/src/owl/core/owl_core.h +++ b/src/owl/core/owl_core.h @@ -11,220 +11,232 @@ #include "owl_ndarray_contract.h" #include // DEBUG +#ifdef __cplusplus +extern "C" +{ +#endif -/** Core function declaration **/ + /** Core function declaration **/ + extern int64_t c_ndarray_numel(struct caml_ba_array *X); -extern int64_t c_ndarray_numel (struct caml_ba_array *X); + extern int64_t c_ndarray_stride_dim(struct caml_ba_array *X, int d); -extern int64_t c_ndarray_stride_dim (struct caml_ba_array *X, int d); + extern int64_t c_ndarray_slice_dim(struct caml_ba_array *X, int d); -extern int64_t c_ndarray_slice_dim (struct caml_ba_array *X, int d); + extern void c_float32_ndarray_transpose(struct slice_pair *sp); -extern void c_float32_ndarray_transpose (struct slice_pair *sp); + extern void c_float64_ndarray_transpose(struct slice_pair *sp); -extern void c_float64_ndarray_transpose (struct slice_pair *sp); + extern void c_complex32_ndarray_transpose(struct slice_pair *sp); -extern void c_complex32_ndarray_transpose (struct slice_pair *sp); + extern void c_complex64_ndarray_transpose(struct slice_pair *sp); -extern void c_complex64_ndarray_transpose (struct slice_pair *sp); + extern void c_float32_ndarray_contract_one(struct contract_pair *sp); -extern void c_float32_ndarray_contract_one (struct contract_pair *sp); + extern void c_float64_ndarray_contract_one(struct contract_pair *sp); -extern void c_float64_ndarray_contract_one (struct contract_pair *sp); + extern void c_complex32_ndarray_contract_one(struct contract_pair *sp); -extern void c_complex32_ndarray_contract_one (struct contract_pair *sp); + extern void c_complex64_ndarray_contract_one(struct contract_pair *sp); -extern void c_complex64_ndarray_contract_one (struct contract_pair *sp); + extern void c_float32_ndarray_contract_two(struct contract_pair *sp); -extern void c_float32_ndarray_contract_two (struct contract_pair *sp); + extern void c_float64_ndarray_contract_two(struct contract_pair *sp); -extern void c_float64_ndarray_contract_two (struct contract_pair *sp); + extern void c_complex32_ndarray_contract_two(struct contract_pair *sp); -extern void c_complex32_ndarray_contract_two (struct contract_pair *sp); + extern void c_complex64_ndarray_contract_two(struct contract_pair *sp); -extern void c_complex64_ndarray_contract_two (struct contract_pair *sp); + extern void c_float32_matrix_swap_rows(float *x, int m, int n, int i, int j); -extern void c_float32_matrix_swap_rows (float *x, int m, int n, int i, int j); + extern void c_float64_matrix_swap_rows(double *x, int m, int n, int i, int j); -extern void c_float64_matrix_swap_rows (double *x, int m, int n, int i, int j); + extern void c_complex32_matrix_swap_rows(_Complex float *x, int m, int n, int i, int j); -extern void c_complex32_matrix_swap_rows (_Complex float *x, int m, int n, int i, int j); + extern void c_complex64_matrix_swap_rows(_Complex double *x, int m, int n, int i, int j); -extern void c_complex64_matrix_swap_rows (_Complex double *x, int m, int n, int i, int j); + extern void c_float32_matrix_swap_cols(float *x, int m, int n, int i, int j); -extern void c_float32_matrix_swap_cols (float *x, int m, int n, int i, int j); + extern void c_float64_matrix_swap_cols(double *x, int m, int n, int i, int j); -extern void c_float64_matrix_swap_cols (double *x, int m, int n, int i, int j); + extern void c_complex32_matrix_swap_cols(_Complex float *x, int m, int n, int i, int j); -extern void c_complex32_matrix_swap_cols (_Complex float *x, int m, int n, int i, int j); + extern void c_complex64_matrix_swap_cols(_Complex double *x, int m, int n, int i, int j); -extern void c_complex64_matrix_swap_cols (_Complex double *x, int m, int n, int i, int j); + extern void c_float32_matrix_transpose(float *x, float *y, int m, int n); -extern void c_float32_matrix_transpose (float *x, float *y, int m, int n); + extern void c_float64_matrix_transpose(double *x, double *y, int m, int n); -extern void c_float64_matrix_transpose (double *x, double *y, int m, int n); + extern void c_complex32_matrix_transpose(_Complex float *x, _Complex float *y, int m, int n); -extern void c_complex32_matrix_transpose (_Complex float *x, _Complex float *y, int m, int n); + extern void c_complex64_matrix_transpose(_Complex double *x, _Complex double *y, int m, int n); -extern void c_complex64_matrix_transpose (_Complex double *x, _Complex double *y, int m, int n); + extern void c_ndarray_stride(struct caml_ba_array *X, int64_t *stride); -extern void c_ndarray_stride (struct caml_ba_array *X, int64_t *stride); + extern void c_ndarray_slice(struct caml_ba_array *X, int64_t *slice); -extern void c_ndarray_slice (struct caml_ba_array *X, int64_t *slice); + extern void c_slicing_stride(struct caml_ba_array *X, int64_t *slice, int64_t *stride); -extern void c_slicing_stride (struct caml_ba_array *X, int64_t *slice, int64_t *stride); + extern void c_slicing_offset(struct caml_ba_array *X, int64_t *slice, int64_t *offset); -extern void c_slicing_offset (struct caml_ba_array *X, int64_t *slice, int64_t *offset); + extern void c_float32_ndarray_get_slice(struct slice_pair *sp); -extern void c_float32_ndarray_get_slice (struct slice_pair *sp); + extern void c_float64_ndarray_get_slice(struct slice_pair *sp); -extern void c_float64_ndarray_get_slice (struct slice_pair *sp); + extern void c_complex32_ndarray_get_slice(struct slice_pair *sp); -extern void c_complex32_ndarray_get_slice (struct slice_pair *sp); + extern void c_complex64_ndarray_get_slice(struct slice_pair *sp); -extern void c_complex64_ndarray_get_slice (struct slice_pair *sp); + extern void c_float32_ndarray_set_slice(struct slice_pair *sp); -extern void c_float32_ndarray_set_slice (struct slice_pair *sp); + extern void c_float64_ndarray_set_slice(struct slice_pair *sp); -extern void c_float64_ndarray_set_slice (struct slice_pair *sp); + extern void c_complex32_ndarray_set_slice(struct slice_pair *sp); -extern void c_complex32_ndarray_set_slice (struct slice_pair *sp); + extern void c_complex64_ndarray_set_slice(struct slice_pair *sp); -extern void c_complex64_ndarray_set_slice (struct slice_pair *sp); + extern void c_float32_ndarray_get_fancy(struct fancy_pair *sp); -extern void c_float32_ndarray_get_fancy (struct fancy_pair *sp); + extern void c_float64_ndarray_get_fancy(struct fancy_pair *sp); -extern void c_float64_ndarray_get_fancy (struct fancy_pair *sp); + extern void c_complex32_ndarray_get_fancy(struct fancy_pair *sp); -extern void c_complex32_ndarray_get_fancy (struct fancy_pair *sp); + extern void c_complex64_ndarray_get_fancy(struct fancy_pair *sp); -extern void c_complex64_ndarray_get_fancy (struct fancy_pair *sp); + extern void c_float32_ndarray_set_fancy(struct fancy_pair *sp); -extern void c_float32_ndarray_set_fancy (struct fancy_pair *sp); + extern void c_float64_ndarray_set_fancy(struct fancy_pair *sp); -extern void c_float64_ndarray_set_fancy (struct fancy_pair *sp); + extern void c_complex32_ndarray_set_fancy(struct fancy_pair *sp); -extern void c_complex32_ndarray_set_fancy (struct fancy_pair *sp); + extern void c_complex64_ndarray_set_fancy(struct fancy_pair *sp); -extern void c_complex64_ndarray_set_fancy (struct fancy_pair *sp); + // compare two numbers (real & complex & int) +#define CEQF(X, Y) ((crealf(X) == crealf(Y)) && (cimagf(X) == cimagf(Y))) -// compare two numbers (real & complex & int) +#define CEQ(X, Y) ((creal(X) == creal(Y)) && (cimag(X) == cimag(Y))) -#define CEQF(X,Y) ((crealf(X) == crealf(Y)) && (cimagf(X) == cimagf(Y))) +#define CNEQF(X, Y) ((crealf(X) != crealf(Y)) || (cimagf(X) != cimagf(Y))) -#define CEQ(X,Y) ((creal(X) == creal(Y)) && (cimag(X) == cimag(Y))) +#define CNEQ(X, Y) ((creal(X) != creal(Y)) || (cimag(X) != cimag(Y))) -#define CNEQF(X,Y) ((crealf(X) != crealf(Y)) || (cimagf(X) != cimagf(Y))) +#define CLTF(X, Y) ((cabsf(X) < cabsf(Y)) || ((cabsf(X) == cabsf(Y)) && (cargf(X) < cargf(Y)))) -#define CNEQ(X,Y) ((creal(X) != creal(Y)) || (cimag(X) != cimag(Y))) +#define CGTF(X, Y) ((cabsf(X) > cabsf(Y)) || ((cabsf(X) == cabsf(Y)) && (cargf(X) > cargf(Y)))) -#define CLTF(X,Y) ((cabsf(X) < cabsf(Y)) || ((cabsf(X) == cabsf(Y)) && (cargf(X) < cargf(Y)))) +#define CLEF(X, Y) !CGTF(X, Y) -#define CGTF(X,Y) ((cabsf(X) > cabsf(Y)) || ((cabsf(X) == cabsf(Y)) && (cargf(X) > cargf(Y)))) +#define CGEF(X, Y) !CLTF(X, Y) -#define CLEF(X,Y) !CGTF(X,Y) +#define CLT(X, Y) ((cabs(X) < cabs(Y)) || ((cabs(X) == cabs(Y)) && (carg(X) < carg(Y)))) -#define CGEF(X,Y) !CLTF(X,Y) +#define CGT(X, Y) ((cabs(X) > cabs(Y)) || ((cabs(X) == cabs(Y)) && (carg(X) > carg(Y)))) -#define CLT(X,Y) ((cabs(X) < cabs(Y)) || ((cabs(X) == cabs(Y)) && (carg(X) < carg(Y)))) +#define CLE(X, Y) !CGT(X, Y) -#define CGT(X,Y) ((cabs(X) > cabs(Y)) || ((cabs(X) == cabs(Y)) && (carg(X) > carg(Y)))) +#define CGE(X, Y) !CLT(X, Y) -#define CLE(X,Y) !CGT(X,Y) + extern int float32_cmp(const void *a, const void *b); -#define CGE(X,Y) !CLT(X,Y) + extern int float64_cmp(const void *a, const void *b); -extern int float32_cmp (const void * a, const void * b); + extern int complex32_cmp(const void *a, const void *b); -extern int float64_cmp (const void * a, const void * b); + extern int complex64_cmp(const void *a, const void *b); -extern int complex32_cmp (const void * a, const void * b); + extern int int8_cmp(const void *a, const void *b); -extern int complex64_cmp (const void * a, const void * b); + extern int uint8_cmp(const void *a, const void *b); -extern int int8_cmp (const void * a, const void * b); + extern int int16_cmp(const void *a, const void *b); -extern int uint8_cmp (const void * a, const void * b); + extern int uint16_cmp(const void *a, const void *b); -extern int int16_cmp (const void * a, const void * b); + extern int int32_cmp(const void *a, const void *b); -extern int uint16_cmp (const void * a, const void * b); + extern int int64_cmp(const void *a, const void *b); -extern int int32_cmp (const void * a, const void * b); + extern int float32_cmp_r(const void *a, const void *b, const void *z); -extern int int64_cmp (const void * a, const void * b); + extern int float64_cmp_r(const void *a, const void *b, const void *z); -extern int float32_cmp_r (const void * a, const void * b, const void * z); + extern int complex32_cmp_r(const void *a, const void *b, const void *z); -extern int float64_cmp_r (const void * a, const void * b, const void * z); + extern int complex64_cmp_r(const void *a, const void *b, const void *z); -extern int complex32_cmp_r (const void * a, const void * b, const void * z); + extern int int8_cmp_r(const void *a, const void *b, const void *z); -extern int complex64_cmp_r (const void * a, const void * b, const void * z); + extern int uint8_cmp_r(const void *a, const void *b, const void *z); -extern int int8_cmp_r (const void * a, const void * b, const void * z); + extern int int16_cmp_r(const void *a, const void *b, const void *z); -extern int uint8_cmp_r (const void * a, const void * b, const void * z); + extern int uint16_cmp_r(const void *a, const void *b, const void *z); -extern int int16_cmp_r (const void * a, const void * b, const void * z); + extern int int32_cmp_r(const void *a, const void *b, const void *z); -extern int uint16_cmp_r (const void * a, const void * b, const void * z); + extern int int64_cmp_r(const void *a, const void *b, const void *z); -extern int int32_cmp_r (const void * a, const void * b, const void * z); + // acquire CPU cache sizes -extern int int64_cmp_r (const void * a, const void * b, const void * z); + extern void query_cache_sizes(int *l1p, int *l2p, int *l3p); + // copy two double type numbers, for interfacing to foreign functions + OWL_INLINE value cp_two_doubles(double d0, double d1) + { + value res = caml_alloc_small(2 * Double_wosize, Double_array_tag); + Store_double_field(res, 0, d0); + Store_double_field(res, 1, d1); + return res; + } -// acquire CPU cache sizes - -extern void query_cache_sizes(int* l1p, int* l2p, int* l3p); - - -// copy two double type numbers, for interfacing to foreign functions -OWL_INLINE value cp_two_doubles(double d0, double d1) { - value res = caml_alloc_small(2 * Double_wosize, Double_array_tag); - Store_double_field(res, 0, d0); - Store_double_field(res, 1, d1); - return res; -} - -// copy x to y with given offset and stride -OWL_INLINE void owl_float32_copy (int N, float* x, int ofsx, int incx, float* y, int ofsy, int incy) { - for (int i = 0; i < N; i++) { - *(y + ofsy) = *(x + ofsx); - ofsx += incx; - ofsy += incy; + // copy x to y with given offset and stride + OWL_INLINE void owl_float32_copy(int N, float *x, int ofsx, int incx, float *y, int ofsy, int incy) + { + for (int i = 0; i < N; i++) + { + *(y + ofsy) = *(x + ofsx); + ofsx += incx; + ofsy += incy; + } } -} -// copy x to y with given offset and stride -OWL_INLINE void owl_float64_copy (int N, double* x, int ofsx, int incx, double* y, int ofsy, int incy) { - for (int i = 0; i < N; i++) { - *(y + ofsy) = *(x + ofsx); - ofsx += incx; - ofsy += incy; + // copy x to y with given offset and stride + OWL_INLINE void owl_float64_copy(int N, double *x, int ofsx, int incx, double *y, int ofsy, int incy) + { + for (int i = 0; i < N; i++) + { + *(y + ofsy) = *(x + ofsx); + ofsx += incx; + ofsy += incy; + } } -} -// copy x to y with given offset and stride -OWL_INLINE void owl_complex32_copy (int N, _Complex float* x, int ofsx, int incx, _Complex float* y, int ofsy, int incy) { - for (int i = 0; i < N; i++) { - *(y + ofsy) = *(x + ofsx); - ofsx += incx; - ofsy += incy; + // copy x to y with given offset and stride + OWL_INLINE void owl_complex32_copy(int N, _Complex float *x, int ofsx, int incx, _Complex float *y, int ofsy, int incy) + { + for (int i = 0; i < N; i++) + { + *(y + ofsy) = *(x + ofsx); + ofsx += incx; + ofsy += incy; + } } -} -// copy x to y with given offset and stride -OWL_INLINE void owl_complex64_copy (int N, _Complex double* x, int ofsx, int incx, _Complex double* y, int ofsy, int incy) { - for (int i = 0; i < N; i++) { - *(y + ofsy) = *(x + ofsx); - ofsx += incx; - ofsy += incy; + // copy x to y with given offset and stride + OWL_INLINE void owl_complex64_copy(int N, _Complex double *x, int ofsx, int incx, _Complex double *y, int ofsy, int incy) + { + for (int i = 0; i < N; i++) + { + *(y + ofsy) = *(x + ofsx); + ofsx += incx; + ofsy += incy; + } } -} +#ifdef __cplusplus +} +#endif -#endif /* OWL_CORE_H */ +#endif /* OWL_CORE_H */ diff --git a/src/owl/fftpack/owl_fftpack_float32.cc b/src/owl/fftpack/owl_fftpack_float32.cc index f18d99ae7..6cd12da54 100644 --- a/src/owl/fftpack/owl_fftpack_float32.cc +++ b/src/owl/fftpack/owl_fftpack_float32.cc @@ -4,24 +4,10 @@ */ #include +#include "owl_core.h" #define Treal float -extern "C" -{ -#include "owl_core.h" - value float32_cfft(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); - value float32_cfft_bytecode(value *argv, int argn); - value float32_rfftf(value vX, value vY, value vD, value vNorm, value vNthreads); - value float32_rfftb(value vX, value vY, value vD, value vNorm, value vNthreads); - value float32_dct(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); - value float32_dct_bytecode(value *argv, int argn); - value float32_dst(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); - value float32_dst_bytecode(value *argv, int argn); -} - -#define REAL_COPY owl_float32_copy -#define COMPLEX_COPY owl_complex32_copy #define STUB_CFFT float32_cfft #define STUB_CFFT_bytecode float32_cfft_bytecode #define STUB_RFFTF float32_rfftf @@ -31,10 +17,20 @@ extern "C" #define STUB_RDST float32_dst #define STUB_RDST_bytecode float32_dst_bytecode +extern "C" +{ + value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_CFFT_bytecode(value *argv, int argn); + value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value STUB_RDCT_bytecode(value *argv, int argn); + value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value STUB_RDST_bytecode(value *argv, int argn); +} + #include "owl_fftpack_impl.h" -#undef REAL_COPY -#undef COMPLEX_COPY #undef STUB_CFFT #undef STUB_CFFT_bytecode #undef STUB_RFFTF diff --git a/src/owl/fftpack/owl_fftpack_float64.cc b/src/owl/fftpack/owl_fftpack_float64.cc index 2051941c9..2547a3e70 100644 --- a/src/owl/fftpack/owl_fftpack_float64.cc +++ b/src/owl/fftpack/owl_fftpack_float64.cc @@ -4,23 +4,9 @@ */ #include -#define Treal double - -extern "C" -{ #include "owl_core.h" - value float64_cfft(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); - value float64_cfft_bytecode(value *argv, int argn); - value float64_rfftf(value vX, value vY, value vD, value vNorm, value vNthreads); - value float64_rfftb(value vX, value vY, value vD, value vNorm, value vNthreads); - value float64_dct(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); - value float64_dct_bytecode(value *argv, int argn); - value float64_dst(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); - value float64_dst_bytecode(value *argv, int argn); -} +#define Treal double -#define REAL_COPY owl_float64_copy -#define COMPLEX_COPY owl_complex64_copy #define STUB_CFFT float64_cfft #define STUB_CFFT_bytecode float64_cfft_bytecode #define STUB_RFFTF float64_rfftf @@ -30,10 +16,20 @@ extern "C" #define STUB_RDST float64_dst #define STUB_RDST_bytecode float64_dst_bytecode +extern "C" +{ + value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_CFFT_bytecode(value *argv, int argn); + value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads); + value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value STUB_RDCT_bytecode(value *argv, int argn); + value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vOrtho, value vNthreads); + value STUB_RDST_bytecode(value *argv, int argn); +} + #include "owl_fftpack_impl.h" -#undef REAL_COPY -#undef COMPLEX_COPY #undef STUB_CFFT #undef STUB_CFFT_bytecode #undef STUB_RFFTF From 85e5ae98c58e103e0f60234d20f8137df5a28e9a Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Tue, 19 Nov 2024 09:45:02 +0100 Subject: [PATCH 10/12] Adapted the fft2 and ifft2 prototypes to match new FFT module. - Added norm and nthreads parameters. --- src/owl/fftpack/owl_fft_d.mli | 4 ++-- src/owl/fftpack/owl_fft_generic.ml | 6 ++++-- src/owl/fftpack/owl_fft_generic.mli | 10 ++++++---- src/owl/fftpack/owl_fft_s.mli | 4 ++-- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/owl/fftpack/owl_fft_d.mli b/src/owl/fftpack/owl_fft_d.mli index ea80f18bb..583c25fb9 100644 --- a/src/owl/fftpack/owl_fft_d.mli +++ b/src/owl/fftpack/owl_fft_d.mli @@ -36,9 +36,9 @@ val irfft -> (Complex.t, complex64_elt) t -> (float, float64_elt) t -val fft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val fft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t -val ifft2 : (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t +val ifft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex64_elt) t -> (Complex.t, complex64_elt) t val dct : ?axis:int diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index 377e7a0c7..859bb28d4 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -77,9 +77,11 @@ let irfft ?axis ?n ?(norm : tnorm = Forward) ?(nthreads : int = 1) ~(otyp : ('a, y -let fft2 x = fft ~axis:0 x |> fft ~axis:1 +let fft2 ?(norm : tnorm = Backward) ?(nthreads : int = 1) x = + (fft ~axis:0 ~norm ~nthreads x) |> (fft ~axis:1 ~norm ~nthreads) -let ifft2 x = ifft ~axis:0 x |> ifft ~axis:1 +let ifft2 ?(norm : tnorm = Forward) ?(nthreads : int = 1) x = + (ifft ~axis:0 ~norm ~nthreads x) |> (ifft ~axis:1 ~norm ~nthreads) type ttrig_transform = | I diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index 73c7a4276..4d95575a6 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -61,11 +61,13 @@ val irfft [nthreads] is the desired number of threads used to compute the fft. Note the [n] parameter is used to specified the size of output. *) -val fft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** [fft2 x] performs 2-dimensional FFT on a complex input. The return is not scaled. *) +val fft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'b) t -> (Complex.t, 'b) t +(** [fft2 ~norm ~nthreads x] performs 2-dimensional FFT on a complex input. [norm] is the normalization option. + By default, [norm] is set to [Forward]. [nthreads] is the desired number of threads used to compute each of the fft. *) -val ifft2 : (Complex.t, 'a) t -> (Complex.t, 'a) t -(** [ifft2 x] performs inverse 2-dimensional FFT on a complex input. *) +val ifft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'b) t -> (Complex.t, 'b) t +(** [ifft2 ~norm ~nthreads x] performs 2-dimensional inverse FFT on a complex input. [norm] is the normalization option. + By default, [norm] is set to [Backward]. [nthreads] is the desired number of threads used to compute each of the ifft. *) (** {5 Discrete Cosine & Sine Transforms functions} *) diff --git a/src/owl/fftpack/owl_fft_s.mli b/src/owl/fftpack/owl_fft_s.mli index 3914d685c..63a2c931d 100644 --- a/src/owl/fftpack/owl_fft_s.mli +++ b/src/owl/fftpack/owl_fft_s.mli @@ -36,9 +36,9 @@ val irfft -> (Complex.t, complex32_elt) t -> (float, float32_elt) t -val fft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val fft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t -val ifft2 : (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t +val ifft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, complex32_elt) t -> (Complex.t, complex32_elt) t val dct : ?axis:int From 2ab8c8dd3e617f7e1556019e0ae1fb67ddfef595 Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Tue, 19 Nov 2024 09:46:46 +0100 Subject: [PATCH 11/12] Fixing documentation strings. --- src/owl/fftpack/owl_fft_generic.mli | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/owl/fftpack/owl_fft_generic.mli b/src/owl/fftpack/owl_fft_generic.mli index 4d95575a6..68ccbfb23 100644 --- a/src/owl/fftpack/owl_fft_generic.mli +++ b/src/owl/fftpack/owl_fft_generic.mli @@ -63,11 +63,11 @@ val irfft val fft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'b) t -> (Complex.t, 'b) t (** [fft2 ~norm ~nthreads x] performs 2-dimensional FFT on a complex input. [norm] is the normalization option. - By default, [norm] is set to [Forward]. [nthreads] is the desired number of threads used to compute each of the fft. *) + By default, [norm] is set to [Backward]. [nthreads] is the desired number of threads used to compute each of the fft. *) val ifft2 : ?norm:tnorm -> ?nthreads:int -> (Complex.t, 'b) t -> (Complex.t, 'b) t (** [ifft2 ~norm ~nthreads x] performs 2-dimensional inverse FFT on a complex input. [norm] is the normalization option. - By default, [norm] is set to [Backward]. [nthreads] is the desired number of threads used to compute each of the ifft. *) + By default, [norm] is set to [Forward]. [nthreads] is the desired number of threads used to compute each of the ifft. *) (** {5 Discrete Cosine & Sine Transforms functions} *) From 8fa2d85ead7f1fc18b7561e5dd30c224bc88808f Mon Sep 17 00:00:00 2001 From: Gabriel Santamaria Date: Tue, 19 Nov 2024 18:49:35 +0100 Subject: [PATCH 12/12] Little code change for more similarities with scipy's implementation. --- src/owl/fftpack/owl_fft_generic.ml | 4 +-- src/owl/fftpack/owl_fftpack_impl.h | 46 +++++++++++++++--------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/owl/fftpack/owl_fft_generic.ml b/src/owl/fftpack/owl_fft_generic.ml index 859bb28d4..c3abac5fa 100644 --- a/src/owl/fftpack/owl_fft_generic.ml +++ b/src/owl/fftpack/owl_fft_generic.ml @@ -12,8 +12,8 @@ type tnorm = let tnorm_to_int = function | Backward -> 0 - | Forward -> 1 - | Ortho -> 2 + | Forward -> 2 + | Ortho -> 1 let fft ?axis ?(norm : tnorm = Backward) ?(nthreads : int = 1) x = let axis = diff --git a/src/owl/fftpack/owl_fftpack_impl.h b/src/owl/fftpack/owl_fftpack_impl.h index 12bb48420..92249cca9 100644 --- a/src/owl/fftpack/owl_fftpack_impl.h +++ b/src/owl/fftpack/owl_fftpack_impl.h @@ -20,34 +20,34 @@ using namespace pocketfft::detail; +using ldbl_t = typename std::conditional< + sizeof(long double) == sizeof(double), double, long double>::type; + template T norm_fct(int inorm, size_t N) { - switch (inorm) - { - case 0: // "backward" - no normalization for forward transform + if (inorm == 0) return T(1); - case 1: // "forward" - 1/n normalization for forward transform - return T(1) / T(N); - case 2: // "ortho" - 1/sqrt(n) normalization for both directions - return T(1) / std::sqrt(T(N)); - default: - caml_failwith("invalid value for inorm (must be 0, 1, or 2)"); - // This will never be reached - return T(0); - } + if (inorm == 2) + return T(1 / ldbl_t(N)); + if (inorm == 1) + return T(1 / sqrt(ldbl_t(N))); + caml_failwith("invalid value for norm (must be 0, 1, or 2)"); // could make use of caml exections + // This will never be reached + return T(0); } template -T compute_norm_factor(const shape_t &dims, const shape_t &axes, int inorm, size_t fct = 1, int delta = 0) +T norm_fct(int inorm, const shape_t &shape, + const shape_t &axes, size_t fct = 1, int delta = 0) { if (inorm == 0) return T(1); + size_t N = 1; for (auto a : axes) - { - N *= fct * size_t(int64_t(dims[a]) + delta); - } + N *= fct * size_t(int64_t(shape[a]) + delta); + return norm_fct(inorm, N); } @@ -95,7 +95,7 @@ value STUB_CFFT(value vForward, value vX, value vY, value vD, value vNorm, value shape_t axes{static_cast(d)}; { - Treal norm_factor = compute_norm_factor(dims, axes, norm); + Treal norm_factor = norm_fct(norm, dims, axes); try { pocketfft::c2c(dims, stride_in, stride_out, axes, forward, @@ -166,7 +166,7 @@ value STUB_RFFTF(value vX, value vY, value vD, value vNorm, value vNthreads) shape_t axes{static_cast(d)}; { - Treal norm_factor = compute_norm_factor(dims, axes, norm); + Treal norm_factor = norm_fct(norm, dims, axes); try { pocketfft::r2c(dims, stride_in, stride_out, axes, pocketfft::FORWARD, @@ -239,7 +239,7 @@ value STUB_RFFTB(value vX, value vY, value vD, value vNorm, value vNthreads) shape_t axes{static_cast(d)}; { - Treal norm_factor = compute_norm_factor(dims, axes, norm); + Treal norm_factor = norm_fct(norm, dims, axes); try { pocketfft::c2r(dims, stride_in, stride_out, axes, pocketfft::BACKWARD, @@ -299,8 +299,8 @@ value STUB_RDCT(value vX, value vY, value vD, value vType, value vNorm, value vO shape_t axes{static_cast(d)}; { - Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, -1) - : compute_norm_factor(dims, axes, norm, 2); + Treal norm_factor = (type == 1) ? norm_fct(norm, dims, axes, 2, -1) + : norm_fct(norm, dims, axes, 2); try { pocketfft::dct(dims, stride_in, stride_out, axes, type, @@ -365,8 +365,8 @@ value STUB_RDST(value vX, value vY, value vD, value vType, value vNorm, value vO shape_t axes{static_cast(d)}; { - Treal norm_factor = (type == 1) ? compute_norm_factor(dims, axes, norm, 2, 1) - : compute_norm_factor(dims, axes, norm, 2); + Treal norm_factor = (type == 1) ? norm_fct(norm, dims, axes, 2, 1) + : norm_fct(norm, dims, axes, 2); try { pocketfft::dst(dims, stride_in, stride_out, axes, type,