diff --git a/src/dspeed/processors/poly_fit.py b/src/dspeed/processors/poly_fit.py index 44fd20f..fceaa66 100644 --- a/src/dspeed/processors/poly_fit.py +++ b/src/dspeed/processors/poly_fit.py @@ -5,12 +5,38 @@ import numpy as np from numba import guvectorize +from ..utils import GUFuncWrapper from ..utils import numba_defaults_kwargs as nb_kwargs +@guvectorize( + [ + "void(float32[:], float64[::1,::1], float32[:])", + "void(float64[:], float64[::1,::1], float64[:])", + ], + "(n),(m,m)->(m)", + **nb_kwargs, +) +def _poly_fitter(w_in: np.ndarray, inv: np.ndarray, poly_pars: np.ndarray) -> None: + """Helper function that fits w_in to order `len(poly_pars)-1` polynomial, + while providing necessary inverse matrix. + """ + if np.isnan(w_in).any(): + return + + arr = np.zeros(len(poly_pars), dtype="float") + for i in range(0, len(w_in), 1): + for j in range(len(poly_pars)): + arr[j] += w_in[i] * (i**j) + + poly_pars[:] = inv @ arr + + def poly_fit(length, deg): + """Factory function for generating a polynomial fitter for an input of length + `length` to a polynomial of order `deg`.""" - vals_array = np.zeros(2 * deg + 1, dtype="float") + vals_array = np.zeros(2 * deg + 1, dtype="float64") for i in range(length): # linear regression @@ -23,26 +49,15 @@ def poly_fit(length, deg): inv = np.linalg.inv(mat) - @guvectorize( - [ - "void(float32[:], float32[:])", - "void(float64[:], float64[:])", - ], + return GUFuncWrapper( + lambda w_in, poly_pars: _poly_fitter(w_in, inv, poly_pars), "(n),(m)", + ["ff", "dd"], + name="poly_fitter", + vectorized=True, + copy_out=False, + doc_string=f"Fit w_in to order {deg} polynomial.", ) - def poly_fitter(w_in: np.ndarray, poly_pars) -> None: - - if np.isnan(w_in).any(): - return - - arr = np.zeros(deg + 1, dtype="float") - for i in range(0, len(w_in), 1): - for j in range(deg + 1): - arr[j] += w_in[i] * (i**j) - - poly_pars[:] = inv @ arr - - return poly_fitter @guvectorize( diff --git a/tests/processors/test_poly_fit.py b/tests/processors/test_poly_fit.py new file mode 100644 index 0000000..f08c974 --- /dev/null +++ b/tests/processors/test_poly_fit.py @@ -0,0 +1,19 @@ +import numpy as np + +from dspeed.processors import poly_fit + + +def test_poly_fit(compare_numba_vs_python): + """Test polynomial fitter""" + + # cubic polynomial coefficients + values + coeffs = np.array([5.0, 3.0, 1.0, -1.0]) + x = np.arange(10) + y = sum(c * x**i for i, c in enumerate(coeffs)) + + # generate processor + filt = poly_fit(len(x), len(coeffs) - 1) + + coeffs_out = np.zeros_like(coeffs) + compare_numba_vs_python(filt, y, coeffs_out) + assert np.all(np.isclose(coeffs_out, coeffs))