Skip to content

Commit a58f317

Browse files
committed
add: tests for shuffled_argmax
1 parent 61ad80d commit a58f317

File tree

1 file changed

+34
-24
lines changed

1 file changed

+34
-24
lines changed

tests/core_tests.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,40 @@ def test_data_vstack(self):
159159
# not supported formats
160160
self.assertRaises(TypeError, modAL.utils.data.data_vstack, (1, 1))
161161

162+
# functions from modAL.utils.selection
163+
164+
def test_multi_argmax(self):
165+
for n_pool in range(2, 100):
166+
for n_instances in range(1, n_pool):
167+
utility = np.zeros(n_pool)
168+
max_idx = np.random.choice(range(n_pool), size=n_instances, replace=False)
169+
utility[max_idx] = 1e-10 + np.random.rand(n_instances, )
170+
np.testing.assert_equal(
171+
np.sort(modAL.utils.selection.multi_argmax(utility, n_instances)),
172+
np.sort(max_idx)
173+
)
174+
175+
def test_shuffled_argmax(self):
176+
for n_pool in range(1, 100):
177+
for n_instances in range(1, n_pool+1):
178+
values = np.random.permutation(n_pool)
179+
true_query_idx = np.argsort(values)[:n_instances]
180+
181+
np.testing.assert_equal(
182+
true_query_idx,
183+
modAL.utils.selection.shuffled_argmax(values, n_instances)
184+
)
185+
186+
def test_weighted_random(self):
187+
for n_pool in range(2, 100):
188+
for n_instances in range(1, n_pool):
189+
utility = np.ones(n_pool)
190+
query_idx = modAL.utils.selection.weighted_random(utility, n_instances)
191+
# testing for correct number of returned indices
192+
np.testing.assert_equal(len(query_idx), n_instances)
193+
# testing for uniqueness of each query index
194+
np.testing.assert_equal(len(query_idx), len(np.unique(query_idx)))
195+
162196

163197
class TestAcquisitionFunctions(unittest.TestCase):
164198
def test_acquisition_functions(self):
@@ -524,30 +558,6 @@ def test_entropy_sampling(self):
524558
np.testing.assert_array_equal(query_idx, true_query_idx)
525559

526560

527-
class TestQueries(unittest.TestCase):
528-
529-
def test_multi_argmax(self):
530-
for n_pool in range(2, 100):
531-
for n_instances in range(1, n_pool):
532-
utility = np.zeros(n_pool)
533-
max_idx = np.random.choice(range(n_pool), size=n_instances, replace=False)
534-
utility[max_idx] = 1e-10 + np.random.rand(n_instances, )
535-
np.testing.assert_equal(
536-
np.sort(modAL.utils.selection.multi_argmax(utility, n_instances)),
537-
np.sort(max_idx)
538-
)
539-
540-
def test_weighted_random(self):
541-
for n_pool in range(2, 100):
542-
for n_instances in range(1, n_pool):
543-
utility = np.ones(n_pool)
544-
query_idx = modAL.utils.selection.weighted_random(utility, n_instances)
545-
# testing for correct number of returned indices
546-
np.testing.assert_equal(len(query_idx), n_instances)
547-
# testing for uniqueness of each query index
548-
np.testing.assert_equal(len(query_idx), len(np.unique(query_idx)))
549-
550-
551561
class TestActiveLearner(unittest.TestCase):
552562

553563
def test_add_training_data(self):

0 commit comments

Comments
 (0)