Skip to content

Commit e6d658c

Browse files
committed
modAL.utils.combination.make_query_strategy tested
1 parent fea5343 commit e6d658c

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/core_tests.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,29 @@ def test_product(self):
101101
np.prod([X_in**exponent for exponent in exponents], axis=0)
102102
)
103103

104+
def test_make_query_strategy(self):
105+
query_strategy = modAL.utils.combination.make_query_strategy(
106+
utility_measure=modAL.uncertainty.classifier_uncertainty,
107+
selector=modAL.utils.selection.multi_argmax
108+
)
109+
110+
for n_samples in range(1, 10):
111+
for n_classes in range(1, 10):
112+
proba = np.random.rand(n_samples, n_classes)
113+
proba = proba/np.sum(proba, axis=1).reshape(n_samples, 1)
114+
X = np.random.rand(n_samples, 3)
115+
116+
learner = modAL.models.ActiveLearner(
117+
estimator=mock.MockClassifier(predict_proba_return=proba)
118+
)
119+
120+
query_1 = query_strategy(learner, X)
121+
query_2 = modAL.uncertainty.uncertainty_sampling(learner, X)
122+
123+
np.testing.assert_equal(query_1[0], query_2[0])
124+
np.testing.assert_almost_equal(query_1[1], query_2[1])
125+
126+
104127

105128
class TestUncertainties(unittest.TestCase):
106129

0 commit comments

Comments
 (0)