Skip to content

Commit

Permalink
Added Binomial distribution (#101)
Browse files Browse the repository at this point in the history
Source code is inspired by the scipy implementation.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hans Dembinski <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent 7da0422 commit 3302ff0
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 41 deletions.
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ addopts = "-q -ra --ff"
testpaths = ["tests"]

[tool.ruff]
src = ["src"]

[tool.ruff.lint]
select = [
"E",
"F", # flake8
Expand All @@ -58,17 +61,16 @@ select = [
extend-ignore = [
"D212", # multi-line-summary-first-line
]
src = ["src"]
unfixable = [
"F841", # Removes unused variables
]

[tool.ruff.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"setup.py" = ["D"]
"tests/*.py" = ["B", "D"]
".ci/*.py" = ["D"]
"bench/*.py" = ["D"]
"docs/*.py" = ["D"]

[tool.ruff.pydocstyle]
convention = "numpy"
5 changes: 4 additions & 1 deletion src/numba_stats/_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def get(name, signature):
gammaincc = get("gammaincc", float64(float64, float64))
stdtrit = get("stdtrit", float64(float64, float64))
stdtr = get("stdtr", float64(float64, float64))
betainc = get("betainc", float64(float64, float64, float64))

# n-ary functions (double)
voigt_profile = get("voigt_profile", float64(float64, float64, float64))
xlogy = get("xlogy", float64(float64, float64))
xlog1py = get("xlog1py", float64(float64, float64))
# for some reason, getting betainc directly does not work! btdtr is an alias
betainc = get("btdtr", float64(float64, float64, float64))
53 changes: 36 additions & 17 deletions src/numba_stats/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,52 @@ def _readonly_carray(T):
return Array(T, 1, "A", readonly=True)


def _jit(arg, cache=True):
def _jit_custom(signatures, cache=True):
"""
Wrap numba.njit to reduce boilerplate code.
We want to build jitted functions with explicit signatures to restrict the argument
types which are used in the implemnetation to float32 or float64. We also want to
pass specific options consistently: error_model='numpy' and inline='always'. The
latter is important to profit from auto-parallelization of surrounding code.
"""
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")


def _jit(npar, *, narg=1, cache=True):
"""
Wrap numba.njit to reduce boilerplate code.
We want to build jitted functions with explicit signatures to restrict the argument
types which are used in the implemnetation to float32 or float64. We also want to
pass specific options consistently: error_model='numpy' and inline='always'. The
latter is important to profit from auto-parallelization of surrounding code.
This decorator builds signatures with "narg" array arguments followed by "npar"
scalar arguments, and it does that for the types float32 or float64.
Parameters
----------
arg : int
Number of arguments. If negative, all arguments of this function are scalars
and -arg is the number of arguments. If positive, the first argument is
an array, the others are scalars and arg is the number of scalar arguments.
npar : int
Number of scalar arguments.
narg : int, optional (default: 1)
Number of array arguments.
cache : bool, optional (default: True)
Whether to cache the compilation. We must turn this off if the function uses a
function pointer from Scipy.
"""
if isinstance(arg, list):
signatures = arg
else:
signatures = []
for T in (nb.float32, nb.float64):
if arg < 0:
sig = T(*([T] * -arg))
else:
sig = T[:](_readonly_carray(T), *[T for _ in range(arg)])
signatures.append(sig)
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")
assert npar >= 0
assert narg >= 0
signatures = []
for T in (nb.float32, nb.float64):
if narg == 0:
sig = T(*([T] * npar))
else:
sig = T[:](
*[_readonly_carray(T) for _ in range(narg)], *[T for _ in range(npar)]
)
signatures.append(sig)
return _jit_custom(signatures, cache=cache)


def _rvs_jit(arg, cache=True):
Expand All @@ -51,7 +70,7 @@ def _rvs_jit(arg, cache=True):
# extra args at the end are for size and random_state
sig = T[:](*[T for _ in range(arg)], nb.uint64, nb.optional(nb.uint64))
signatures.append(sig)
return nb.njit(signatures, cache=cache, inline="always", error_model="numpy")
return _jit_custom(signatures, cache=cache)


@nb.njit(cache=True)
Expand Down
15 changes: 9 additions & 6 deletions src/numba_stats/bernstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
"""

import numpy as np
from ._util import _jit, _Floats, _generate_wrappers, _readonly_carray, _trans
from ._util import (
_jit,
_Floats,
_generate_wrappers,
_trans,
)


@_jit([T[:](_readonly_carray(T), _readonly_carray(T)) for T in _Floats])
@_jit(0, narg=2)
def _de_castlejau(z, beta):
# De Casteljau algorithm, numerically stable
n = len(beta)
Expand Down Expand Up @@ -48,7 +53,7 @@ def _beta_int(beta):
return r


@_jit([T[:](_readonly_carray(T), _readonly_carray(T), T, T) for T in _Floats])
@_jit(2, narg=2)
def _density(x, beta, xmin, xmax):
"""
Return density described by a Bernstein polynomial.
Expand Down Expand Up @@ -90,9 +95,7 @@ def _density(x, beta, xmin, xmax):
return _de_castlejau(z, beta)


@_jit(
[T[:](_readonly_carray(T), _readonly_carray(T), T, T) for T in _Floats], cache=True
)
@_jit(2, narg=2)
def _integral(x, beta, xmin, xmax):
"""
Return integral of a Bernstein polynomial from xmin to x.
Expand Down
71 changes: 71 additions & 0 deletions src/numba_stats/binom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Binomial distribution.
See Also
--------
scipy.stats.binom: Scipy equivalent.
"""

import numpy as np
from ._special import xlogy as _xlogy, xlog1py as _xlog1py, betainc as _betainc
from math import lgamma as _lgamma
from ._util import _jit, _generate_wrappers, _prange, _seed
import numba as nb

_doc_par = """
k : int
Number of successes.
n : int
Number of trials.
p : float
Success probability for each trial.
"""


@_jit(1, narg=2, cache=False)
def _logpmf(k, n, p):
T = type(p)
r = np.empty(len(k), T)
one = T(1)
for i in _prange(len(r)):
combiln = _lgamma(n[i] + one) - (
_lgamma(k[i] + one) + _lgamma(n[i] - k[i] + one)
)
r[i] = combiln + _xlogy(k[i], p) + _xlog1py(n[i] - k[i], -p)
return r


@_jit(1, narg=2, cache=False)
def _pmf(k, n, p):
return np.exp(_logpmf(k, n, p))


@_jit(1, narg=2, cache=False)
def _cdf(k, n, p):
T = type(p)
r = np.empty(len(k), T)
one = T(1)
for i in _prange(len(r)):
if k[i] == n[i]:
r[i] = 1
elif p == 0:
r[i] = 1
elif p == 1:
r[i] = 0
else:
r[i] = 1 - _betainc(k[i] + one, n[i] - k[i], p)
return r


@nb.njit(
nb.int64[:](nb.uint64, nb.float32, nb.uint64, nb.optional(nb.uint64)),
cache=True,
inline="always",
error_model="numpy",
)
def _rvs(n, p, size, random_state):
_seed(random_state)
return np.random.binomial(n, p, size=size)


_generate_wrappers(globals())
8 changes: 4 additions & 4 deletions src/numba_stats/crystalball.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""


@_jit(-3)
@_jit(3, narg=0)
def _log_powerlaw(z, beta, m):
T = type(beta)
c = -T(0.5) * beta * beta
Expand All @@ -35,7 +35,7 @@ def _log_powerlaw(z, beta, m):
return log_a - m * np.log(b - z)


@_jit(-3)
@_jit(3, narg=0)
def _powerlaw_integral(z, beta, m):
T = type(beta)
exp_beta = np.exp(-T(0.5) * beta * beta)
Expand All @@ -45,14 +45,14 @@ def _powerlaw_integral(z, beta, m):
return a * (b - z) ** -m1 / m1


@_jit(-2)
@_jit(2, narg=0)
def _normal_integral(a, b):
T = type(a)
sqrt_half = np.sqrt(T(0.5))
return sqrt_half * np.sqrt(T(np.pi)) * (_erf(b * sqrt_half) - _erf(a * sqrt_half))


@_jit(-3)
@_jit(3, narg=0)
def _log_density(z, beta, m):
if z < -beta:
return _log_powerlaw(z, beta, m)
Expand Down
2 changes: 1 addition & 1 deletion src/numba_stats/crystalball_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


@_jit(-3)
@_jit(3, narg=0)
def _norm_half(beta, m, scale):
T = type(beta)
return (_powerlaw_integral(-beta, beta, m) + _normal_integral(-beta, T(0))) * scale
Expand Down
4 changes: 2 additions & 2 deletions src/numba_stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
"""


@_jit(-1)
@_jit(1, narg=0)
def _cdf1(z):
T = type(z)
return T(0) if z < 0 else -_expm1(-z)


@_jit(-1)
@_jit(1, narg=0)
def _ppf1(p):
return -_log1p(-p)

Expand Down
4 changes: 2 additions & 2 deletions src/numba_stats/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
"""


@_jit(-1)
@_jit(1, narg=0)
def _cdf1(z):
return 1.0 - 0.5 * np.exp(-z) if z > 0 else 0.5 * np.exp(z)


@_jit(-1)
@_jit(1, narg=0)
def _ppf1(p):
return -np.log(2 * (1 - p)) if p > 0.5 else np.log(2 * p)

Expand Down
6 changes: 3 additions & 3 deletions src/numba_stats/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@
"""


@_jit(-1)
@_jit(1, narg=0)
def _logpdf1(z):
T = type(z)
return -T(0.5) * (z * z + T(np.log(2 * np.pi)))


@_jit(-1)
@_jit(1, narg=0)
def _cdf1(z):
T = type(z)
c = T(np.sqrt(0.5))
return T(0.5) * (T(1.0) + _erf(z * c))


@_jit(-1, cache=False) # cannot cache because of _ndtri
@_jit(1, narg=0, cache=False) # cannot cache because of _ndtri
def _ppf1(p):
T = type(p)
return T(_ndtri(p))
Expand Down
43 changes: 43 additions & 0 deletions tests/test_binom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
from numba_stats import binom
import scipy.stats as sc
import pytest


# NC and KC are all combinations of n and k from 0 to 10
N = np.arange(10)
NC = []
KC = []
for n in N:
for k in range(n + 1):
NC.append(n)
KC.append(k)
NC = np.array(NC, np.float64)
KC = np.array(KC, np.float64)


@pytest.mark.parametrize("p", np.linspace(0, 1, 5))
def test_pmf(p):
print(KC, NC)
got = binom.pmf(KC, NC, p)
expected = sc.binom.pmf(KC, NC, p)
np.testing.assert_allclose(got, expected)


@pytest.mark.parametrize("p", np.linspace(0, 1, 5))
def test_cdf(p):
got = binom.cdf(KC, NC, p)
expected = sc.binom.cdf(KC, NC, p)
np.testing.assert_allclose(got, expected)


@pytest.mark.parametrize("n", np.linspace(0, 10, 6))
@pytest.mark.parametrize("p", np.linspace(0, 1, 5))
def test_rvs(n, p):
got = binom.rvs(n, p, size=1000, random_state=1)

def expected():
np.random.seed(1)
return np.random.binomial(n, p, 1000)

np.testing.assert_equal(got, expected())
16 changes: 16 additions & 0 deletions tests/test_special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from numba_stats import _special as sp
from scipy import special as sp_ref
import numba as nb
import pytest
import numpy as np


@pytest.mark.parametrize("a", [1, 2, 3])
@pytest.mark.parametrize("b", [1, 2, 3])
@pytest.mark.parametrize("x", [0.1, 0.5, 0.9])
def test_betainc(a, b, x):
@nb.njit
def betainc(a, b, x):
return sp.betainc(a, b, x)

np.testing.assert_allclose(betainc(a, b, x), sp_ref.betainc(a, b, x))

0 comments on commit 3302ff0

Please sign in to comment.