Skip to content

Commit 9b0f119

Browse files
committed
add: random tie break implemented for uncertainty sampling methods
1 parent a58f317 commit 9b0f119

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

modAL/uncertainty.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from sklearn.exceptions import NotFittedError
99
from sklearn.base import BaseEstimator
1010

11-
from modAL.utils.selection import multi_argmax
1211
from modAL.utils.data import modALinput
12+
from modAL.utils.selection import multi_argmax, shuffled_argmax
1313

1414

1515
def _proba_uncertainty(proba: np.ndarray) -> np.ndarray:
@@ -131,61 +131,83 @@ def classifier_entropy(classifier: BaseEstimator, X: modALinput, **predict_proba
131131

132132

133133
def uncertainty_sampling(classifier: BaseEstimator, X: modALinput,
134-
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
134+
n_instances: int = 1, random_tie_break: bool = False,
135+
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
135136
"""
136137
Uncertainty sampling query strategy. Selects the least sure instances for labelling.
137138
138139
Args:
139140
classifier: The classifier for which the labels are to be queried.
140141
X: The pool of samples to query from.
141142
n_instances: Number of samples to be queried.
142-
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
143+
random_tie_break: If True, shuffles utility scores to randomize the order. This
144+
can be used to break the tie when the highest utility score is not unique.
145+
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
146+
measure function.
143147
144148
Returns:
145-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
149+
The indices of the instances from X chosen to be labelled;
150+
the instances from X chosen to be labelled.
146151
"""
147152
uncertainty = classifier_uncertainty(classifier, X, **uncertainty_measure_kwargs)
148-
query_idx = multi_argmax(uncertainty, n_instances=n_instances)
153+
154+
if not random_tie_break:
155+
query_idx = multi_argmax(uncertainty, n_instances=n_instances)
156+
else:
157+
query_idx = shuffled_argmax(uncertainty, n_instances=n_instances)
149158

150159
return query_idx, X[query_idx]
151160

152161

153162
def margin_sampling(classifier: BaseEstimator, X: modALinput,
154-
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
163+
n_instances: int = 1, random_tie_break: bool = False,
164+
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
155165
"""
156166
Margin sampling query strategy. Selects the instances where the difference between the first most likely and second
157167
most likely classes are the smallest.
158-
159168
Args:
160169
classifier: The classifier for which the labels are to be queried.
161170
X: The pool of samples to query from.
162171
n_instances: Number of samples to be queried.
163172
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
164-
165173
Returns:
166174
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
167175
"""
168176
margin = classifier_margin(classifier, X, **uncertainty_measure_kwargs)
169-
query_idx = multi_argmax(-margin, n_instances=n_instances)
177+
178+
if not random_tie_break:
179+
query_idx = multi_argmax(-margin, n_instances=n_instances)
180+
else:
181+
query_idx = shuffled_argmax(-margin, n_instances=n_instances)
170182

171183
return query_idx, X[query_idx]
172184

173185

174186
def entropy_sampling(classifier: BaseEstimator, X: modALinput,
175-
n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
187+
n_instances: int = 1, random_tie_break: bool = False,
188+
**uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
176189
"""
177-
Entropy sampling query strategy. Selects the instances where the class probabilities have the largest entropy.
190+
Entropy sampling query strategy. Selects the instances where the class probabilities
191+
have the largest entropy.
178192
179193
Args:
180194
classifier: The classifier for which the labels are to be queried.
181195
X: The pool of samples to query from.
182196
n_instances: Number of samples to be queried.
183-
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
197+
random_tie_break: If True, shuffles utility scores to randomize the order. This
198+
can be used to break the tie when the highest utility score is not unique.
199+
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
200+
measure function.
184201
185202
Returns:
186-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
203+
The indices of the instances from X chosen to be labelled;
204+
the instances from X chosen to be labelled.
187205
"""
188206
entropy = classifier_entropy(classifier, X, **uncertainty_measure_kwargs)
189-
query_idx = multi_argmax(entropy, n_instances=n_instances)
207+
208+
if not random_tie_break:
209+
query_idx = multi_argmax(entropy, n_instances=n_instances)
210+
else:
211+
query_idx = shuffled_argmax(entropy, n_instances=n_instances)
190212

191213
return query_idx, X[query_idx]

tests/core_tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,10 @@ def test_uncertainty_sampling(self):
528528
query_idx, query_instance = modAL.uncertainty.uncertainty_sampling(
529529
classifier, np.random.rand(n_samples, n_classes)
530530
)
531+
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.uncertainty_sampling(
532+
classifier, np.random.rand(n_samples, n_classes),
533+
random_tie_break=True
534+
)
531535
np.testing.assert_array_equal(query_idx, true_query_idx)
532536

533537
def test_margin_sampling(self):
@@ -541,6 +545,10 @@ def test_margin_sampling(self):
541545
query_idx, query_instance = modAL.uncertainty.margin_sampling(
542546
classifier, np.random.rand(n_samples, n_classes)
543547
)
548+
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.margin_sampling(
549+
classifier, np.random.rand(n_samples, n_classes),
550+
random_tie_break=True
551+
)
544552
np.testing.assert_array_equal(query_idx, true_query_idx)
545553

546554
def test_entropy_sampling(self):
@@ -555,6 +563,10 @@ def test_entropy_sampling(self):
555563
query_idx, query_instance = modAL.uncertainty.entropy_sampling(
556564
classifier, np.random.rand(n_samples, n_classes)
557565
)
566+
shuffled_query_idx, shuffled_query_instance = modAL.uncertainty.entropy_sampling(
567+
classifier, np.random.rand(n_samples, n_classes),
568+
random_tie_break=True
569+
)
558570
np.testing.assert_array_equal(query_idx, true_query_idx)
559571

560572

0 commit comments

Comments
 (0)