Skip to content

Commit bb3b579

Browse files
committed
add: tests for disagreement sampling functions added
1 parent e0af35f commit bb3b579

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

modAL/disagreement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,5 +214,5 @@ def max_std_sampling(regressor: BaseEstimator, X: modALinput,
214214
query_idx = multi_argmax(std, n_instances=n_instances)
215215
else:
216216
query_idx = shuffled_argmax(std, n_instances=n_instances)
217-
218-
return query_idx, X[query_idx]
217+
218+
return query_idx, X[query_idx]

tests/core_tests.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from collections import namedtuple
2222

2323
from sklearn.ensemble import RandomForestClassifier
24+
from sklearn.gaussian_process import GaussianProcessRegressor
2425
from sklearn.exceptions import NotFittedError
2526
from sklearn.metrics import confusion_matrix
2627
from sklearn.svm import SVC
@@ -417,6 +418,39 @@ def test_KL_max_disagreement(self):
417418
)
418419
np.testing.assert_almost_equal(returned_KL_disagreement, true_KL_disagreement)
419420

421+
def test_vote_entropy_sampling(self):
422+
for n_samples, n_features, n_classes in product(range(1, 10), range(1, 10), range(1, 10)):
423+
committee = mock.MockCommittee(classes_=np.asarray(range(n_classes)),
424+
vote_return=np.zeros(shape=(n_samples, n_classes), dtype=np.int16))
425+
modAL.disagreement.vote_entropy_sampling(committee, np.random.rand(n_samples, n_features))
426+
modAL.disagreement.vote_entropy_sampling(committee, np.random.rand(n_samples, n_features),
427+
random_tie_break=True)
428+
429+
def test_consensus_entropy_sampling(self):
430+
for n_samples, n_features, n_classes in product(range(1, 10), range(1, 10), range(1, 10)):
431+
committee = mock.MockCommittee(predict_proba_return=np.random.rand(n_samples, n_classes))
432+
modAL.disagreement.consensus_entropy_sampling(committee, np.random.rand(n_samples, n_features))
433+
modAL.disagreement.consensus_entropy_sampling(committee, np.random.rand(n_samples, n_features),
434+
random_tie_break=True)
435+
436+
def test_max_disagreement_sampling(self):
437+
for n_samples, n_features, n_classes, n_learners in product(range(1, 10), range(1, 10), range(1, 10), range(2, 5)):
438+
committee = mock.MockCommittee(
439+
n_learners=n_learners, classes_=range(n_classes),
440+
vote_proba_return=np.zeros(shape=(n_samples, n_learners, n_classes))
441+
)
442+
modAL.disagreement.max_disagreement_sampling(committee, np.random.rand(n_samples, n_features))
443+
modAL.disagreement.max_disagreement_sampling(committee, np.random.rand(n_samples, n_features),
444+
random_tie_break=True)
445+
446+
def test_max_std_sampling(self):
447+
for n_samples, n_features in product(range(1, 10), range(1, 10)):
448+
regressor = GaussianProcessRegressor()
449+
regressor.fit(np.random.rand(n_samples, n_features), np.random.rand(n_samples))
450+
modAL.disagreement.max_std_sampling(regressor, np.random.rand(n_samples, n_features))
451+
modAL.disagreement.max_std_sampling(regressor, np.random.rand(n_samples, n_features),
452+
random_tie_break=True)
453+
420454

421455
class TestEER(unittest.TestCase):
422456
def test_eer(self):

0 commit comments

Comments
 (0)