diff --git a/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java b/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java index 494720d7..35404a8c 100644 --- a/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java +++ b/fastdoubleparser-dev/src/main/java/ch.randelshofer.fastdoubleparser/ch/randelshofer/fastdoubleparser/FftMultiplier.java @@ -48,7 +48,7 @@ class FftMultiplier { /** * for FFTs of length up to 2^19 */ - private static final int ROOTS_CACHE2_SIZE = 20; + private static final int ROOTS2_CACHE_SIZE = 20; /** * The threshold value for using 3-way Toom-Cook multiplication. */ @@ -58,7 +58,7 @@ class FftMultiplier { * elements representing all (2^(k+2))-th roots between 0 and pi/2. * Used for FFT multiplication. */ - private volatile static ComplexVector[] ROOTS2_CACHE = new ComplexVector[ROOTS_CACHE2_SIZE]; + private volatile static ComplexVector[] ROOTS2_CACHE = new ComplexVector[ROOTS2_CACHE_SIZE]; /** * Sets of complex roots of unity. The set at index k contains 3*2^k * elements representing all (3*2^(k+2))-th roots between 0 and pi/2. @@ -66,6 +66,11 @@ class FftMultiplier { */ private volatile static ComplexVector[] ROOTS3_CACHE = new ComplexVector[ROOTS3_CACHE_SIZE]; + private static final ComplexVector ONE; + static { + ONE = new ComplexVector(1); + ONE.set(0, 1.0, 0.0); + } /** * Returns the maximum number of bits that one double precision number can fit without * causing the multiplication to be incorrect. @@ -118,10 +123,7 @@ static int bitsPerFftPoint(int bitLen) { */ private static ComplexVector calculateRootsOfUnity(int n) { if (n == 1) { - ComplexVector v = new ComplexVector(1); - v.real(0, 1); - v.imag(0, 0); - return v; + return ONE; } ComplexVector roots = new ComplexVector(n); roots.set(0, 1.0, 0.0); @@ -139,6 +141,36 @@ private static ComplexVector calculateRootsOfUnity(int n) { return roots; } + private static ComplexVector calculateRootsOfUnity(int n, ComplexVector prev) { + if (n == 1) { + return ONE; + } + ComplexVector roots = new ComplexVector(n); + roots.set(0, 1.0, 0.0); + double cos = COS_0_25; + double sin = SIN_0_25; + roots.set(n / 2, cos, sin); + + double angleTerm = 0.5 * Math.PI / n; + int ratio = n / prev.length; + for (int i = 1, j = 1; j < n / 2; i++, j += ratio) { + for (int k = 0; k < ratio - 1; k++) { + int outIdx = j + k; + double angle = angleTerm * outIdx; + cos = Math.cos(angle); + sin = Math.sin(angle); + roots.set(outIdx, cos, sin); + roots.set(n - outIdx, sin, cos); + } + cos = prev.real(i); + sin = prev.imag(i); + int outIdx = j + ratio - 1; + roots.set(outIdx, cos, sin); + roots.set(n - outIdx, sin, cos); + } + return roots; + } + /** * Performs an FFT of length 2^n on the vector {@code a}. * This is a decimation-in-frequency implementation. @@ -348,21 +380,33 @@ static BigInteger fromFftVector(ComplexVector fftVec, int signum, int bitsPerFft * * @param logN for a transform of length 2^logN */ - private static ComplexVector[] getRootsOfUnity2(int logN) { + static ComplexVector[] getRootsOfUnity2(int logN) { ComplexVector[] roots = new ComplexVector[logN + 1]; - for (int i = logN; i >= 0; i -= 2) { - if (i < ROOTS_CACHE2_SIZE) { + for (int i = logN % 2; i <= logN; i += 2) { + if (i < ROOTS2_CACHE_SIZE) { if (ROOTS2_CACHE[i] == null) { - ROOTS2_CACHE[i] = calculateRootsOfUnity(1 << i); + ROOTS2_CACHE[i] = getRootOfUnity(1, i, ROOTS2_CACHE); } roots[i] = ROOTS2_CACHE[i]; } else { - roots[i] = calculateRootsOfUnity(1 << i); + roots[i] = getRootOfUnity(1, i, ROOTS2_CACHE); } } return roots; } + private static ComplexVector getRootOfUnity(int b, int e, ComplexVector[] roots) { + int nearest = floorEntry(e, roots); + return nearest >= 2 + ? calculateRootsOfUnity(b << e, roots[nearest]) + : calculateRootsOfUnity(b << e); + } + + private static int floorEntry(int i, ComplexVector[] roots) { + while (i >= 2 && roots[i] == null) { i--; } + return i; + } + /** * Returns sets of complex roots of unity. For k=logN, logN-2, logN-4, ..., * the return value contains all k-th roots between 0 and pi/2. @@ -372,11 +416,11 @@ private static ComplexVector[] getRootsOfUnity2(int logN) { private static ComplexVector getRootsOfUnity3(int logN) { if (logN < ROOTS3_CACHE_SIZE) { if (ROOTS3_CACHE[logN] == null) { - ROOTS3_CACHE[logN] = calculateRootsOfUnity(3 << logN); + ROOTS3_CACHE[logN] = getRootOfUnity(3, logN, ROOTS3_CACHE); } return ROOTS3_CACHE[logN]; } else { - return calculateRootsOfUnity(3 << logN); + return getRootOfUnity(3, logN, ROOTS3_CACHE); } }