diff --git a/Makefile.in b/Makefile.in index d87350671a..7b4d3953fe 100644 --- a/Makefile.in +++ b/Makefile.in @@ -189,7 +189,8 @@ HEADER_DIRS := \ fmpz_mod_mpoly_factor fmpq_mpoly_factor \ fq_nmod_mpoly_factor fq_zech_mpoly_factor \ \ - fft @FFT_SMALL@ fmpz_poly_q fmpz_lll \ + fft n_fft @FFT_SMALL@ \ + fmpz_poly_q fmpz_lll \ n_poly arith qsieve aprcl \ \ nf nf_elem qfb \ diff --git a/src/n_fft.h b/src/n_fft.h new file mode 100644 index 0000000000..6676fe9fd2 --- /dev/null +++ b/src/n_fft.h @@ -0,0 +1,258 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#ifndef N_FFT_H +#define N_FFT_H + +#include "ulong_extras.h" + +#define N_FFT_CTX_DEFAULT_DEPTH 12 + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * TODO[short term] augment precomputations with inverse roots + * TODO[short term] add testing for general variants, not only node0 + * TODO[long term] large depth can lead to heavy memory usage + * --> provide precomputation-free functions + * TODO[long term] on zen4 (likely on other cpus as well) ctx_init becomes + * slower at some point, losing a factor 4 or more, probably due to caching; + * what is annoying is that the depth where it becomes slower is significantly + * smaller (~13-14) when tab_iw has been incorporated compared to without + * tab_iw (it was depth ~20-21); see if this can be understood, and maybe play + * with vectorization for those simple functions + * TODO[later] provide forward function which reduces output to [0..n) ? + * unclear this is useful... to be decided later + */ + +/** n_fft context: + * parameters and tabulated powers of the primitive root of unity "w". + **/ + +typedef struct +{ + ulong mod; // modulus, odd prime + ulong max_depth; // maximum supported depth (w has order 2**max_depth) + ulong cofactor; // prime is 1 + cofactor * 2**max_depth + ulong depth; // depth supported by current precomputation + nn_ptr tab_w; // tabulated powers of w, see below + nn_ptr tab_iw; // tabulated powers of 1/w, see below + ulong tab_w2[2*FLINT_BITS]; // tabulated powers w**(2**k), see below + ulong tab_inv2[2*FLINT_BITS]; // tabulated inverses of 2**k, see below +} n_fft_ctx_struct; +typedef n_fft_ctx_struct n_fft_ctx_t[1]; + + +/** Requirements (not checked upon init): + * - mod is an odd prime < 2**(FLINT_BITS-2) + * - max_depth must be >= 3 (so, 8 must divide mod - 1) + * Total memory cost of precomputations for arrays tab_{w,iw,w2,inv2}: + * at most 2 * (2*FLINT_BITS + 2**depth) ulong's + */ + +/** tab_w2: + * - length 2*FLINT_BITS, with undefined entries at index 2*(max_depth-1) and beyond + * - contains powers w**d for d a power of 2, and corresponding + * precomputations for modular multiplication: + * -- for 0 <= k < max_depth-1, tab_w2[2*k] = w**(2**(max_depth-2-k)) + * and tab_w2[2*k+1] = floor(tab_w2[2*k] * 2**FLINT_BITS / mod) + * -- for 2*(max_depth-1) <= k < 2*FLINT_BITS, tab_w2[k] is undefined + * + * --> one can retrieve w as tab_w2[2 * (max_depth-2)] + * --> the first elements are tab_w2 = [I, I_pr, J, J_pr, ...] + * where I is a square root of -1 and J is a square root of I + */ + +/** tab_w: + * - length 2**depth + * - contains 2**(depth-1) first powers of w in (max_depth-1)-bit reversed order, + * and corresponding precomputations for modular multiplication: + * -- for 0 <= k < 2**(depth-1), tab_w[2*k] = w**(br[k]) + * and tab_w[2*k+1] = floor(tab_w[2*k] * 2**FLINT_BITS / mod) + * where br = [0, 2**(max_depth-2), 2**(max_depth-3), 3 * 2**(max_depth-3), ...] + * is the bit reversal permutation of length 2**(max_depth-1) + * (https://en.wikipedia.org/wiki/Bit-reversal_permutation) + * + * In particular the first elements are + * tab_w = [1, 1_pr, I, I_pr, J, J_pr, IJ, IJ_pr, ...] + * where I is a square root of -1, J is a square root of I, and IJ = I*J. Note + * that powers of w beyond 2**(max_depth-1), for example -1, -I, -J, etc. are + * not stored. + **/ + +/** tab_iw: same as tab_w but for the primitive root 1/w */ + +/** tab_inv2: + * - length 2*FLINT_BITS, with undefined entries at index 2*max_depth and beyond + * - contains the modular inverses of 2**k, and corresponding + * precomputations for modular multiplication: + * -- for 0 <= k < max_depth, tab_inv2[2*k] = the inverse of 2**(k+1) + * modulo mod, and tab_inv2[2*k+1] = floor(tab_inv2[2*k] * 2**FLINT_BITS / mod) + * -- for 2*max_depth <= k < 2*FLINT_BITS, tab_inv2[k] is undefined + * + * Recall F->mod == 1 + cofactor * 2**max_depth, so + * 1 == F->mod - cofactor * 2**(max_depth - k) * 2**k + * --> the inverse of 2**k in (0, F->mod) is + * F->mod - cofactor * 2**(max_depth - k), + * we do not really need to store it, but we want the precomputations as well + */ + + + + + + + +/** Note for init functions, when depth is provided: + * - if it is < 3, it is pretended that it is 3 + * - it it is more than F->max_depth (the maximum possible with the given + * prime), it is reduced to F->max_depth + * After calling init, precomputations support DFTs of length up to 2**depth + */ + +// initialize with given root and given depth +void n_fft_ctx_init2_root(n_fft_ctx_t F, ulong w, ulong max_depth, ulong cofactor, ulong depth, ulong mod); + +// find primitive root, initialize with given depth +void n_fft_ctx_init2(n_fft_ctx_t F, ulong depth, ulong p); + +// same, with default depth +FLINT_FORCE_INLINE +void n_fft_ctx_init_root(n_fft_ctx_t F, ulong w, ulong max_depth, ulong cofactor, ulong p) +{ n_fft_ctx_init2_root(F, w, max_depth, cofactor, N_FFT_CTX_DEFAULT_DEPTH, p); } + +FLINT_FORCE_INLINE +void n_fft_ctx_init(n_fft_ctx_t F, ulong p) +{ n_fft_ctx_init2(F, N_FFT_CTX_DEFAULT_DEPTH, p); } + +// grows F->depth and precomputations to support DFTs of depth up to depth +void n_fft_ctx_fit_depth(n_fft_ctx_t F, ulong depth); + +void n_fft_ctx_clear(n_fft_ctx_t F); + + + +typedef struct +{ + ulong mod; // modulus, odd prime + ulong mod2; // 2*mod (storing helps for speed) + //ulong mod4; // 4*mod (storing helps for speed) + nn_srcptr tab_w; // tabulated powers of w, see below +} n_fft_args_struct; +typedef n_fft_args_struct n_fft_args_t[1]; + +FLINT_FORCE_INLINE +void n_fft_set_args(n_fft_args_t F, ulong mod, nn_srcptr tab_w) +{ + F->mod = mod; + F->mod2 = 2*mod; + F->tab_w = tab_w; +} + + + + + +/** dft: + * transforms / inverse transforms / transposed transforms + * at length a power of 2 + */ + +void dft_node0_lazy14(nn_ptr p, ulong depth, n_fft_args_t F); + +/** 2**depth-point DFT + * * in [0..n) / out [0..4n) / max < 4n + * * In-place transform p of length len == 2**depth into + * the concatenation of + * [sum(p[i] * w_k**i for i in range(len), sum(p[i] * (-w_k)**i for i in range(len)] + * for k in range(len), + * where w_k = F->tab_w[2*k] for 0 <= k < 2**(depth-1) + * * By construction these evaluation points are the roots of the polynomial + * x**len - 1, precisely they are all powers of the chosen len-th primitive + * root of unity with exponents listed in bit reversed order + * * Requirements (not checked): depth <= F.depth + */ +FLINT_FORCE_INLINE void n_fft_dft(nn_ptr p, ulong depth, n_fft_ctx_t F) +{ + n_fft_args_t Fargs; + n_fft_set_args(Fargs, F->mod, F->tab_w); + dft_node0_lazy14(p, depth, Fargs); +} + +// FIXME in progress +// not tested yet --> test == applying dft yields identity +// DOC. Note: output < n. +void idft_node0_lazy12(nn_ptr p, ulong depth, n_fft_args_t F); +FLINT_FORCE_INLINE void n_fft_idft(nn_ptr p, ulong depth, n_fft_ctx_t F) +{ + n_fft_args_t Fargs; + n_fft_set_args(Fargs, F->mod, F->tab_iw); + idft_node0_lazy12(p, depth, Fargs); + + if (depth > 0) + { + const ulong inv2 = F->tab_inv2[2*depth-2]; + const ulong inv2_pr = F->tab_inv2[2*depth-1]; + //ulong p_hi, p_lo; + for (ulong k = 0; k < (UWORD(1) << depth); k++) + { + p[k] = n_mulmod_shoup(inv2, p[k], inv2_pr, F->mod); + //umul_ppmm(p_hi, p_lo, inv2_pr, p[k]); + //p[k] = inv2 * p[k] - p_hi * F->mod; + } + // NOTE: apparently no gain from lazy variant, so + // probably better to use non-lazy one (ensures output < n) + } + // FIXME see if that can be made less expensive at least for depths not too + // small, by integrating into base cases of dft_node0 +} + + + +// FIXME in progress +// not tested yet --> test == naive version? +// DOC. Note: output < 2n (?). +FLINT_FORCE_INLINE void n_fft_dft_t(nn_ptr p, ulong depth, n_fft_ctx_t F) +{ + n_fft_args_t Fargs; + n_fft_set_args(Fargs, F->mod, F->tab_w); + idft_node0_lazy12(p, depth, Fargs); +} + +// FIXME in progress +// not tested yet --> test == applying dft_t yields identity? +// DOC. Note: output < n. +FLINT_FORCE_INLINE void n_fft_idft_t(nn_ptr p, ulong depth, n_fft_ctx_t F) +{ + n_fft_args_t Fargs; + n_fft_set_args(Fargs, F->mod, F->tab_iw); + dft_node0_lazy14(p, depth, Fargs); + + if (depth > 0) + { + // see comments in idft concerning this loop + const ulong inv2 = F->tab_inv2[2*depth-2]; + const ulong inv2_pr = F->tab_inv2[2*depth-1]; + for (ulong k = 0; k < (UWORD(1) << depth); k++) + p[k] = n_mulmod_shoup(inv2, p[k], inv2_pr, F->mod); + } +} + + + + +#ifdef __cplusplus +} +#endif + +#endif /* N_FFT_H */ diff --git a/src/n_fft/ctx_init.c b/src/n_fft/ctx_init.c new file mode 100644 index 0000000000..d05a5f9463 --- /dev/null +++ b/src/n_fft/ctx_init.c @@ -0,0 +1,174 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "n_fft.h" + +/** Given the precomputed quotient a_pr for modular multiplication by a mod n, + * a_pr == floor(a * 2**FLINT_BITS / n) + * where we assume 0 < a < n and n does not divide a * 2**FLINT_BITS, + * this returns the quotient for mulmod by -a mod n, + * floor( (n-a) * 2**FLINT_BITS / n) + * == 2**FLINT_BITS - ceil(a * 2**FLINT_BITS / n) + * == 2**FLINT_BITS - a_pr + * + * Note: the requirement "n does not divide a * 2**FLINT_BITS" follows + * from the other requirement 0 < a < n as soon as n is odd; in n_fft.h + * we will only use this for odd primes + */ +FLINT_FORCE_INLINE ulong n_mulmod_precomp_shoup_negate(ulong a_pr) +{ + return UWORD_MAX - a_pr; +} + +void n_fft_ctx_init2_root(n_fft_ctx_t F, ulong w, ulong max_depth, ulong cofactor, ulong depth, ulong p) +{ + if (depth < 3) + depth = 3; + if (max_depth < depth) + depth = max_depth; + + // fill basic attributes + F->mod = p; + F->max_depth = max_depth; + F->cofactor = cofactor; + F->depth = 3; // to be able to call fit_depth below + + // fill tab_w2 + ulong pr_quo, pr_rem, ww; + ww = w; + n_mulmod_precomp_shoup_quo_rem(&pr_quo, &pr_rem, ww, p); + F->tab_w2[2*(max_depth-2)] = ww; + F->tab_w2[2*(max_depth-2)+1] = pr_quo; + for (slong k = max_depth-3; k >= 0; k--) + { + // ww <- ww**2 and its precomputed quotient + n_mulmod_and_precomp_shoup(&ww, &pr_quo, ww, ww, pr_quo, pr_rem, pr_quo, p); + pr_rem = n_mulmod_precomp_shoup_rem_from_quo(pr_quo, p); + F->tab_w2[2*k] = ww; + F->tab_w2[2*k+1] = pr_quo; + } + // at this stage, pr_quo and pr_rem are for k == 0 i.e. for I == tab_w2[0] + + // fill tab_inv2 + for (ulong k = 0; k < max_depth; k++) + { + F->tab_inv2[2*k] = p - (cofactor << (max_depth - k-1)); + F->tab_inv2[2*k+1] = n_mulmod_precomp_shoup(F->tab_inv2[2*k], p); + } + + // fill tab_w and tab_iw for depth 3 + ulong len = UWORD(1) << (depth-1); // len >= 4 + F->tab_w = (nn_ptr) flint_malloc(2*len * sizeof(ulong)); + F->tab_iw = (nn_ptr) flint_malloc(2*len * sizeof(ulong)); + + // w**0 == iw**0 == 1 + F->tab_w[0] = UWORD(1); + F->tab_w[1] = n_mulmod_precomp_shoup(UWORD(1), p); + F->tab_iw[0] = UWORD(1); + F->tab_iw[1] = F->tab_w[1]; + + // w**(L/4) == I and iw**(L/4) == -I, L == 2**max_depth + F->tab_w[2] = F->tab_w2[0]; + F->tab_w[3] = F->tab_w2[1]; + F->tab_iw[2] = p - F->tab_w2[0]; + F->tab_iw[3] = n_mulmod_precomp_shoup_negate(F->tab_w2[1]); + + // w**(L/8) == J and w**(3L/8) == I*J + F->tab_w[4] = F->tab_w2[2]; + F->tab_w[5] = F->tab_w2[3]; + n_mulmod_and_precomp_shoup(F->tab_w+6, F->tab_w+7, F->tab_w2[0], F->tab_w2[2], pr_quo, pr_rem, F->tab_w2[3], p); + + // iw**(L/8) == -I*J and iw**(3L/8) == -J + F->tab_iw[4] = p - F->tab_w[6]; + F->tab_iw[5] = n_mulmod_precomp_shoup_negate(F->tab_w[7]); + F->tab_iw[6] = p - F->tab_w[4]; + F->tab_iw[7] = n_mulmod_precomp_shoup_negate(F->tab_w[5]); + + // complete tab_w up to specified depth + n_fft_ctx_fit_depth(F, depth); +} + +void n_fft_ctx_init2(n_fft_ctx_t F, ulong depth, ulong p) +{ + FLINT_ASSERT(p > 2 && flint_clz(p) >= 2); // 2 < p < 2**(FLINT_BITS-2) + FLINT_ASSERT(flint_ctz(p - UWORD(1)) >= 3); // p-1 divisible by 8 + + // find the constant and exponent such that p == c * 2**max_depth + 1 + const ulong max_depth = flint_ctz(p - UWORD(1)); + const ulong cofactor = (p - UWORD(1)) >> max_depth; + + // find primitive root w of order 2**max_depth + const ulong prim_root = n_primitive_root_prime(p); + const ulong w = n_powmod2(prim_root, cofactor, p); + + // fill all attributes and tables + n_fft_ctx_init2_root(F, w, max_depth, cofactor, depth, p); +} + +void n_fft_ctx_clear(n_fft_ctx_t F) +{ + flint_free(F->tab_w); + flint_free(F->tab_iw); +} + +void n_fft_ctx_fit_depth(n_fft_ctx_t F, ulong depth) +{ + if (F->max_depth < depth) + depth = F->max_depth; + + if (depth > F->depth) + { + ulong len = UWORD(1) << (depth-1); // len >= 8 (since depth >= 4) + F->tab_w = flint_realloc(F->tab_w, 2*len * sizeof(ulong)); + F->tab_iw = flint_realloc(F->tab_iw, 2*len * sizeof(ulong)); + + // tab_w[2] is w**(L/8) * tab_w[0], where L = 2**max_depth, + // tab_w[2*4,2*6] is w**(L/16) * tab_w[2*0,2*2], + // tab_w[2*8,2*10,2*12,2*14] is w**(L/32) * tab_w[2*0,2*2,2*4,2*6], etc. + // recall tab_w2[2*k] == w**(L / 2**(k+2)) + ulong d = F->depth - 1; + ulong llen = UWORD(1) << (F->depth-1); + ulong ww, pr_quo, pr_rem; + for ( ; llen < len; llen <<= 1, d += 1) + { + ww = F->tab_w2[2*d]; + pr_quo = F->tab_w2[2*d+1]; + pr_rem = n_mulmod_precomp_shoup_rem_from_quo(pr_quo, F->mod); + // for each k, tab_w[2*(k+llen)] <- ww * tab_w[2*k], and deduce precomputation + for (ulong k = 0; k < llen; k+=4) + { + n_mulmod_and_precomp_shoup(F->tab_w + 2*llen + 2*(k+0), F->tab_w + 2*llen + 2*(k+0)+1, + ww, F->tab_w[2*(k+0)], + pr_quo, pr_rem, F->tab_w[2*(k+0)+1], F->mod); + n_mulmod_and_precomp_shoup(F->tab_w + 2*llen + 2*(k+1), F->tab_w + 2*llen + 2*(k+1)+1, + ww, F->tab_w[2*(k+1)], + pr_quo, pr_rem, F->tab_w[2*(k+1)+1], F->mod); + n_mulmod_and_precomp_shoup(F->tab_w + 2*llen + 2*(k+2), F->tab_w + 2*llen + 2*(k+2)+1, + ww, F->tab_w[2*(k+2)], + pr_quo, pr_rem, F->tab_w[2*(k+2)+1], F->mod); + n_mulmod_and_precomp_shoup(F->tab_w + 2*llen + 2*(k+3), F->tab_w + 2*llen + 2*(k+3)+1, + ww, F->tab_w[2*(k+3)], + pr_quo, pr_rem, F->tab_w[2*(k+3)+1], F->mod); + + F->tab_iw[2*llen + 2*(llen-1-(k+0))] = F->mod - F->tab_w[2*llen + 2*(k+0)]; + F->tab_iw[2*llen + 2*(llen-1-(k+0)) + 1] = n_mulmod_precomp_shoup_negate(F->tab_w[2*llen + 2*(k+0)+1]); + F->tab_iw[2*llen + 2*(llen-1-(k+1))] = F->mod - F->tab_w[2*llen + 2*(k+1)]; + F->tab_iw[2*llen + 2*(llen-1-(k+1)) + 1] = n_mulmod_precomp_shoup_negate(F->tab_w[2*llen + 2*(k+1)+1]); + F->tab_iw[2*llen + 2*(llen-1-(k+2))] = F->mod - F->tab_w[2*llen + 2*(k+2)]; + F->tab_iw[2*llen + 2*(llen-1-(k+2)) + 1] = n_mulmod_precomp_shoup_negate(F->tab_w[2*llen + 2*(k+2)+1]); + F->tab_iw[2*llen + 2*(llen-1-(k+3))] = F->mod - F->tab_w[2*llen + 2*(k+3)]; + F->tab_iw[2*llen + 2*(llen-1-(k+3)) + 1] = n_mulmod_precomp_shoup_negate(F->tab_w[2*llen + 2*(k+3)+1]); + } + } + + F->depth = depth; + } +} diff --git a/src/n_fft/dft.c b/src/n_fft/dft.c new file mode 100644 index 0000000000..97205418e4 --- /dev/null +++ b/src/n_fft/dft.c @@ -0,0 +1,728 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "n_fft.h" +#include "n_fft_macros.h" + +/*-------------*/ +/* 2-point DFT */ +/*-------------*/ + +/** Cooley-Tukey butterfly, general + * * in [0..4n) / out [0..4n) / max < 4n + * * In-place transform + * [1 1] + * [a b] <- [a b] [w -w] + * * n2 is 2*n, w_pr is the precomputed data for multiplication by w mod n + * p_hi, p_lo, u, v are temporaries + */ +#define DFT2_LAZY44(a, b, n, n2, w, w_pr, p_hi, p_lo, u, v) \ + do { \ + u = (a); \ + if (u >= (n2)) \ + u -= (n2); /* [0..2n) */ \ + v = (b); \ + N_MULMOD_PRECOMP_LAZY(v, w, v, w_pr, n, p_hi, p_lo); \ + (a) = u + v; /* [0..4n) */ \ + (b) = u + (n2) - v; /* [0..4n) */ \ + } while(0) + +/*-------------*/ +/* 4-point DFT */ +/*-------------*/ + +/** 4-point DFT, general + * * in [0..4n) / out [0..4n) / max < 4n + * * In-place transform + * [ 1 1 1 1] + * [w2 -w2 w3 -w3] + * [a b c d] <- [a b c d] [w1 w1 -w1 -w1] + * [w1*w2 -w1*w2 -w1*w3 w1*w3] + * * Corresponds to reducing down the tree with nodes + * x^4 - w1**2 + * / \ + * x^2 - w1 x^2 + w1 + * / \ / \ + * x - w2 x + w2 x - w3 x + w3 + * typically w2**2 == w1 and w3 == I*w2 (hence w3**2 == -w1) so that the above + * is a Vandermonde matrix and this tree really is the subproduct tree built + * from the four roots w2, -w2, I*w2, -I*w2 of x**4 - w1 + */ +#define DFT4_LAZY44(a, b, c, d, \ + w1, w1_pr, w2, w2_pr, w3, w3_pr, \ + n, n2, p_hi, p_lo, tmp) \ +do { \ + ulong u0 = (a); \ + ulong u1 = (b); \ + ulong u2 = (c); \ + ulong u3 = (d); \ + if (u0 >= n2) \ + u0 -= n2; \ + if (u1 >= n2) \ + u1 -= n2; \ + \ + N_MULMOD_PRECOMP_LAZY(u2, w1, u2, w1_pr, n, p_hi, p_lo); \ + tmp = u0; \ + u0 = u0 + u2; /* [0..4n) */ \ + u2 = tmp + n2 - u2; /* [0..4n) */ \ + if (u0 >= n2) \ + u0 -= n2; /* [0..2n) */ \ + if (u2 >= n2) \ + u2 -= n2; /* [0..2n) */ \ + \ + N_MULMOD_PRECOMP_LAZY(u3, w1, u3, w1_pr, n, p_hi, p_lo); \ + tmp = u1; \ + u1 = u1 + u3; /* [0..4n) */ \ + u3 = tmp + n2 - u3; /* [0..4n) */ \ + \ + N_MULMOD_PRECOMP_LAZY(u1, w2, u1, w2_pr, n, p_hi, p_lo); \ + tmp = u0; \ + (a) = u0 + u1; /* [0..4n) */ \ + (b) = tmp + n2 - u1; /* [0..4n) */ \ + \ + N_MULMOD_PRECOMP_LAZY(u3, w3, u3, w3_pr, n, p_hi, p_lo); \ + tmp = u2; \ + (c) = u2 + u3; /* [0..4n) */ \ + (d) = tmp + n2 - u3; /* [0..4n) */ \ +} while(0) + +/*-------------*/ +/* 8-point DFT */ +/*-------------*/ + +/** 8-point DFT, node 0 + * * in [0..n) / out [0..4n) / max < 4n + * * In-place transform p = [p0,p1,p2,p3,p4,p5,p6,p7], seen as a polynomial + * p(x) = p0 + p1*x + ... + p7*x**7 into its evaluations + * p(1), p(-1), p(I), p(-I), p(J), p(-J), p(I*J), p(-I*J) + * i.e. the evaluations at all 8-th roots of unity J**k for 0 <= k < 8 in + * bit-reversed order + * * Recall [F->tab_w[2*k] for k in range(4)] == [1, I, J, IJ] + */ +#define DFT8_NODE0_LAZY14(p0, p1, p2, p3, p4, p5, p6, p7, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + BUTTERFLY_LAZY12(p0, p4, mod, tmp); \ + BUTTERFLY_LAZY12(p1, p5, mod, tmp); \ + BUTTERFLY_LAZY12(p2, p6, mod, tmp); \ + BUTTERFLY_LAZY12(p3, p7, mod, tmp); \ + \ + DFT4_NODE0_LAZY24(p0, p1, p2, p3, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + /* could use a lazy24 variant of the next macro, */ \ + /* but the gain is negligible */ \ + DFT4_LAZY44(p4, p5, p6, p7, \ + tab_w[2], tab_w[3], \ + tab_w[4], tab_w[5], \ + tab_w[6], tab_w[7], \ + mod, mod2, p_hi, p_lo, tmp); \ +} while(0) + +/** 8-point DFT, node 0 + * * in [0..2n) / out [0..4n) / max < 4n + * * apart from these ranges, same specification as DFT8_NODE0_LAZY14 + */ +#define DFT8_NODE0_LAZY24(p0, p1, p2, p3, p4, p5, p6, p7, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + BUTTERFLY_LAZY24(p0, p4, mod2, tmp); \ + BUTTERFLY_LAZY24(p1, p5, mod2, tmp); \ + BUTTERFLY_LAZY24(p2, p6, mod2, tmp); \ + BUTTERFLY_LAZY24(p3, p7, mod2, tmp); \ + \ + DFT4_NODE0_LAZY24(p0, p1, p2, p3, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + DFT4_LAZY44(p4, p5, p6, p7, \ + tab_w[2], tab_w[3], \ + tab_w[4], tab_w[5], \ + tab_w[6], tab_w[7], \ + mod, mod2, p_hi, p_lo, tmp); \ +} while(0) + +/** 8-point DFT + * * in [0..4n) / out [0..4n) / max < 4n + * * In-place transform p = [p0,p1,p2,p3,p4,p5,p6,p7], seen as a polynomial + * p(x) = p0 + p1*x + ... + p7*x**7 into its evaluations + * p(w0), p(-w0), p(w1), p(-w1), p(w2), p(-w2), p(w3), p(-w3) + * where w_k = F->tab_w[8*node + 2*k] for 0 <= k < 4 + * * By construction these 8 evaluation points are the 8 roots of the + * polynomial x**8 - F->tab_w[node] + */ +#define DFT8_LAZY44(p0, p1, p2, p3, p4, p5, p6, p7, \ + node, mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, u, v; \ + \ + const ulong w = tab_w[2*(node)]; \ + const ulong w_pr = tab_w[2*(node)+1]; \ + DFT2_LAZY44(p0, p4, mod, mod2, w, w_pr, p_hi, p_lo, u, v); \ + DFT2_LAZY44(p1, p5, mod, mod2, w, w_pr, p_hi, p_lo, u, v); \ + DFT2_LAZY44(p2, p6, mod, mod2, w, w_pr, p_hi, p_lo, u, v); \ + DFT2_LAZY44(p3, p7, mod, mod2, w, w_pr, p_hi, p_lo, u, v); \ + \ + DFT4_LAZY44(p0, p1, p2, p3, \ + tab_w[4*(node)], tab_w[4*(node)+1], \ + tab_w[8*(node)], tab_w[8*(node)+1], \ + tab_w[8*(node)+2], tab_w[8*(node)+3], \ + mod, mod2, p_hi, p_lo, u); \ + \ + DFT4_LAZY44(p4, p5, p6, p7, \ + tab_w[4*(node)+2], tab_w[4*(node)+3], \ + tab_w[8*(node)+4], tab_w[8*(node)+5], \ + tab_w[8*(node)+6], tab_w[8*(node)+7], \ + mod, mod2, p_hi, p_lo, u); \ +} while(0) + +/*--------------*/ +/* 16-point DFT */ +/*--------------*/ + +/** 16-point DFT, node 0 + * * in [0..n) / out [0..4n) / max < 4n + * * Apart from this range, same specification as dft_node0_lazy24, for depth==4 + */ +#define DFT16_NODE0_LAZY14(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + DFT4_NODE0_LAZY14(p0, p4, p8, p12, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p0 >= mod2) \ + p0 -= mod2; \ + DFT4_NODE0_LAZY14(p1, p5, p9, p13, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p1 >= mod2) \ + p1 -= mod2; \ + DFT4_NODE0_LAZY14(p2, p6, p10, p14, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p2 >= mod2) \ + p2 -= mod2; \ + DFT4_NODE0_LAZY14(p3, p7, p11, p15, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p3 >= mod2) \ + p3 -= mod2; \ + \ + /* next line requires < 2n, */ \ + /* hence the four reductions above */ \ + DFT4_NODE0_LAZY24(p0, p1, p2, p3, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + DFT4_LAZY44(p4, p5, p6, p7, \ + tab_w[2], tab_w[3], \ + tab_w[4], tab_w[5], \ + tab_w[6], tab_w[7], \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p8, p9, p10, p11, \ + tab_w[4], tab_w[5], \ + tab_w[8], tab_w[9], \ + tab_w[10], tab_w[11], \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p12, p13, p14, p15, \ + tab_w[6], tab_w[7], \ + tab_w[12], tab_w[13], \ + tab_w[14], tab_w[15], \ + mod, mod2, p_hi, p_lo, tmp); \ +} while(0) + +/** 16-point DFT, node 0 + * * in [0..2n) / out [0..4n) / max < 4n + * * Same specification as dft_node0_lazy24, for depth==4 + */ +#define DFT16_NODE0_LAZY24(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + DFT4_NODE0_LAZY24(p0, p4, p8, p12, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p0 >= mod2) \ + p0 -= mod2; \ + DFT4_NODE0_LAZY24(p1, p5, p9, p13, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p1 >= mod2) \ + p1 -= mod2; \ + DFT4_NODE0_LAZY24(p2, p6, p10, p14, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p2 >= mod2) \ + p2 -= mod2; \ + DFT4_NODE0_LAZY24(p3, p7, p11, p15, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p3 >= mod2) \ + p3 -= mod2; \ + \ + /* next line requires < 2n, */ \ + /* hence the four reductions above */ \ + DFT4_NODE0_LAZY24(p0, p1, p2, p3, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + DFT4_LAZY44(p4, p5, p6, p7, \ + tab_w[2], tab_w[3], \ + tab_w[4], tab_w[5], \ + tab_w[6], tab_w[7], \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p8, p9, p10, p11, \ + tab_w[4], tab_w[5], \ + tab_w[8], tab_w[9], \ + tab_w[10], tab_w[11], \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p12, p13, p14, p15, \ + tab_w[6], tab_w[7], \ + tab_w[12], tab_w[13], \ + tab_w[14], tab_w[15], \ + mod, mod2, p_hi, p_lo, tmp); \ +} while(0) + +/** 16-point DFT + * * in [0..4n) / out [0..4n) / max < 4n + * * Same specification as dft_lazy44, for depth==4 + */ +#define DFT16_LAZY44(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + node, mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + ulong w2, w2pre, w, wpre, Iw, Iwpre; \ + \ + w2 = tab_w[2*node]; \ + w2pre = tab_w[2*node+1]; \ + w = tab_w[4*node]; \ + wpre = tab_w[4*node+1]; \ + Iw = tab_w[4*node+2]; \ + Iwpre = tab_w[4*node+3]; \ + \ + DFT4_LAZY44(p0, p4, p8, p12, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p1, p5, p9, p13, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p2, p6, p10, p14, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p3, p7, p11, p15, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + \ + w2 = tab_w[8*node]; \ + w2pre = tab_w[8*node+1]; \ + w = tab_w[16*node]; \ + wpre = tab_w[16*node+1]; \ + Iw = tab_w[16*node+2]; \ + Iwpre = tab_w[16*node+3]; \ + DFT4_LAZY44(p0, p1, p2, p3, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + \ + w2 = tab_w[8*node+2]; \ + w2pre = tab_w[8*node+3]; \ + w = tab_w[16*node+4]; \ + wpre = tab_w[16*node+5]; \ + Iw = tab_w[16*node+6]; \ + Iwpre = tab_w[16*node+7]; \ + DFT4_LAZY44(p4, p5, p6, p7, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + \ + w2 = tab_w[8*node+4]; \ + w2pre = tab_w[8*node+5]; \ + w = tab_w[16*node+8]; \ + wpre = tab_w[16*node+9]; \ + Iw = tab_w[16*node+10]; \ + Iwpre = tab_w[16*node+11]; \ + DFT4_LAZY44(p8, p9, p10, p11, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ + \ + w2 = tab_w[8*node+6]; \ + w2pre = tab_w[8*node+7]; \ + w = tab_w[16*node+12]; \ + wpre = tab_w[16*node+13]; \ + Iw = tab_w[16*node+14]; \ + Iwpre = tab_w[16*node+15]; \ + DFT4_LAZY44(p12, p13, p14, p15, \ + w2, w2pre, w, wpre, Iw, Iwpre, \ + mod, mod2, p_hi, p_lo, tmp); \ +} while(0) + +/*--------------*/ +/* 32-point DFT */ +/*--------------*/ + +/** 32-point DFT, node 0 + * * in [0..n) / out [0..4n) / max < 4n + * * Apart from this range, same specification as dft_node0_lazy24, for depth==5 + */ +#define DFT32_NODE0_LAZY14(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + p16, p17, p18, p19, p20, p21, p22, p23, \ + p24, p25, p26, p27, p28, p29, p30, p31, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo; \ + \ + DFT4_NODE0_LAZY14(p0, p8, p16, p24, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p0 >= mod2) \ + p0 -= mod2; \ + DFT4_NODE0_LAZY14(p1, p9, p17, p25, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p1 >= mod2) \ + p1 -= mod2; \ + DFT4_NODE0_LAZY14(p2, p10, p18, p26, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p2 >= mod2) \ + p2 -= mod2; \ + DFT4_NODE0_LAZY14(p3, p11, p19, p27, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p3 >= mod2) \ + p3 -= mod2; \ + DFT4_NODE0_LAZY14(p4, p12, p20, p28, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p4 >= mod2) \ + p4 -= mod2; \ + DFT4_NODE0_LAZY14(p5, p13, p21, p29, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p5 >= mod2) \ + p5 -= mod2; \ + DFT4_NODE0_LAZY14(p6, p14, p22, p30, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p6 >= mod2) \ + p6 -= mod2; \ + DFT4_NODE0_LAZY14(p7, p15, p23, p31, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p7 >= mod2) \ + p7 -= mod2; \ + \ + /* next line requires < 2n, hence the 8 reductions above */ \ + DFT8_NODE0_LAZY24(p0, p1, p2, p3, p4, p5, p6, p7, mod, mod2, tab_w); \ + DFT8_LAZY44(p8, p9, p10, p11, p12, p13, p14, p15, 1, mod, mod2, tab_w); \ + DFT8_LAZY44(p16, p17, p18, p19, p20, p21, p22, p23, 2, mod, mod2, tab_w); \ + DFT8_LAZY44(p24, p25, p26, p27, p28, p29, p30, p31, 3, mod, mod2, tab_w); \ +} while(0) + +/** 32-point DFT, node 0 + * * in [0..2n) / out [0..4n) / max < 4n + * * Same specification as dft_node0_lazy24, for depth==5 + */ +#define DFT32_NODE0_LAZY24(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + p16, p17, p18, p19, p20, p21, p22, p23, \ + p24, p25, p26, p27, p28, p29, p30, p31, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo; \ + \ + DFT4_NODE0_LAZY24(p0, p8, p16, p24, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p0 >= mod2) \ + p0 -= mod2; \ + DFT4_NODE0_LAZY24(p1, p9, p17, p25, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p1 >= mod2) \ + p1 -= mod2; \ + DFT4_NODE0_LAZY24(p2, p10, p18, p26, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p2 >= mod2) \ + p2 -= mod2; \ + DFT4_NODE0_LAZY24(p3, p11, p19, p27, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p3 >= mod2) \ + p3 -= mod2; \ + DFT4_NODE0_LAZY24(p4, p12, p20, p28, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p4 >= mod2) \ + p4 -= mod2; \ + DFT4_NODE0_LAZY24(p5, p13, p21, p29, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p5 >= mod2) \ + p5 -= mod2; \ + DFT4_NODE0_LAZY24(p6, p14, p22, p30, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p6 >= mod2) \ + p6 -= mod2; \ + DFT4_NODE0_LAZY24(p7, p15, p23, p31, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + if (p7 >= mod2) \ + p7 -= mod2; \ + \ + /* next line requires < 2n, hence the 8 reductions above */ \ + DFT8_NODE0_LAZY24(p0, p1, p2, p3, p4, p5, p6, p7, mod, mod2, tab_w); \ + DFT8_LAZY44(p8, p9, p10, p11, p12, p13, p14, p15, 1, mod, mod2, tab_w); \ + DFT8_LAZY44(p16, p17, p18, p19, p20, p21, p22, p23, 2, mod, mod2, tab_w); \ + DFT8_LAZY44(p24, p25, p26, p27, p28, p29, p30, p31, 3, mod, mod2, tab_w); \ +} while(0) + +/** 32-point DFT + * * in [0..4n) / out [0..4n) / max < 4n + * * Same specification as dft_lazy44, for depth==5 + */ +#define DFT32_LAZY44(p0, p1, p2, p3, p4, p5, p6, p7, \ + p8, p9, p10, p11, p12, p13, p14, p15, \ + p16, p17, p18, p19, p20, p21, p22, p23, \ + p24, p25, p26, p27, p28, p29, p30, p31, \ + node, mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + ulong w2 = tab_w[2*node]; \ + ulong w2pre = tab_w[2*node+1]; \ + ulong w = tab_w[4*node]; \ + ulong wpre = tab_w[4*node+1]; \ + ulong Iw = tab_w[4*node+2]; \ + ulong Iwpre = tab_w[4*node+3]; \ + DFT4_LAZY44(p0, p8, p16, p24, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p1, p9, p17, p25, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p2, p10, p18, p26, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p3, p11, p19, p27, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p4, p12, p20, p28, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p5, p13, p21, p29, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p6, p14, p22, p30, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + DFT4_LAZY44(p7, p15, p23, p31, w2, w2pre, w, wpre, Iw, Iwpre, mod, mod2, p_hi, p_lo, tmp); \ + \ + /* next line requires < 2n, hence the four reductions above */ \ + DFT8_LAZY44(p0, p1, p2, p3, p4, p5, p6, p7, 4*node, mod, mod2, tab_w); \ + DFT8_LAZY44(p8, p9, p10, p11, p12, p13, p14, p15, 4*node+1, mod, mod2, tab_w); \ + DFT8_LAZY44(p16, p17, p18, p19, p20, p21, p22, p23, 4*node+2, mod, mod2, tab_w); \ + DFT8_LAZY44(p24, p25, p26, p27, p28, p29, p30, p31, 4*node+3, mod, mod2, tab_w); \ +} while(0) + +/*-------------*/ +/* general DFT */ +/*-------------*/ + +/** 2**depth-point DFT + * * in [0..4n) / out [0..4n) / max < 4n + * * In-place transform p of length len == 2**depth into + * the concatenation of + * [sum(p[i] * w_k**i for i in range(len), sum(p[i] * (-w_k)**i for i in range(len)] + * for k in range(len), + * where w_k = F->tab_w[2**depth * node + 2*k] for 0 <= k < 2**(depth-1) + * * By construction these evaluation points are the roots of the + * polynomial x**len - F->tab_w[node] + * * Requirement (not checked): + * 3 <= depth + * (node+1) * 2**depth <= 2**F.depth (length of F->tab_w) + */ +void dft_lazy44(nn_ptr p, ulong depth, ulong node, n_fft_args_t F) +{ + if (depth == 3) + { + DFT8_LAZY44(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], node, F->mod, F->mod2, F->tab_w); + } + else if (depth == 4) + { + DFT16_LAZY44(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + node, F->mod, F->mod2, F->tab_w); + } + else if (depth == 5) + { + DFT32_LAZY44(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + p[16], p[17], p[18], p[19], p[20], p[21], p[22], p[23], + p[24], p[25], p[26], p[27], p[28], p[29], p[30], p[31], + node, F->mod, F->mod2, F->tab_w); + } + else + { + const ulong len = UWORD(1) << depth; + + // 4-point butterflies + // in: [0..4n), out: [0..4n) + const nn_ptr p0 = p; + const nn_ptr p1 = p+len/4; + const nn_ptr p2 = p+2*len/4; + const nn_ptr p3 = p+3*len/4; + const ulong w2 = F->tab_w[2*node]; + const ulong w2pre = F->tab_w[2*node+1]; + const ulong w = F->tab_w[4*node]; + const ulong wpre = F->tab_w[4*node+1]; + const ulong Iw = F->tab_w[4*node+2]; + const ulong Iwpre = F->tab_w[4*node+3]; + ulong p_hi, p_lo, tmp; + + for (ulong k = 0; k < len/4; k+=4) + { + DFT4_LAZY44(p0[k+0], p1[k+0], p2[k+0], p3[k+0], w2, w2pre, w, wpre, Iw, Iwpre, F->mod, F->mod2, p_hi, p_lo, tmp); + DFT4_LAZY44(p0[k+1], p1[k+1], p2[k+1], p3[k+1], w2, w2pre, w, wpre, Iw, Iwpre, F->mod, F->mod2, p_hi, p_lo, tmp); + DFT4_LAZY44(p0[k+2], p1[k+2], p2[k+2], p3[k+2], w2, w2pre, w, wpre, Iw, Iwpre, F->mod, F->mod2, p_hi, p_lo, tmp); + DFT4_LAZY44(p0[k+3], p1[k+3], p2[k+3], p3[k+3], w2, w2pre, w, wpre, Iw, Iwpre, F->mod, F->mod2, p_hi, p_lo, tmp); + } + + // 4 recursive calls with depth-2 + dft_lazy44(p0, depth-2, 4*node, F); + dft_lazy44(p1, depth-2, 4*node+1, F); + dft_lazy44(p2, depth-2, 4*node+2, F); + dft_lazy44(p3, depth-2, 4*node+3, F); + } +} + +/** 2**depth-point DFT + * * in [0..2n) / out [0..4n) / max < 4n + * * In-place transform p of length len == 2**depth into + * the concatenation of + * [sum(p[i] * w_k**i for i in range(len), sum(p[i] * (-w_k)**i for i in range(len)] + * for k in range(len), + * where w_k = F->tab_w[2*k] for 0 <= k < 2**(depth-1) + * * By construction these evaluation points are the roots of the polynomial + * x**len - 1, precisely they are all powers of the chosen len-th primitive + * root of unity with exponents listed in bit reversed order + * * Requirements (not checked): 3 <= depth <= F.depth + */ +void dft_node0_lazy24(nn_ptr p, ulong depth, n_fft_args_t F) +{ + if (depth == 3) + { + DFT8_NODE0_LAZY24(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], F->mod, F->mod2, F->tab_w); + } + else if (depth == 4) + { + DFT16_NODE0_LAZY24(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + F->mod, F->mod2, F->tab_w); + } + else if (depth == 5) + { + DFT32_NODE0_LAZY24(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + p[16], p[17], p[18], p[19], p[20], p[21], p[22], p[23], + p[24], p[25], p[26], p[27], p[28], p[29], p[30], p[31], + F->mod, F->mod2, F->tab_w); + } + else + { + const ulong len = UWORD(1) << depth; + + // 4-point butterflies + // input p0,p1,p2,p3 in [0..2n) x [0..2n) x [0..2n) x [0..2n) + // output p0,p1,p2,p3 in [0..2n) x [0..4n) x [0..4n) x [0..4n) + const nn_ptr p0 = p; + const nn_ptr p1 = p + len/4; + const nn_ptr p2 = p + 2*len/4; + const nn_ptr p3 = p + 3*len/4; + ulong p_hi, p_lo; + for (ulong k = 0; k < len/4; k++) + { + DFT4_NODE0_LAZY24(p0[k], p1[k], p2[k], p3[k], F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + if (p0[k] >= F->mod2) + p0[k] -= F->mod2; + } + + // 4 recursive calls with depth-2 + dft_node0_lazy24(p0, depth-2, F); + dft_lazy44(p1, depth-2, 1, F); + dft_lazy44(p2, depth-2, 2, F); + dft_lazy44(p3, depth-2, 3, F); + } +} + +/** 2**depth-point DFT + * * in [0..n) / out [0..4n) / max < 4n + * * In-place transform p of length len == 2**depth into + * the concatenation of + * [sum(p[i] * w_k**i for i in range(len), sum(p[i] * (-w_k)**i for i in range(len)] + * for k in range(len), + * where w_k = F->tab_w[2*k] for 0 <= k < 2**(depth-1) + * * By construction these evaluation points are the roots of the polynomial + * x**len - 1, precisely they are all powers of the chosen len-th primitive + * root of unity with exponents listed in bit reversed order + * * Requirements (not checked): depth <= F.depth + */ +void dft_node0_lazy14(nn_ptr p, ulong depth, n_fft_args_t F) +{ + if (depth == 4) + { + DFT16_NODE0_LAZY14(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + F->mod, F->mod2, F->tab_w); + } + else if (depth == 5) + { + DFT32_NODE0_LAZY14(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + p[8], p[9], p[10], p[11], p[12], p[13], p[14], p[15], + p[16], p[17], p[18], p[19], p[20], p[21], p[22], p[23], + p[24], p[25], p[26], p[27], p[28], p[29], p[30], p[31], + F->mod, F->mod2, F->tab_w); + } + else if (depth > 5) + { + const ulong len = UWORD(1) << depth; + + // 4-point butterflies + // input p0,p1,p2,p3 in [0..n) x [0..n) x [0..n) x [0..n) + // output p0,p1,p2,p3 in [0..2n) x [0..4n) x [0..4n) x [0..4n) + const nn_ptr p0 = p; + const nn_ptr p1 = p + len/4; + const nn_ptr p2 = p + 2*len/4; + const nn_ptr p3 = p + 3*len/4; + ulong p_hi, p_lo; + for (ulong k = 0; k < len/4; k++) + { + DFT4_NODE0_LAZY14(p0[k], p1[k], p2[k], p3[k], F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + if (p0[k] >= F->mod2) + p0[k] -= F->mod2; + } + + // 4 recursive calls with depth-2 + dft_node0_lazy24(p0, depth-2, F); + dft_lazy44(p1, depth-2, 1, F); + dft_lazy44(p2, depth-2, 2, F); + dft_lazy44(p3, depth-2, 3, F); + } + else if (depth == 3) + { + DFT8_NODE0_LAZY14(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], F->mod, F->mod2, F->tab_w); + } + else if (depth == 2) + { + ulong p_hi, p_lo; + DFT4_NODE0_LAZY14(p[0], p[1], p[2], p[3], F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + } + else if (depth == 1) + { + ulong tmp; + BUTTERFLY_LAZY12(p[0], p[1], F->mod, tmp); + } +} + diff --git a/src/n_fft/idft.c b/src/n_fft/idft.c new file mode 100644 index 0000000000..50036b3c3d --- /dev/null +++ b/src/n_fft/idft.c @@ -0,0 +1,321 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "n_fft.h" +#include "n_fft_macros.h" + +/*--------------*/ +/* 2-point IDFT */ +/*--------------*/ + +/** Gentleman-Sande butterfly, general + * * in [0..2n) / out [0..2n) / max < 4n + * * In-place transform + * [1 iw] + * [a b] <- [a b] [1 -iw] + * * n2 is 2*n, iw_pr is the precomputed data for multiplication by iw mod n + * p_hi, p_lo, tmp are temporaries + * * can be seen as interpolation at points w = 1 / iw and -w, up to a scaling + * by 1/2, since the inverse of [1 iw] is 1/2 * [1 1] + * [1 -iw] [w -w] + */ +// TODO make order of arguments consistent +#define IDFT2_LAZY22(a, b, n, n2, w, w_pr, p_hi, p_lo, tmp) \ +do { \ + tmp = (a) + (n2) - (b); /* [0..4n) */ \ + (a) = (a) + (b); /* [0..4n) */ \ + if ((a) >= (n2)) \ + (a) -= (n2); /* [0..2n) */ \ + N_MULMOD_PRECOMP_LAZY((b), w, tmp, w_pr, n, p_hi, p_lo); \ + /* --> (b) in [0..2n) */ \ +} while(0) + +// move in macros? +// in [0..4n) x [0..2n) -> out [0..4n) x [0..4n) +// TODO rename +#define BUTTERFLY_LAZY4244(a, b, n2, tmp) \ +do { \ + tmp = (a); \ + if (tmp >= (n2)) \ + tmp -= (n2); /* [0..2n) */ \ + (a) = tmp + (b); /* [0..4n) */ \ + (b) = tmp + (n2) - (b); /* [0..4n) */ \ +} while(0) + +/*--------------*/ +/* 4-point IDFT */ +/*--------------*/ + +/** 4-point IDFT, general + * * in [0..4n) / out [0..4n) / max < 4n + * * In-place transform + * [ 1 iw2 iw1 iw1*iw2] + * [ 1 -iw2 iw1 -iw1*iw2] + * [a b c d] <- [a b c d] [ 1 iw3 -iw1 -iw1*iw3] + * [ 1 -iw3 -iw1 iw1*iw3] + * [1 iw2 0 0] [1 0 w1 0] + * == [a b c d] [1 -iw2 0 0] [0 1 0 w1] + * [0 0 1 iw3] [1 0 -w1 0] + * [0 0 1 -iw3] [0 1 0 -w1] + * * Corresponds, up to scaling by 1/4, to going up the tree with nodes + * x^4 - w1**2 + * / \ + * x^2 - w1 x^2 + w1 + * / \ / \ + * x - w2 x + w2 x - w3 x + w3 + * typically w2**2 == w1 and w3 == I*w2 (hence w3**2 == -w1) so that the above + * is the inverse of a Vandermonde matrix and this tree really is the + * subproduct tree built from the four roots w2, -w2, I*w2, -I*w2 of x**4 - w1 + */ +#define IDFT4_LAZY22(a, b, c, d, \ + w1, w1_pr, w2, w2_pr, w3, w3_pr, \ + n, n2, p_hi, p_lo) \ +do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v1; /* < 4*n */ \ + if (v4 >= (n2)) \ + v4 -= (n2); /* < 2*n */ \ + ulong v5; \ + N_MULMOD_PRECOMP_LAZY(v5, (w2), v0 + (n2) - v1, (w2_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 4*n */ \ + if (v6 >= (n2)) \ + v6 -= (n2); /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (w3), v2 + (n2) - v3, (w3_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + \ + (a) = v4 + v6; \ + if ((a) >= (n2)) \ + (a) -= (n2); /* < 2*n */ \ + (b) = v5 + v7; \ + if ((b) >= (n2)) \ + (b) -= (n2); /* < 2*n */ \ + N_MULMOD_PRECOMP_LAZY((c), (w1), v4 + (n2) - v6, (w1_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + N_MULMOD_PRECOMP_LAZY((d), (w1), v5 + (n2) - v7, (w1_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ +} while(0) + +#define IDFT4_LAZY12(a, b, c, d, \ + w1, w1_pr, w2, w2_pr, w3, w3_pr, \ + n, n2, p_hi, p_lo) \ +do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v1; /* < 2*n */ \ + ulong v5; \ + N_MULMOD_PRECOMP_LAZY(v5, (w2), v0 + (n) - v1, (w2_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (w3), v2 + (n) - v3, (w3_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + \ + (a) = v4 + v6; /* < 4*n */ \ + if ((a) >= (n2)) \ + (a) -= (n2); /* < 2*n */ \ + (b) = v5 + v7; /* < 4*n */ \ + if ((b) >= (n2)) \ + (b) -= (n2); /* < 2*n */ \ + N_MULMOD_PRECOMP_LAZY((c), (w1), v4 + (n2) - v6, (w1_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + N_MULMOD_PRECOMP_LAZY((d), (w1), v5 + (n2) - v7, (w1_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ +} while(0) + +/*--------------*/ +/* 8-point IDFT */ +/*--------------*/ + +// TODO doc +#define IDFT8_NODE0_LAZY14(p0, p1, p2, p3, p4, p5, p6, p7, \ + mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + IDFT4_NODE0_LAZY14(p0, p1, p2, p3, \ + tab_w[2], tab_w[3], \ + mod, mod2, p_hi, p_lo); \ + IDFT4_LAZY12(p4, p5, p6, p7, \ + tab_w[2], tab_w[3], \ + tab_w[4], tab_w[5], \ + tab_w[6], tab_w[7], \ + mod, mod2, p_hi, p_lo); \ + \ + BUTTERFLY_LAZY4244(p0, p4, mod2, tmp); \ + BUTTERFLY_LAZY4244(p1, p5, mod2, tmp); \ + BUTTERFLY_LAZY4244(p2, p6, mod2, tmp); \ + BUTTERFLY_LAZY4244(p3, p7, mod2, tmp); \ +} while(0) + + +/** 8-point IDFT + * TODO clean, check laziness + * * in [0..?n) / out [0..?n) / max < ?n + */ +#define DFT8_LAZY12(p0, p1, p2, p3, p4, p5, p6, p7, \ + node, mod, mod2, tab_w) \ +do { \ + ulong p_hi, p_lo, tmp; \ + \ + const ulong w = tab_w[2*(node)]; \ + const ulong w_pr = tab_w[2*(node)+1]; \ + \ + IDFT4_LAZY12(p0, p1, p2, p3, \ + tab_w[4*(node)], tab_w[4*(node)+1], \ + tab_w[8*(node)], tab_w[8*(node)+1], \ + tab_w[8*(node)+2], tab_w[8*(node)+3], \ + mod, mod2, p_hi, p_lo); \ + \ + IDFT4_LAZY12(p4, p5, p6, p7, \ + tab_w[4*(node)+2], tab_w[4*(node)+3], \ + tab_w[8*(node)+4], tab_w[8*(node)+5], \ + tab_w[8*(node)+6], tab_w[8*(node)+7], \ + mod, mod2, p_hi, p_lo); \ + \ + IDFT2_LAZY22(p0, p4, mod, mod2, w, w_pr, p_hi, p_lo, tmp); \ + IDFT2_LAZY22(p1, p5, mod, mod2, w, w_pr, p_hi, p_lo, tmp); \ + IDFT2_LAZY22(p2, p6, mod, mod2, w, w_pr, p_hi, p_lo, tmp); \ + IDFT2_LAZY22(p3, p7, mod, mod2, w, w_pr, p_hi, p_lo, tmp); \ +} while(0) + + + + + + + + + + + + + + +/*--------------*/ +/* general IDFT */ +/*--------------*/ + + +// TODO doc +// TODO make sure this is tested (code coverage: including for small depths) +void idft_lazy12(nn_ptr p, ulong depth, ulong node, n_fft_args_t F) +{ + if (depth == 1) + { + ulong p_hi, p_lo, tmp; + IDFT2_LAZY22(p[0], p[1], F->mod, F->mod2, F->tab_w[2*node], F->tab_w[2*node+1], p_hi, p_lo, tmp); + } + else if (depth == 2) + { + ulong p_hi, p_lo; + IDFT4_LAZY12(p[0], p[1], p[2], p[3], + F->tab_w[2*node], F->tab_w[2*node+1], + F->tab_w[4*node], F->tab_w[4*node+1], + F->tab_w[4*node+2], F->tab_w[4*node+3], + F->mod, F->mod2, p_hi, p_lo); + } + else if (depth == 3) + { + DFT8_LAZY12(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + node, F->mod, F->mod2, F->tab_w); + } + else + { + const ulong len = UWORD(1) << depth; + + // 4 recursive calls with depth-2 + const nn_ptr p0 = p; + const nn_ptr p1 = p + len/4; + const nn_ptr p2 = p + 2*len/4; + const nn_ptr p3 = p + 3*len/4; + idft_lazy12(p0, depth-2, 4*node, F); + idft_lazy12(p1, depth-2, 4*node+1, F); + idft_lazy12(p2, depth-2, 4*node+2, F); + idft_lazy12(p3, depth-2, 4*node+3, F); + + const ulong w2 = F->tab_w[2*node]; + const ulong w2_pr = F->tab_w[2*node+1]; + const ulong w = F->tab_w[4*node]; + const ulong w_pr = F->tab_w[4*node+1]; + const ulong Iw = F->tab_w[4*node+2]; + const ulong Iw_pr = F->tab_w[4*node+3]; + ulong p_hi, p_lo; + + for (ulong k = 0; k < len/4; k+=4) + { + IDFT4_LAZY22(p0[k+0], p1[k+0], p2[k+0], p3[k+0], w2, w2_pr, w, w_pr, Iw, Iw_pr, F->mod, F->mod2, p_hi, p_lo); + IDFT4_LAZY22(p0[k+1], p1[k+1], p2[k+1], p3[k+1], w2, w2_pr, w, w_pr, Iw, Iw_pr, F->mod, F->mod2, p_hi, p_lo); + IDFT4_LAZY22(p0[k+2], p1[k+2], p2[k+2], p3[k+2], w2, w2_pr, w, w_pr, Iw, Iw_pr, F->mod, F->mod2, p_hi, p_lo); + IDFT4_LAZY22(p0[k+3], p1[k+3], p2[k+3], p3[k+3], w2, w2_pr, w, w_pr, Iw, Iw_pr, F->mod, F->mod2, p_hi, p_lo); + } + } +} + +void idft_node0_lazy12(nn_ptr p, ulong depth, n_fft_args_t F) +{ + if (depth == 0) + return; + + if (depth == 1) + { + ulong tmp; + BUTTERFLY_LAZY12(p[0], p[1], F->mod, tmp); + } + else if (depth == 2) + { + ulong p_hi, p_lo; + IDFT4_NODE0_LAZY14(p[0], p[1], p[2], p[3], F->tab_w[2], F->tab_w[3], + F->mod, F->mod2, p_hi, p_lo); + } + else if (depth == 3) + { + // TODO to be improved + IDFT8_NODE0_LAZY14(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], + F->mod, F->mod2, F->tab_w); + } + else + { + const ulong len = UWORD(1) << depth; + + // 4 recursive calls with depth-2 + const nn_ptr p0 = p; + const nn_ptr p1 = p + len/4; + const nn_ptr p2 = p + 2*len/4; + const nn_ptr p3 = p + 3*len/4; + idft_node0_lazy12(p0, depth-2, F); + idft_lazy12(p1, depth-2, 1, F); + idft_lazy12(p2, depth-2, 2, F); + idft_lazy12(p3, depth-2, 3, F); + + // 4-point butterflies + // input p0 in [0,4n), p1,p2,p3 in [0,2n) + // output p0,p1,p2,p3 in [0,4n) + ulong p_hi, p_lo; + for (ulong k = 0; k < len/4; k+=4) + { + IDFT4_NODE0_LAZY4222(p0[k+0], p1[k+0], p2[k+0], p3[k+0], + F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + IDFT4_NODE0_LAZY4222(p0[k+1], p1[k+1], p2[k+1], p3[k+1], + F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + IDFT4_NODE0_LAZY4222(p0[k+2], p1[k+2], p2[k+2], p3[k+2], + F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + IDFT4_NODE0_LAZY4222(p0[k+3], p1[k+3], p2[k+3], p3[k+3], + F->tab_w[2], F->tab_w[3], F->mod, F->mod2, p_hi, p_lo); + } + } +} diff --git a/src/n_fft/n_fft_macros.h b/src/n_fft/n_fft_macros.h new file mode 100644 index 0000000000..4817f0ceb2 --- /dev/null +++ b/src/n_fft/n_fft_macros.h @@ -0,0 +1,256 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#ifndef N_FFT_MACROS_H +#define N_FFT_MACROS_H + +/*---------*/ +/* helpers */ +/*---------*/ + +/** Shoup's modular multiplication with precomputation, lazy + * (does not perform the excess correction step) + * --> computes either r or r+n and store it is res, where r = (a*b) % n + * --> a_pr is the precomputation for n, p_hi and p_lo are temporaries + */ +#define N_MULMOD_PRECOMP_LAZY(res, a, b, a_pr, n, p_hi, p_lo) \ + do { \ + umul_ppmm(p_hi, p_lo, (a_pr), (b)); \ + res = (a) * (b) - p_hi * (n); \ + } while(0) + +/*---------------------*/ +/* radix-2 butterflies */ +/*---------------------*/ + +/** Butterfly radix 2 + * * in [0..n) x [0..n) / out [0..2n) x [0..2n) / max < 2n + * * In-place transform + * [1 1] + * [a b] <- [a b] [1 -1] + * * n is the modulus, tmp is a temporary + */ +#define BUTTERFLY_LAZY12(a, b, n, tmp) \ + do { \ + tmp = (b); \ + (b) = (a) + (n) - tmp; \ + (a) = (a) + tmp; \ + } while(0) + +/** Butterfly radix 2 + * * in [0..2n) x [0..2n) / out [0..2n) x [0..4n) / max < 4n + * * In-place transform + * [1 1] + * [a b] <- [a b] [1 -1] + * * n2 is 2*n, tmp is a temporary + */ +#define BUTTERFLY_LAZY24(a, b, n2, tmp) \ + do { \ + tmp = (b); \ + (b) = (a) + (n2) - tmp; \ + (a) = (a) + tmp; \ + if ((a) >= (n2)) \ + (a) -= (n2); \ + } while(0) + +/*---------------------*/ +/* radix-4 butterflies */ +/*---------------------*/ + +/** 4-point butterfly, evaluation + * * in [0..n) / out [0..4n) / max < 4n + * * In-place transform + * [1 1 1 1] + * [1 -1 I -I] + * [a b c d] <- [a b c d] [1 1 -1 -1] + * [1 -1 -I I] + * [1 0 1 0] [1 1 0 0] + * == [a b c d] [0 1 0 I] [1 -1 0 0] + * [1 0 -1 0] [0 0 1 1] + * [0 1 0 -I] [0 0 1 -1] + * * Corresponds to reducing down the tree with nodes + * x^4 - 1 + * / \ + * x^2 - 1 x^2 + 1 + * / \ / \ + * x - 1 x + 1 x - I x + I + * where I is typically a square root of -1 + * (but this property is not exploited) + * * n is the modulus and n2 == 2*n, p_hi, p_lo are temporaries + */ +#define DFT4_NODE0_LAZY14(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ + do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v2; /* < 2*n */ \ + ulong v5 = v0 + (n) - v2; /* < 2*n */ \ + ulong v6 = v1 + v3; /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v1 + (n) - v3, (I_pr), (n), \ + p_hi, p_lo); \ + (a) = v4 + v6; /* < 4*n */ \ + (b) = v4 + (n2) - v6; /* < 4*n */ \ + (c) = v5 + v7; /* < 3*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ + } while(0) + +/** 4-point butterfly, evaluation + * * in [0..2n) / out [0..4n) / max < 4n + * * other than this, same specification as DFT4_NODE0_LAZY14 + */ +#define DFT4_NODE0_LAZY24(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ + do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v2; /* < 4*n */ \ + if (v4 >= (n2)) \ + v4 -= (n2); /* < 2*n */ \ + ulong v5 = v0 + (n2) - v2; /* < 4*n */ \ + if (v5 >= (n2)) \ + v5 -= (n2); /* < 2*n */ \ + ulong v6 = v1 + v3; /* < 4*n */ \ + if (v6 >= (n2)) \ + v6 -= (n2); /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v1 + (n2) - v3, (I_pr), (n), \ + p_hi, p_lo); \ + (a) = v4 + v6; /* < 4*n */ \ + (b) = v4 + (n2) - v6; /* < 4*n */ \ + (c) = v5 + v7; /* < 4*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ + } while(0) + + +/** 4-point butterfly, interpolation + * * in [0..n) / out [0..4n) / max < 4n + * * In-place transform + * [1 1 1 1] + * [1 -1 1 -1] + * [a b c d] <- [a b c d] [1 -I -1 I] + * [1 I -1 -I] + * [1 1 0 0] [1 0 1 0] + * == [a b c d] [1 -1 0 0] [0 1 0 1] + * [0 0 1 I] [1 0 -1 0] + * [0 0 1 -I] [0 1 0 -1] + * + * * If I**2 == -1, this matrix is the inverse of the one above; this + * corresponds to interpolation at 1, -1, I, -I, up to scaling by 1/4; or to + * going up the tree with nodes + * x^4 - 1 + * / \ + * x^2 - 1 x^2 + 1 + * / \ / \ + * x - 1 x + 1 x - I x + I + */ +#define IDFT4_NODE0_LAZY12(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ +do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v1; /* < 2*n */ \ + ulong v5 = v0 + (n) - v1; /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v2 + (n) - v3, (I_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + (a) = v4 + v6; /* < 4*n */ \ + if ((a) >= (n2)) \ + (a) -= (n2); /* < 2*n */ \ + (b) = v5 + v7; /* < 4*n */ \ + if ((b) >= (n2)) \ + (b) -= (n2); /* < 2*n */ \ + (c) = v4 + (n2) - v6; /* < 4*n */ \ + if ((c) >= (n2)) \ + (c) -= (n2); /* < 2*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ + if ((d) >= (n2)) \ + (d) -= (n2); /* < 2*n */ \ +} while(0) + +#define IDFT4_NODE0_LAZY14(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ +do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v1; /* < 2*n */ \ + ulong v5 = v0 + (n) - v1; /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v2 + (n) - v3, (I_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + (a) = v4 + v6; /* < 4*n */ \ + (b) = v5 + v7; /* < 4*n */ \ + (c) = v4 + (n2) - v6; /* < 4*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ +} while(0) + +/** 4-point butterfly, interpolation + * * in [0..2n) / out [0..4n) / max < 4n + * * other than this, same specification as IDFT4_NODE0_LAZY14 + */ +#define IDFT4_NODE0_LAZY24(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ +do { \ + const ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + ulong v4 = v0 + v1; /* < 4*n */ \ + if (v4 >= (n2)) \ + v4 -= (n2); /* < 2*n */ \ + ulong v5 = v0 + (n2) - v1; /* < 4*n */ \ + if (v5 >= (n2)) \ + v5 -= (n2); /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 4*n */ \ + if (v6 >= (n2)) \ + v6 -= (n2); /* < 2*n */ \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v2 + (n2) - v3, (I_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + (a) = v4 + v6; /* < 4*n */ \ + (b) = v5 + v7; /* < 4*n */ \ + (c) = v4 + (n2) - v6; /* < 4*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ +} while(0) + +#define IDFT4_NODE0_LAZY4222(a, b, c, d, I, I_pr, n, n2, p_hi, p_lo) \ +do { \ + ulong v0 = (a); \ + const ulong v1 = (b); \ + const ulong v2 = (c); \ + const ulong v3 = (d); \ + if (v0 >= (n2)) \ + v0 -= (n2); /* < 2*n */ \ + ulong v4 = v0 + v1; /* < 4*n */ \ + if (v4 >= (n2)) \ + v4 -= (n2); /* < 2*n */ \ + ulong v5 = v0 + (n2) - v1; /* < 4*n */ \ + if (v5 >= (n2)) \ + v5 -= (n2); /* < 2*n */ \ + ulong v6 = v2 + v3; /* < 4*n */ \ + if (v6 >= (n2)) \ + v6 -= (n2); /* < 2*n */ \ + ulong v7; \ + N_MULMOD_PRECOMP_LAZY(v7, (I), v2 + (n2) - v3, (I_pr), (n), \ + p_hi, p_lo); /* < 2*n */ \ + (a) = v4 + v6; /* < 4*n */ \ + (b) = v5 + v7; /* < 4*n */ \ + (c) = v4 + (n2) - v6; /* < 4*n */ \ + (d) = v5 + (n2) - v7; /* < 4*n */ \ +} while(0) + + + +#endif /* N_FFT_MACROS_H */ diff --git a/src/n_fft/profile/p-dft.c b/src/n_fft/profile/p-dft.c new file mode 100644 index 0000000000..bfa4174faf --- /dev/null +++ b/src/n_fft/profile/p-dft.c @@ -0,0 +1,142 @@ +#include "profiler.h" +#include "nmod_vec.h" +#include "fft_small.h" +#include "n_fft.h" + +#define num_primes 7 + +typedef struct +{ + ulong prime; + ulong depth; + ulong stride; +} info_t; + +#define SAMPLE(fun, _variant) \ +void sample_##fun##_variant(void * arg, ulong count) \ +{ \ + info_t * info = (info_t *) arg; \ + const ulong p = info->prime; \ + const ulong depth = info->depth; \ + const ulong stride = info->stride; \ + \ + const ulong len = stride * (UWORD(1) << depth); \ + const ulong rep = FLINT_MAX(1, FLINT_MIN(1000, 1000000/len)); \ + \ + /* modulus, roots of unity */ \ + n_fft_ctx_t F; \ + n_fft_ctx_init2(F, depth, p); \ + \ + FLINT_TEST_INIT(state); \ + \ + ulong * coeffs = _nmod_vec_init(len); \ + for (ulong k = 0; k < len; k++) \ + coeffs[k] = n_randint(state, p); \ + \ + for (ulong i = 0; i < count; i++) \ + { \ + prof_start(); \ + for (ulong j = 0; j < rep; j++) \ + n_fft_##fun##_variant(coeffs, depth, F); \ + prof_stop(); \ + } \ + \ + n_fft_ctx_clear(F); \ + FLINT_TEST_CLEAR(state); \ +} \ + +SAMPLE(dft, ) +SAMPLE(idft, ) +SAMPLE(dft_t, ) +SAMPLE(idft_t, ) +//SAMPLE(n_fft_dft, _stride) + +void sample_sd_fft(void * arg, ulong count) +{ + info_t * info = (info_t *) arg; + const ulong p = info->prime; + const ulong depth = info->depth; + + const ulong len = UWORD(1) << depth; + const ulong rep = FLINT_MAX(1, FLINT_MIN(1000, 1000000/len)); + + sd_fft_ctx_t Q; + sd_fft_ctx_init_prime(Q, p); + sd_fft_ctx_fit_depth(Q, depth); + + ulong sz = sd_fft_ctx_data_size(depth)*sizeof(double); + + FLINT_TEST_INIT(state); + + nmod_t mod; + nmod_init(&mod, p); + ulong * coeffs = _nmod_vec_init(len); + _nmod_vec_randtest(coeffs, state, len, mod); + + double* data = flint_aligned_alloc(4096, n_round_up(sz, 4096)); + for (ulong i = 0; i < len; i++) + data[i] = coeffs[i]; + + for (ulong i = 0; i < count; i++) + { + prof_start(); + for (ulong j = 0; j < rep; j++) + sd_fft_trunc(Q, data, depth, len, len); + prof_stop(); + } + + sd_fft_ctx_clear(Q); + FLINT_TEST_CLEAR(state); +} + +int main() +{ + flint_printf("- depth is log(fft length)\n"); + flint_printf("- timing DFT (length power of 2) for several bit lengths and depths\n"); + flint_printf("depth\tsd_fft\tdft\tidft\tdft_t\tidft_t\n"); + + ulong primes[num_primes] = { + 786433, // 20 bits, 1 + 2**18 * 3 + 1073479681, // 30 bits, 1 + 2**30 - 2**18 == 1 + 2**18 * (2**12 - 1) + 2013265921, // 31 bits, 1 + 2**27 * 3 * 5 + 2748779069441, // 42 bits, 1 + 2**39 * 5 + 1108307720798209, // 50 bits, 1 + 2**44 * 3**2 * 7 + 1139410705724735489, // 60 bits, 1 + 2**52 * 11 * 23 + 4611686018427322369 // 62 bits: 1 + 2**62 - 2**16 == 1 + 2**16 * (2**46 - 1) + }; + ulong max_depths[num_primes] = { 18, 18, 25, 25, 25, 25, 16 }; + + for (ulong k = 4; k < 5; k++) + { + for (ulong depth = 3; depth <= max_depths[k]; depth++) + { + printf("%ld\t", depth); + + info_t info; + info.prime = primes[k]; + info.depth = depth; + info.stride = 1; + + const ulong len = UWORD(1) << depth; + const ulong rep = FLINT_MAX(1, FLINT_MIN(1000, 1000000/len)); + + double min[15]; + double max; + + prof_repeat(min+0, &max, sample_sd_fft, (void *) &info); + prof_repeat(min+1, &max, sample_dft, (void *) &info); + prof_repeat(min+2, &max, sample_idft, (void *) &info); + prof_repeat(min+3, &max, sample_dft_t, (void *) &info); + prof_repeat(min+4, &max, sample_idft_t, (void *) &info); + + flint_printf("%.1e\t%.1e\t%.1e\t%.1e\t%.1e\t\n", + min[0]/(double)1000000/rep, + min[1]/(double)1000000/rep, + min[2]/(double)1000000/rep, + min[3]/(double)1000000/rep, + min[4]/(double)1000000/rep + ); + } + } + return 0; +} diff --git a/src/n_fft/profile/p-init.c b/src/n_fft/profile/p-init.c new file mode 100644 index 0000000000..f19117066a --- /dev/null +++ b/src/n_fft/profile/p-init.c @@ -0,0 +1,126 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "flint.h" +#include "nmod.h" +#include "profiler.h" +#include "n_fft.h" + +#define num_primes 5 + +typedef struct +{ + ulong prime; + ulong depth; + ulong maxdepth; +} info_t; + +void sample_init2_root(void * arg, ulong count) +{ + info_t * info = (info_t *) arg; + ulong p = info->prime; + ulong depth = info->depth; + ulong maxdepth = info->maxdepth; + + const ulong len = UWORD(1) << depth; + const ulong rep = FLINT_MAX(1, FLINT_MIN(1000, 1000000/len)); + + // modulus, roots of unity + nmod_t mod; + nmod_init(&mod, p); + ulong cofactor = (p - 1) >> maxdepth; + ulong w0 = nmod_pow_ui(n_primitive_root_prime(p), cofactor, mod); + ulong w = nmod_pow_ui(w0, 1UL<<(maxdepth - depth), mod); + + FLINT_TEST_INIT(state); + + for (ulong i = 0; i < count; i++) + { + prof_start(); + for (ulong j = 0; j < rep; j++) + { + n_fft_ctx_t F; + n_fft_ctx_init2_root(F, w, depth, cofactor, depth, p); + n_fft_ctx_clear(F); + } + prof_stop(); + } + + FLINT_TEST_CLEAR(state); +} + +/*-----------------------------------------------------------------*/ +/* initialize context for FFT for several bit lengths and depths */ +/*-----------------------------------------------------------------*/ +void time_fft_init(ulong * primes, ulong * max_depths) +{ + for (ulong depth = 3; depth <= 25; depth++) + { + printf("%ld\t", depth); + for (ulong k = 0; k < num_primes; k++) + { + if (depth <= max_depths[k]) + { + info_t info; + info.prime = primes[k]; + info.maxdepth = max_depths[k]; + info.depth = depth; + + const ulong len = UWORD(1) << depth; + const ulong rep = FLINT_MAX(1, FLINT_MIN(1000, 1000000/len)); + + double min; + double max; + + prof_repeat(&min, &max, sample_init2_root, (void *) &info); + + flint_printf("%.1e|%.1e\t", + min/(double)1000000/rep, + min/(double)FLINT_CLOCK_SCALE_FACTOR/len/rep + ); + } + else + flint_printf(" na | na \t"); + } + flint_printf("\n"); + } + +} + +/*------------------------------------------------------------*/ +/* main just calls time_init_set() */ +/*------------------------------------------------------------*/ +int main() +{ + printf("- depth == precomputing w**k, 0 <= k < 2**depth\n"); + printf("- timing init FFT context + clear at this depth:\n"); + printf(" t_raw == raw time\n"); + printf(" t_unit == raw time divided by 2**depth * clock scale factor\n"); + printf("\n"); + + printf(" \t 20 bits \t 31 bits \t 42 bits \t 50 bits \t 60 bits \n"); + printf("depth\tt_raw | t_unit\tt_raw | t_unit\tt_raw | t_unit\tt_raw | t_unit\tt_raw | t_unit\n"); + + // TODO fix for FLINT_BITS==32 + ulong primes[num_primes] = { + 786433, // 20 bits, 1 + 2**18 * 3 + 2013265921, // 31 bits, 1 + 2**27 * 3 * 5 + 2748779069441, // 42 bits, 1 + 2**39 * 5 + 1108307720798209, // 50 bits, 1 + 2**44 * 3**2 * 7 + 1139410705724735489, // 60 bits, 1 + 2**52 * 11 * 23 + }; + ulong max_depths[num_primes] = { 18, 25, 25, 25, 25 }; + + time_fft_init(primes, max_depths); + + return 0; +} + diff --git a/src/n_fft/test/main.c b/src/n_fft/test/main.c new file mode 100644 index 0000000000..296d96f361 --- /dev/null +++ b/src/n_fft/test/main.c @@ -0,0 +1,29 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +/* Include functions *********************************************************/ + +#include "t-init.c" +#include "t-dft.c" +#include "t-idft.c" + +/* Array of test functions ***************************************************/ + +test_struct tests[] = +{ + TEST_FUNCTION(n_fft_ctx_init2), + TEST_FUNCTION(n_fft_dft), + TEST_FUNCTION(n_fft_idft), +}; + +/* main function *************************************************************/ + +TEST_MAIN(tests) diff --git a/src/n_fft/test/t-dft.c b/src/n_fft/test/t-dft.c new file mode 100644 index 0000000000..d6c7bd66ec --- /dev/null +++ b/src/n_fft/test/t-dft.c @@ -0,0 +1,145 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "flint.h" +#include "test_helpers.h" +#include "ulong_extras.h" +#include "nmod.h" +#include "nmod_poly.h" +#include "nmod_vec.h" +#include "n_fft.h" + +#define MAX_EVAL_DEPTH 11 // must be <= 12 + +// vector equality up to reduction mod +int nmod_vec_red_equal(nn_srcptr vec1, nn_srcptr vec2, ulong len, nmod_t mod) +{ + for (ulong k = 0; k < len; k++) + { + ulong v1; + ulong v2; + NMOD_RED(v1, vec1[k], mod); + NMOD_RED(v2, vec2[k], mod); + if (v1 != v2) + return 0; + } + + return 1; +} + +// testing that all elements of "vec" are less than "bound" +int nmod_vec_range(nn_srcptr vec, ulong len, ulong bound) +{ + for (ulong k = 0; k < len; k++) + if (vec[k] >= bound) + return 0; + + return 1; +} + + +TEST_FUNCTION_START(n_fft_dft, state) +{ + int i; + + for (i = 0; i < 200 * flint_test_multiplier(); i++) + { + // take some FFT prime p with max_depth >= 12 + ulong max_depth, prime; + + // half of tests == fixed large prime, close to limit + // 62 bits: prime = 4611686018427322369 == 2**62 - 2**16 + 1 + // 30 bits: prime = 1073479681 == 2**30 - 2**18 + 1 + if (i > 100) +#if FLINT_BITS == 64 + prime = UWORD(4611686018427322369); +#else // FLINT_BITS == 32 + prime = UWORD(1073479681); +#endif + else + { + max_depth = 12 + n_randint(state, 6); + prime = 1 + (UWORD(1) << max_depth); + while (! n_is_prime(prime)) + prime += (UWORD(1) << max_depth); + } + max_depth = flint_ctz(prime-1); + + nmod_t mod; + nmod_init(&mod, prime); + + // init FFT root tables + n_fft_ctx_t F; + n_fft_ctx_init2(F, MAX_EVAL_DEPTH, prime); + + // retrieve roots, used later for multipoint evaluation + nn_ptr roots = flint_malloc((UWORD(1) << MAX_EVAL_DEPTH) * sizeof(ulong)); + for (ulong k = 0; k < (UWORD(1) << (MAX_EVAL_DEPTH-1)); k++) + { + roots[2*k] = F->tab_w[2*k]; + roots[2*k+1] = prime - F->tab_w[2*k]; // < prime since F->tab_w[2*k] != 0 + } + + for (ulong depth = 0; depth <= MAX_EVAL_DEPTH; depth++) + { + const ulong len = (UWORD(1) << depth); + + // choose random poly of degree < len + nmod_poly_t pol; + nmod_poly_init(pol, mod.n); + nmod_poly_randtest(pol, state, len); + + // evals via general multipoint evaluation + nn_ptr evals_br = _nmod_vec_init(len); + if (len == 1) + evals_br[0] = nmod_poly_evaluate_nmod(pol, UWORD(1)); + else + nmod_poly_evaluate_nmod_vec(evals_br, pol, roots, len); + + // evals by DFT + ulong * p = _nmod_vec_init(len); + _nmod_vec_set(p, pol->coeffs, len); + + n_fft_dft(p, depth, F); + + int res = nmod_vec_red_equal(evals_br, p, len, mod); + + if (!res) + TEST_FUNCTION_FAIL( + "prime = %wu\n" + "root of unity = %wu\n" + "max_depth = %wu\n" + "depth = %wu\n" + "failed equality test\n", + prime, F->tab_w2[2*(max_depth-2)], max_depth, depth); + + res = nmod_vec_range(p, len, 4*mod.n); + + if (!res) + TEST_FUNCTION_FAIL( + "prime = %wu\n" + "root of unity = %wu\n" + "max_depth = %wu\n" + "depth = %wu\n" + "failed range test\n", + prime, F->tab_w2[2*(max_depth-2)], max_depth, depth); + + _nmod_vec_clear(p); + nmod_poly_clear(pol); + _nmod_vec_clear(evals_br); + } + + flint_free(roots); + n_fft_ctx_clear(F); + } + + TEST_FUNCTION_END(state); +} diff --git a/src/n_fft/test/t-idft.c b/src/n_fft/test/t-idft.c new file mode 100644 index 0000000000..14c8e1ca93 --- /dev/null +++ b/src/n_fft/test/t-idft.c @@ -0,0 +1,158 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "flint.h" +#include "test_helpers.h" +#include "ulong_extras.h" +#include "nmod.h" +#include "nmod_poly.h" +#include "nmod_vec.h" +#include "n_fft.h" + +#define MAX_EVAL_DEPTH 11 // must be <= 12 + +// vector equality up to reduction mod +/* int nmod_vec_red_equal(nn_srcptr vec1, nn_srcptr vec2, ulong len, nmod_t mod) */ +/* { */ +/* for (ulong k = 0; k < len; k++) */ +/* { */ +/* ulong v1; */ +/* ulong v2; */ +/* NMOD_RED(v1, vec1[k], mod); */ +/* NMOD_RED(v2, vec2[k], mod); */ +/* if (v1 != v2) */ +/* return 0; */ +/* } */ + +/* return 1; */ +/* } */ + +/* // testing that all elements of "vec" are less than "bound" */ +/* int nmod_vec_range(nn_srcptr vec, ulong len, ulong bound) */ +/* { */ +/* for (ulong k = 0; k < len; k++) */ +/* if (vec[k] >= bound) */ +/* return 0; */ + +/* return 1; */ +/* } */ + + +TEST_FUNCTION_START(n_fft_idft, state) +{ + int i; + + for (i = 0; i < 200 * flint_test_multiplier(); i++) + { + // take some FFT prime p with max_depth >= 12 + ulong max_depth, prime; + + // half of tests == fixed large prime, close to limit + // 62 bits: prime = 4611686018427322369 == 2**62 - 2**16 + 1 + // 30 bits: prime = 1073479681 == 2**30 - 2**18 + 1 + if (i > 100) +#if FLINT_BITS == 64 + prime = UWORD(4611686018427322369); +#else // FLINT_BITS == 32 + prime = UWORD(1073479681); +#endif + else + { + max_depth = 12 + n_randint(state, 6); + prime = 1 + (UWORD(1) << max_depth); + while (! n_is_prime(prime)) + prime += (UWORD(1) << max_depth); + } + max_depth = flint_ctz(prime-1); + + nmod_t mod; + nmod_init(&mod, prime); + + // init FFT root tables + n_fft_ctx_t F; + n_fft_ctx_init2(F, MAX_EVAL_DEPTH, prime); + + // retrieve roots, used later for multipoint evaluation + nn_ptr roots = flint_malloc((UWORD(1) << MAX_EVAL_DEPTH) * sizeof(ulong)); + for (ulong k = 0; k < (UWORD(1) << (MAX_EVAL_DEPTH-1)); k++) + { + roots[2*k] = F->tab_w[2*k]; + roots[2*k+1] = prime - F->tab_w[2*k]; // < prime since F->tab_w[2*k] != 0 + } + + for (ulong depth = 0; depth <= MAX_EVAL_DEPTH; depth++) + { + const ulong len = (UWORD(1) << depth); + + // choose random evals of degree == len + nn_ptr evals = flint_malloc(len * sizeof(ulong)); + for (ulong k = 0; k < len; k++) + evals[k] = n_randint(state, prime); + + // general interpolation + nmod_poly_t pol; + nmod_poly_init(pol, prime); + nmod_poly_interpolate_nmod_vec(pol, roots, evals, len); + + // evals by IDFT + ulong * p = _nmod_vec_init(len); + _nmod_vec_set(p, evals, len); + + n_fft_idft(p, depth, F); + + int res = _nmod_vec_equal(pol->coeffs, p, len); + + if (!res) + { + _nmod_vec_print(p, len, mod); + _nmod_vec_print(pol->coeffs, len, mod); + TEST_FUNCTION_FAIL( + "prime = %wu\n" + "root of unity = %wu\n" + "max_depth = %wu\n" + "depth = %wu\n" + "failed equality test\n", + prime, F->tab_w2[2*(max_depth-2)], max_depth, depth); + } + + //int res = nmod_vec_red_equal(evals_br, p, len, mod); + + //if (!res) + // TEST_FUNCTION_FAIL( + // "prime = %wu\n" + // "root of unity = %wu\n" + // "max_depth = %wu\n" + // "depth = %wu\n" + // "failed equality test\n", + // prime, F->tab_w2[2*(max_depth-2)], max_depth, depth); + + //res = nmod_vec_range(p, len, 4*mod.n); + + //if (!res) + // TEST_FUNCTION_FAIL( + // "prime = %wu\n" + // "root of unity = %wu\n" + // "max_depth = %wu\n" + // "depth = %wu\n" + // "failed range test\n", + // prime, F->tab_w2[2*(max_depth-2)], max_depth, depth); + + _nmod_vec_clear(p); + flint_free(evals); + nmod_poly_clear(pol); + } + + flint_free(roots); + n_fft_ctx_clear(F); + } + + TEST_FUNCTION_END(state); +} diff --git a/src/n_fft/test/t-init.c b/src/n_fft/test/t-init.c new file mode 100644 index 0000000000..30449469c6 --- /dev/null +++ b/src/n_fft/test/t-init.c @@ -0,0 +1,163 @@ +/* + Copyright (C) 2024 Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "ulong_extras.h" +#include "n_fft.h" + +// return bit reversal index of k for given nbits: +// e.g. br_index([0,1,2,3], 4) == [0, 8, 4, 12] +static inline ulong br_index(ulong k, ulong nbits) +{ + k = ((k >> 1) & 0x55555555) | ((k & 0x55555555) << 1); + k = ((k >> 2) & 0x33333333) | ((k & 0x33333333) << 2); + k = ((k >> 4) & 0x0F0F0F0F) | ((k & 0x0F0F0F0F) << 4); + k = ((k >> 8) & 0x00FF00FF) | ((k & 0x00FF00FF) << 8); + k = ( k >> 16 ) | ( k << 16); +#if FLINT_BITS == 64 + k = ( k >> 32 ) | ( k << 32); +#endif // FLINT_BITS == 64 + + return k >> (FLINT_BITS - nbits); +} + +int test_one(n_fft_ctx_t F, ulong max_depth, ulong depth, ulong p, flint_rand_t state) +{ + // if depth < 3, init is supposed to behave as if depth == 3 + depth = FLINT_MAX(3, depth); + + // check all basic attributes + if (F->mod != p) + return 1; + + if (F->max_depth != max_depth) + return 2; + + if ((1 + (F->cofactor << max_depth)) != p) + return 3; + + if (F->depth != depth) + return 4; + + // retrieve primitive root and its inverse + const ulong w = F->tab_w2[2*(max_depth-2)]; + const ulong iw = n_invmod(w, p); + + // check the primitive root + if (n_powmod2(w, UWORD(1)<tab_w2[2*k]; + if (w2 != n_powmod2(w, UWORD(1)<<(max_depth-2-k), p)) + return 6; + if (F->tab_w2[2*k+1] != n_mulmod_precomp_shoup(w2, p)) + return 7; + } + + // check all entries of tab_inv2 + for (ulong k = 0; k < max_depth; k++) + { + ulong inv2 = F->tab_inv2[2*k]; + if (inv2 != n_invmod((UWORD(1)<<(k+1)), p)) + return 8; + if (F->tab_inv2[2*k+1] != n_mulmod_precomp_shoup(inv2, p)) + return 9; + } + + // check a few random entries of tab_w and tab_iw + for (ulong j = 0; j < 1000; j++) + { + ulong k = n_randint(state, UWORD(1) << (F->depth - 1)); + ulong exp = br_index(k, F->max_depth - 1); + + ulong wk = F->tab_w[2*k]; + if (wk != n_powmod2(w, exp, p)) + return 10; + if (F->tab_w[2*k+1] != n_mulmod_precomp_shoup(wk, p)) + return 11; + + ulong iwk = F->tab_iw[2*k]; + if (iwk != n_powmod2(iw, exp, p)) + return 12; + if (F->tab_iw[2*k+1] != n_mulmod_precomp_shoup(iwk, p)) + return 13; + } + + return 0; +} + +TEST_FUNCTION_START(n_fft_ctx_init2, state) +{ + int i; + + for (i = 0; i < 1000 * flint_test_multiplier(); i++) + { + ulong p, max_depth; + if (i % 20 != 0) + { + // take random prime in [17, 2**(FLINT_BITS-2)) +#if FLINT_BITS == 64 + ulong bits = 5 + n_randint(state, 58); +#else + ulong bits = 5 + n_randint(state, 25); +#endif + p = n_randprime(state, bits, 1); + max_depth = flint_ctz(p-1); + + // we need p such that 8 divides p-1 + while (max_depth < 3) + { + p = n_randprime(state, bits, 1); + max_depth = flint_ctz(p-1); + } + } + else + { + // the above will most often have max_depth 3 or 4 + // every now and then we want p with larger max_depth +#if FLINT_BITS == 64 + max_depth = 40 + n_randint(state, 10); +#else + max_depth = 10 + n_randint(state, 10); +#endif + p = 1 + (UWORD(1) << max_depth); + while (! n_is_prime(p)) + p += (UWORD(1) << max_depth); + max_depth = flint_ctz(p-1); + } + + // take depth between 0 and min(12, max_depth) + ulong depth = n_randint(state, FLINT_MIN(12, max_depth)); + + // init + n_fft_ctx_t F; + n_fft_ctx_init2(F, depth, p); + + int res = test_one(F, max_depth, depth, p, state); + + if (res) + TEST_FUNCTION_FAIL( + "prime = %wu\n" + "root of unity = %wu\n" + "max_depth = %wu\n" + "depth = %wu\n" + "error code = %wu\n", + p, F->tab_w2[2*(max_depth-2)], max_depth, depth, res); + + n_fft_ctx_clear(F); + } + + TEST_FUNCTION_END(state); +} diff --git a/src/nmod_vec/profile/p-dot.c b/src/nmod_vec/profile/p-dot.c index 6d226710be..217f715704 100644 --- a/src/nmod_vec/profile/p-dot.c +++ b/src/nmod_vec/profile/p-dot.c @@ -9,9 +9,9 @@ (at your option) any later version. See . */ -#include #include // for atoi +#include "ulong_extras.h" #include "profiler.h" #include "nmod.h" #include "nmod_vec.h" diff --git a/src/ulong_extras/profile/p-powmod.c b/src/ulong_extras/profile/p-powmod.c new file mode 100644 index 0000000000..0a8e00c10e --- /dev/null +++ b/src/ulong_extras/profile/p-powmod.c @@ -0,0 +1,152 @@ +/* + Copyright 2024 (C) Vincent Neiger + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . + */ + +#include "profiler.h" +#include "ulong_extras.h" +#include "double_extras.h" + +#define NB_ITER 1000 + +typedef struct +{ + ulong bits; + ulong exp; +} info_t; + + +void sample_preinv(void * arg, ulong count) +{ + info_t * info = (info_t *) arg; + ulong exp = info->exp; + ulong bits = info->bits; + nn_ptr array = (nn_ptr) flint_malloc(NB_ITER*sizeof(ulong)); + FLINT_TEST_INIT(state); + + for (ulong i = 0; i < count; i++) + { + ulong n = n_randbits(state, bits); // 0 < n < 2**(FLINT_BITS) + ulong ninv = n_preinvert_limb(n); + ulong norm = flint_clz(n); + + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_randint(state, n); // 0 <= array[j] < n + + prof_start(); + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_powmod_ui_preinv(array[j], exp, n, ninv, norm); + prof_stop(); + } + + flint_free(array); + FLINT_TEST_CLEAR(state); +} + +void sample_preinv2(void * arg, ulong count) +{ + info_t * info = (info_t *) arg; + ulong exp = info->exp; + ulong bits = info->bits; + nn_ptr array = (nn_ptr) flint_malloc(NB_ITER*sizeof(ulong)); + FLINT_TEST_INIT(state); + + for (ulong i = 0; i < count; i++) + { + ulong n = n_randbits(state, bits); // 0 < n < 2**(FLINT_BITS) + ulong ninv = n_preinvert_limb(n); + + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_randlimb(state); + + prof_start(); + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_powmod2_ui_preinv(array[j], exp, n, ninv); + prof_stop(); + } + + flint_free(array); + FLINT_TEST_CLEAR(state); +} + +void sample_precomp(void * arg, ulong count) +{ + info_t * info = (info_t *) arg; + ulong exp = info->exp; + ulong bits = info->bits; + nn_ptr array = (nn_ptr) flint_malloc(NB_ITER*sizeof(ulong)); + FLINT_TEST_INIT(state); + + for (ulong i = 0; i < count; i++) + { + ulong n = n_randbits(state, bits); // 0 < n < 2**bits + double ninv = n_precompute_inverse(n); + + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_randint(state, n); // 0 <= array[j] < n + + prof_start(); + for (ulong j = 0; j < NB_ITER; j++) + array[j] = n_powmod_ui_precomp(array[j], exp, n, ninv); + prof_stop(); + } + + flint_free(array); + FLINT_TEST_CLEAR(state); +} + +int main(void) +{ + double min, max; + + const ulong bits_nb = 5; + ulong bits_list[] = {20, 30, 50, 60, 64}; + const ulong exp_nb = 11; + ulong exp_list[] = {5, 10, 20, 40, 80, 160, 1000, 10000, 100000, 1000000L, 10000000L}; + + flint_printf("compute an exponentiation a**e mod n, with nbits(n) = b\n"); + flint_printf(" computation is repeated on the element of a %wu-length array\n"); + flint_printf(" time is divided by %wu * FLINT_CLOCK_SCALE_FACTOR * log_2(exp)\n", NB_ITER, NB_ITER); + flint_printf("timings are: powmod_ui_precomp | powmod_ui_preinv | powmod2_ui_preinv\n"); + flint_printf("b \\ e\t"); + for (ulong e = 0; e < exp_nb; e++) + flint_printf("%wu\t\t", exp_list[e]); + flint_printf("\n"); + + info_t info; + + for (ulong b = 0; b < bits_nb; b++) + { + info.bits = bits_list[b]; + flint_printf("%wu\t", info.bits); + + for (ulong e = 0; e < exp_nb; e++) + { + info.exp = exp_list[e]; + double log_exp = d_log2((double)info.exp); + + if (info.bits <= 53) + { + prof_repeat(&min, &max, sample_precomp, (void *) &info); + flint_printf("%4.1f|", min/(NB_ITER * FLINT_CLOCK_SCALE_FACTOR * log_exp)); + } + else + flint_printf(" na |"); + + prof_repeat(&min, &max, sample_preinv, (void *) &info); + flint_printf("%4.1f|", min/(NB_ITER * FLINT_CLOCK_SCALE_FACTOR * log_exp)); + + prof_repeat(&min, &max, sample_preinv2, (void *) &info); + flint_printf("%4.1f\t", min/(NB_ITER * FLINT_CLOCK_SCALE_FACTOR * log_exp)); + } + flint_printf("\n"); + } + + return 0; +}