From 3302ff0fd83c228c98697349a063acef601ef041 Mon Sep 17 00:00:00 2001 From: Moritz Neuberger <31659079+MoritzNeuberger@users.noreply.github.com> Date: Tue, 6 Aug 2024 18:49:06 +0200 Subject: [PATCH] Added Binomial distribution (#101) Source code is inspired by the scipy implementation. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hans Dembinski --- pyproject.toml | 12 +++--- src/numba_stats/_special.py | 5 ++- src/numba_stats/_util.py | 53 +++++++++++++++-------- src/numba_stats/bernstein.py | 15 ++++--- src/numba_stats/binom.py | 71 +++++++++++++++++++++++++++++++ src/numba_stats/crystalball.py | 8 ++-- src/numba_stats/crystalball_ex.py | 2 +- src/numba_stats/expon.py | 4 +- src/numba_stats/laplace.py | 4 +- src/numba_stats/norm.py | 6 +-- tests/test_binom.py | 43 +++++++++++++++++++ tests/test_special.py | 16 +++++++ 12 files changed, 198 insertions(+), 41 deletions(-) create mode 100644 src/numba_stats/binom.py create mode 100644 tests/test_binom.py create mode 100644 tests/test_special.py diff --git a/pyproject.toml b/pyproject.toml index 9f89f02..b20c1ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ addopts = "-q -ra --ff" testpaths = ["tests"] [tool.ruff] +src = ["src"] + +[tool.ruff.lint] select = [ "E", "F", # flake8 @@ -58,17 +61,16 @@ select = [ extend-ignore = [ "D212", # multi-line-summary-first-line ] -src = ["src"] unfixable = [ "F841", # Removes unused variables ] -[tool.ruff.pydocstyle] -convention = "numpy" - -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "setup.py" = ["D"] "tests/*.py" = ["B", "D"] ".ci/*.py" = ["D"] "bench/*.py" = ["D"] "docs/*.py" = ["D"] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/src/numba_stats/_special.py b/src/numba_stats/_special.py index 73ef19e..8087e52 100644 --- a/src/numba_stats/_special.py +++ b/src/numba_stats/_special.py @@ -33,7 +33,10 @@ def get(name, signature): gammaincc = get("gammaincc", float64(float64, float64)) stdtrit = get("stdtrit", float64(float64, float64)) stdtr = get("stdtr", float64(float64, float64)) -betainc = get("betainc", float64(float64, float64, float64)) # n-ary functions (double) voigt_profile = get("voigt_profile", float64(float64, float64, float64)) +xlogy = get("xlogy", float64(float64, float64)) +xlog1py = get("xlog1py", float64(float64, float64)) +# for some reason, getting betainc directly does not work! btdtr is an alias +betainc = get("btdtr", float64(float64, float64, float64)) diff --git a/src/numba_stats/_util.py b/src/numba_stats/_util.py index 5c0c406..01dc7e9 100644 --- a/src/numba_stats/_util.py +++ b/src/numba_stats/_util.py @@ -16,7 +16,7 @@ def _readonly_carray(T): return Array(T, 1, "A", readonly=True) -def _jit(arg, cache=True): +def _jit_custom(signatures, cache=True): """ Wrap numba.njit to reduce boilerplate code. @@ -24,25 +24,44 @@ def _jit(arg, cache=True): types which are used in the implemnetation to float32 or float64. We also want to pass specific options consistently: error_model='numpy' and inline='always'. The latter is important to profit from auto-parallelization of surrounding code. + """ + return nb.njit(signatures, cache=cache, inline="always", error_model="numpy") + + +def _jit(npar, *, narg=1, cache=True): + """ + Wrap numba.njit to reduce boilerplate code. + + We want to build jitted functions with explicit signatures to restrict the argument + types which are used in the implemnetation to float32 or float64. We also want to + pass specific options consistently: error_model='numpy' and inline='always'. The + latter is important to profit from auto-parallelization of surrounding code. + + This decorator builds signatures with "narg" array arguments followed by "npar" + scalar arguments, and it does that for the types float32 or float64. Parameters ---------- - arg : int - Number of arguments. If negative, all arguments of this function are scalars - and -arg is the number of arguments. If positive, the first argument is - an array, the others are scalars and arg is the number of scalar arguments. + npar : int + Number of scalar arguments. + narg : int, optional (default: 1) + Number of array arguments. + cache : bool, optional (default: True) + Whether to cache the compilation. We must turn this off if the function uses a + function pointer from Scipy. """ - if isinstance(arg, list): - signatures = arg - else: - signatures = [] - for T in (nb.float32, nb.float64): - if arg < 0: - sig = T(*([T] * -arg)) - else: - sig = T[:](_readonly_carray(T), *[T for _ in range(arg)]) - signatures.append(sig) - return nb.njit(signatures, cache=cache, inline="always", error_model="numpy") + assert npar >= 0 + assert narg >= 0 + signatures = [] + for T in (nb.float32, nb.float64): + if narg == 0: + sig = T(*([T] * npar)) + else: + sig = T[:]( + *[_readonly_carray(T) for _ in range(narg)], *[T for _ in range(npar)] + ) + signatures.append(sig) + return _jit_custom(signatures, cache=cache) def _rvs_jit(arg, cache=True): @@ -51,7 +70,7 @@ def _rvs_jit(arg, cache=True): # extra args at the end are for size and random_state sig = T[:](*[T for _ in range(arg)], nb.uint64, nb.optional(nb.uint64)) signatures.append(sig) - return nb.njit(signatures, cache=cache, inline="always", error_model="numpy") + return _jit_custom(signatures, cache=cache) @nb.njit(cache=True) diff --git a/src/numba_stats/bernstein.py b/src/numba_stats/bernstein.py index eb9ac52..bc98a05 100644 --- a/src/numba_stats/bernstein.py +++ b/src/numba_stats/bernstein.py @@ -17,10 +17,15 @@ """ import numpy as np -from ._util import _jit, _Floats, _generate_wrappers, _readonly_carray, _trans +from ._util import ( + _jit, + _Floats, + _generate_wrappers, + _trans, +) -@_jit([T[:](_readonly_carray(T), _readonly_carray(T)) for T in _Floats]) +@_jit(0, narg=2) def _de_castlejau(z, beta): # De Casteljau algorithm, numerically stable n = len(beta) @@ -48,7 +53,7 @@ def _beta_int(beta): return r -@_jit([T[:](_readonly_carray(T), _readonly_carray(T), T, T) for T in _Floats]) +@_jit(2, narg=2) def _density(x, beta, xmin, xmax): """ Return density described by a Bernstein polynomial. @@ -90,9 +95,7 @@ def _density(x, beta, xmin, xmax): return _de_castlejau(z, beta) -@_jit( - [T[:](_readonly_carray(T), _readonly_carray(T), T, T) for T in _Floats], cache=True -) +@_jit(2, narg=2) def _integral(x, beta, xmin, xmax): """ Return integral of a Bernstein polynomial from xmin to x. diff --git a/src/numba_stats/binom.py b/src/numba_stats/binom.py new file mode 100644 index 0000000..c8df4e4 --- /dev/null +++ b/src/numba_stats/binom.py @@ -0,0 +1,71 @@ +""" +Binomial distribution. + +See Also +-------- +scipy.stats.binom: Scipy equivalent. +""" + +import numpy as np +from ._special import xlogy as _xlogy, xlog1py as _xlog1py, betainc as _betainc +from math import lgamma as _lgamma +from ._util import _jit, _generate_wrappers, _prange, _seed +import numba as nb + +_doc_par = """ +k : int + Number of successes. +n : int + Number of trials. +p : float + Success probability for each trial. +""" + + +@_jit(1, narg=2, cache=False) +def _logpmf(k, n, p): + T = type(p) + r = np.empty(len(k), T) + one = T(1) + for i in _prange(len(r)): + combiln = _lgamma(n[i] + one) - ( + _lgamma(k[i] + one) + _lgamma(n[i] - k[i] + one) + ) + r[i] = combiln + _xlogy(k[i], p) + _xlog1py(n[i] - k[i], -p) + return r + + +@_jit(1, narg=2, cache=False) +def _pmf(k, n, p): + return np.exp(_logpmf(k, n, p)) + + +@_jit(1, narg=2, cache=False) +def _cdf(k, n, p): + T = type(p) + r = np.empty(len(k), T) + one = T(1) + for i in _prange(len(r)): + if k[i] == n[i]: + r[i] = 1 + elif p == 0: + r[i] = 1 + elif p == 1: + r[i] = 0 + else: + r[i] = 1 - _betainc(k[i] + one, n[i] - k[i], p) + return r + + +@nb.njit( + nb.int64[:](nb.uint64, nb.float32, nb.uint64, nb.optional(nb.uint64)), + cache=True, + inline="always", + error_model="numpy", +) +def _rvs(n, p, size, random_state): + _seed(random_state) + return np.random.binomial(n, p, size=size) + + +_generate_wrappers(globals()) diff --git a/src/numba_stats/crystalball.py b/src/numba_stats/crystalball.py index f69694e..741d629 100644 --- a/src/numba_stats/crystalball.py +++ b/src/numba_stats/crystalball.py @@ -26,7 +26,7 @@ """ -@_jit(-3) +@_jit(3, narg=0) def _log_powerlaw(z, beta, m): T = type(beta) c = -T(0.5) * beta * beta @@ -35,7 +35,7 @@ def _log_powerlaw(z, beta, m): return log_a - m * np.log(b - z) -@_jit(-3) +@_jit(3, narg=0) def _powerlaw_integral(z, beta, m): T = type(beta) exp_beta = np.exp(-T(0.5) * beta * beta) @@ -45,14 +45,14 @@ def _powerlaw_integral(z, beta, m): return a * (b - z) ** -m1 / m1 -@_jit(-2) +@_jit(2, narg=0) def _normal_integral(a, b): T = type(a) sqrt_half = np.sqrt(T(0.5)) return sqrt_half * np.sqrt(T(np.pi)) * (_erf(b * sqrt_half) - _erf(a * sqrt_half)) -@_jit(-3) +@_jit(3, narg=0) def _log_density(z, beta, m): if z < -beta: return _log_powerlaw(z, beta, m) diff --git a/src/numba_stats/crystalball_ex.py b/src/numba_stats/crystalball_ex.py index 95b5cb7..c9cc897 100644 --- a/src/numba_stats/crystalball_ex.py +++ b/src/numba_stats/crystalball_ex.py @@ -31,7 +31,7 @@ """ -@_jit(-3) +@_jit(3, narg=0) def _norm_half(beta, m, scale): T = type(beta) return (_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, T(0))) * scale diff --git a/src/numba_stats/expon.py b/src/numba_stats/expon.py index d43578e..8653d84 100644 --- a/src/numba_stats/expon.py +++ b/src/numba_stats/expon.py @@ -18,13 +18,13 @@ """ -@_jit(-1) +@_jit(1, narg=0) def _cdf1(z): T = type(z) return T(0) if z < 0 else -_expm1(-z) -@_jit(-1) +@_jit(1, narg=0) def _ppf1(p): return -_log1p(-p) diff --git a/src/numba_stats/laplace.py b/src/numba_stats/laplace.py index d340848..59c9e06 100644 --- a/src/numba_stats/laplace.py +++ b/src/numba_stats/laplace.py @@ -17,12 +17,12 @@ """ -@_jit(-1) +@_jit(1, narg=0) def _cdf1(z): return 1.0 - 0.5 * np.exp(-z) if z > 0 else 0.5 * np.exp(z) -@_jit(-1) +@_jit(1, narg=0) def _ppf1(p): return -np.log(2 * (1 - p)) if p > 0.5 else np.log(2 * p) diff --git a/src/numba_stats/norm.py b/src/numba_stats/norm.py index b6b165b..ac78ac8 100644 --- a/src/numba_stats/norm.py +++ b/src/numba_stats/norm.py @@ -19,20 +19,20 @@ """ -@_jit(-1) +@_jit(1, narg=0) def _logpdf1(z): T = type(z) return -T(0.5) * (z * z + T(np.log(2 * np.pi))) -@_jit(-1) +@_jit(1, narg=0) def _cdf1(z): T = type(z) c = T(np.sqrt(0.5)) return T(0.5) * (T(1.0) + _erf(z * c)) -@_jit(-1, cache=False) # cannot cache because of _ndtri +@_jit(1, narg=0, cache=False) # cannot cache because of _ndtri def _ppf1(p): T = type(p) return T(_ndtri(p)) diff --git a/tests/test_binom.py b/tests/test_binom.py new file mode 100644 index 0000000..870a3c9 --- /dev/null +++ b/tests/test_binom.py @@ -0,0 +1,43 @@ +import numpy as np +from numba_stats import binom +import scipy.stats as sc +import pytest + + +# NC and KC are all combinations of n and k from 0 to 10 +N = np.arange(10) +NC = [] +KC = [] +for n in N: + for k in range(n + 1): + NC.append(n) + KC.append(k) +NC = np.array(NC, np.float64) +KC = np.array(KC, np.float64) + + +@pytest.mark.parametrize("p", np.linspace(0, 1, 5)) +def test_pmf(p): + print(KC, NC) + got = binom.pmf(KC, NC, p) + expected = sc.binom.pmf(KC, NC, p) + np.testing.assert_allclose(got, expected) + + +@pytest.mark.parametrize("p", np.linspace(0, 1, 5)) +def test_cdf(p): + got = binom.cdf(KC, NC, p) + expected = sc.binom.cdf(KC, NC, p) + np.testing.assert_allclose(got, expected) + + +@pytest.mark.parametrize("n", np.linspace(0, 10, 6)) +@pytest.mark.parametrize("p", np.linspace(0, 1, 5)) +def test_rvs(n, p): + got = binom.rvs(n, p, size=1000, random_state=1) + + def expected(): + np.random.seed(1) + return np.random.binomial(n, p, 1000) + + np.testing.assert_equal(got, expected()) diff --git a/tests/test_special.py b/tests/test_special.py new file mode 100644 index 0000000..da2ddbf --- /dev/null +++ b/tests/test_special.py @@ -0,0 +1,16 @@ +from numba_stats import _special as sp +from scipy import special as sp_ref +import numba as nb +import pytest +import numpy as np + + +@pytest.mark.parametrize("a", [1, 2, 3]) +@pytest.mark.parametrize("b", [1, 2, 3]) +@pytest.mark.parametrize("x", [0.1, 0.5, 0.9]) +def test_betainc(a, b, x): + @nb.njit + def betainc(a, b, x): + return sp.betainc(a, b, x) + + np.testing.assert_allclose(betainc(a, b, x), sp_ref.betainc(a, b, x))