@@ -159,6 +159,40 @@ def test_data_vstack(self):
159
159
# not supported formats
160
160
self .assertRaises (TypeError , modAL .utils .data .data_vstack , (1 , 1 ))
161
161
162
+ # functions from modAL.utils.selection
163
+
164
+ def test_multi_argmax (self ):
165
+ for n_pool in range (2 , 100 ):
166
+ for n_instances in range (1 , n_pool ):
167
+ utility = np .zeros (n_pool )
168
+ max_idx = np .random .choice (range (n_pool ), size = n_instances , replace = False )
169
+ utility [max_idx ] = 1e-10 + np .random .rand (n_instances , )
170
+ np .testing .assert_equal (
171
+ np .sort (modAL .utils .selection .multi_argmax (utility , n_instances )),
172
+ np .sort (max_idx )
173
+ )
174
+
175
+ def test_shuffled_argmax (self ):
176
+ for n_pool in range (1 , 100 ):
177
+ for n_instances in range (1 , n_pool + 1 ):
178
+ values = np .random .permutation (n_pool )
179
+ true_query_idx = np .argsort (values )[:n_instances ]
180
+
181
+ np .testing .assert_equal (
182
+ true_query_idx ,
183
+ modAL .utils .selection .shuffled_argmax (values , n_instances )
184
+ )
185
+
186
+ def test_weighted_random (self ):
187
+ for n_pool in range (2 , 100 ):
188
+ for n_instances in range (1 , n_pool ):
189
+ utility = np .ones (n_pool )
190
+ query_idx = modAL .utils .selection .weighted_random (utility , n_instances )
191
+ # testing for correct number of returned indices
192
+ np .testing .assert_equal (len (query_idx ), n_instances )
193
+ # testing for uniqueness of each query index
194
+ np .testing .assert_equal (len (query_idx ), len (np .unique (query_idx )))
195
+
162
196
163
197
class TestAcquisitionFunctions (unittest .TestCase ):
164
198
def test_acquisition_functions (self ):
@@ -524,30 +558,6 @@ def test_entropy_sampling(self):
524
558
np .testing .assert_array_equal (query_idx , true_query_idx )
525
559
526
560
527
- class TestQueries (unittest .TestCase ):
528
-
529
- def test_multi_argmax (self ):
530
- for n_pool in range (2 , 100 ):
531
- for n_instances in range (1 , n_pool ):
532
- utility = np .zeros (n_pool )
533
- max_idx = np .random .choice (range (n_pool ), size = n_instances , replace = False )
534
- utility [max_idx ] = 1e-10 + np .random .rand (n_instances , )
535
- np .testing .assert_equal (
536
- np .sort (modAL .utils .selection .multi_argmax (utility , n_instances )),
537
- np .sort (max_idx )
538
- )
539
-
540
- def test_weighted_random (self ):
541
- for n_pool in range (2 , 100 ):
542
- for n_instances in range (1 , n_pool ):
543
- utility = np .ones (n_pool )
544
- query_idx = modAL .utils .selection .weighted_random (utility , n_instances )
545
- # testing for correct number of returned indices
546
- np .testing .assert_equal (len (query_idx ), n_instances )
547
- # testing for uniqueness of each query index
548
- np .testing .assert_equal (len (query_idx ), len (np .unique (query_idx )))
549
-
550
-
551
561
class TestActiveLearner (unittest .TestCase ):
552
562
553
563
def test_add_training_data (self ):
0 commit comments