Skip to content

Commit e683252

Browse files
committed
Merge branch 'develop' into release-1.10
2 parents 7379dd2 + 4ed489c commit e683252

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed

tests/inputs.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=missing-docstring
3+
"""Tests for the entropy() function."""
4+
import numpy
5+
6+
from make_test_ref import SEED
7+
8+
9+
class Pmf:
10+
"""PMF class.
11+
12+
Parameters
13+
----------
14+
alpha : float
15+
Concentration parameter.
16+
k : int
17+
Alphabet size.
18+
zero : float or None
19+
Fraction of bins with exactly zero probability.
20+
21+
"""
22+
23+
def __init__(self, alpha=0.1, k=10000, zero=0):
24+
numpy.random.seed(SEED)
25+
self.alpha = alpha
26+
self.k = k
27+
self.zero = zero
28+
self._pk = self._generate_pk(self.alpha, self.k, self.zero)
29+
self._entropy = None
30+
31+
@property
32+
def pk(self):
33+
return self._pk
34+
35+
@staticmethod
36+
def _generate_pk(alpha, k, zero=0):
37+
"""Return a Dirichlet sample."""
38+
pk = numpy.random.dirichlet([alpha] * k)
39+
if zero:
40+
n_zero = numpy.random.binomial(k, zero)
41+
pk[:n_zero] = 0
42+
pk /= pk.sum()
43+
pk = pk[n_zero:]
44+
return pk
45+
46+
def randomize(self):
47+
"""Reset pk to a random pmf."""
48+
self._pk = self._generate_pk(self.alpha, self.k, self.zero)
49+
self._entropy = None
50+
return self
51+
52+
@staticmethod
53+
def entropy_from_pmf(a):
54+
pk = numpy.asarray(a)
55+
pk = pk[pk > 0]
56+
return -numpy.sum(pk * numpy.log(pk))
57+
58+
@property
59+
def entropy(self):
60+
"""Entropy for PMF"""
61+
if self._entropy is None:
62+
self._entropy = self.entropy_from_pmf(self.pk)
63+
return self._entropy
64+
65+
66+
class Counts:
67+
def __init__(self, n=100, pmf=None, **kwargs):
68+
"""
69+
Counts class.
70+
71+
Parameters
72+
----------
73+
n : int
74+
Number of samples.
75+
pmf : Pmf object, optional
76+
Alphabet size.
77+
78+
"""
79+
numpy.random.seed(SEED)
80+
self.n = n
81+
82+
if pmf is not None:
83+
self.pmf = pmf
84+
else:
85+
self.pmf = Pmf(**kwargs)
86+
87+
self._nk = self._generate_nk(self.n, self.pmf.pk)
88+
self._entropy = None
89+
90+
@property
91+
def nk(self):
92+
return self._nk
93+
94+
@staticmethod
95+
def _generate_nk(n, pk):
96+
"""Return a Multinomial sample."""
97+
return numpy.random.multinomial(n, pk)
98+
99+
def randomize(self):
100+
self.nk = self._generate_nk(self.n, self.pmf.pk)
101+
self._entropy = None
102+
return self
103+
104+
@staticmethod
105+
def entropy_from_counts(a, estimator, **kwargs):
106+
nk = numpy.asarray(a)
107+
estimator.fit(nk, **kwargs)
108+
return estimator.estimate_
109+
110+
def entropy(self, estimator, **kwargs):
111+
"""Entropy estimate from counts using `estimator`."""
112+
if self._entropy is None:
113+
self._entropy = self.entropy_from_counts(self.nk, estimator,
114+
**kwargs)
115+
return self._entropy

tests/test_entropy.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=missing-docstring
3+
# pylint: disable=redefined-outer-name
4+
"""Tests for the entropy() function."""
5+
import numpy
6+
import pytest
7+
8+
import ndd
9+
from inputs import Counts, Pmf
10+
from make_test_ref import approx
11+
12+
13+
@pytest.fixture
14+
def pmf():
15+
return Pmf()
16+
17+
18+
@pytest.fixture
19+
def pmf_with_zeros():
20+
return Pmf(zero=0.5)
21+
22+
23+
@pytest.fixture
24+
def counts():
25+
return Counts()
26+
27+
28+
def test_pmf(pmf):
29+
ref = pmf.entropy
30+
assert ndd.entropy(pmf.pk) == approx(ref)
31+
32+
33+
def test_pmf_with_zeros(pmf_with_zeros):
34+
ref = pmf_with_zeros.entropy
35+
print(pmf_with_zeros.pk.sum())
36+
assert ndd.entropy(pmf_with_zeros.pk) == approx(ref)
37+
38+
39+
def test_counts(counts):
40+
estimator = ndd.estimators.AutoEstimator()
41+
assert ndd.entropy(counts.nk) == approx(
42+
counts.entropy(estimator=estimator))
43+
44+
45+
def test_unnormalized_pmf():
46+
counts = numpy.random.random(size=100) # pylint: disable=no-member
47+
pk = counts / counts.sum()
48+
assert ndd.entropy(counts) == approx(Pmf().entropy_from_pmf(pk))

tests/test_estimators.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=missing-docstring
3+
# pylint: disable=redefined-outer-name
4+
"""Estimators tests."""
5+
import pytest
6+
7+
import ndd
8+
from inputs import Pmf
9+
from make_test_ref import approx
10+
11+
12+
@pytest.fixture
13+
def pmf():
14+
return Pmf()
15+
16+
17+
def test_PmfPlugin(pmf):
18+
"""Test estimator from PMF."""
19+
estimator = ndd.estimators.PmfPlugin()
20+
ref = pmf.entropy
21+
assert estimator(pmf.pk) == approx(ref)

0 commit comments

Comments
 (0)