|
8 | 8 | from sklearn.exceptions import NotFittedError
|
9 | 9 | from sklearn.base import BaseEstimator
|
10 | 10 |
|
11 |
| -from modAL.utils.selection import multi_argmax |
12 | 11 | from modAL.utils.data import modALinput
|
| 12 | +from modAL.utils.selection import multi_argmax, shuffled_argmax |
13 | 13 |
|
14 | 14 |
|
15 | 15 | def _proba_uncertainty(proba: np.ndarray) -> np.ndarray:
|
@@ -131,61 +131,83 @@ def classifier_entropy(classifier: BaseEstimator, X: modALinput, **predict_proba
|
131 | 131 |
|
132 | 132 |
|
133 | 133 | def uncertainty_sampling(classifier: BaseEstimator, X: modALinput,
|
134 |
| - n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
| 134 | + n_instances: int = 1, random_tie_break: bool = False, |
| 135 | + **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
135 | 136 | """
|
136 | 137 | Uncertainty sampling query strategy. Selects the least sure instances for labelling.
|
137 | 138 |
|
138 | 139 | Args:
|
139 | 140 | classifier: The classifier for which the labels are to be queried.
|
140 | 141 | X: The pool of samples to query from.
|
141 | 142 | n_instances: Number of samples to be queried.
|
142 |
| - **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function. |
| 143 | + random_tie_break: If True, shuffles utility scores to randomize the order. This |
| 144 | + can be used to break the tie when the highest utility score is not unique. |
| 145 | + **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty |
| 146 | + measure function. |
143 | 147 |
|
144 | 148 | Returns:
|
145 |
| - The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled. |
| 149 | + The indices of the instances from X chosen to be labelled; |
| 150 | + the instances from X chosen to be labelled. |
146 | 151 | """
|
147 | 152 | uncertainty = classifier_uncertainty(classifier, X, **uncertainty_measure_kwargs)
|
148 |
| - query_idx = multi_argmax(uncertainty, n_instances=n_instances) |
| 153 | + |
| 154 | + if not random_tie_break: |
| 155 | + query_idx = multi_argmax(uncertainty, n_instances=n_instances) |
| 156 | + else: |
| 157 | + query_idx = shuffled_argmax(uncertainty, n_instances=n_instances) |
149 | 158 |
|
150 | 159 | return query_idx, X[query_idx]
|
151 | 160 |
|
152 | 161 |
|
153 | 162 | def margin_sampling(classifier: BaseEstimator, X: modALinput,
|
154 |
| - n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
| 163 | + n_instances: int = 1, random_tie_break: bool = False, |
| 164 | + **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
155 | 165 | """
|
156 | 166 | Margin sampling query strategy. Selects the instances where the difference between the first most likely and second
|
157 | 167 | most likely classes are the smallest.
|
158 |
| -
|
159 | 168 | Args:
|
160 | 169 | classifier: The classifier for which the labels are to be queried.
|
161 | 170 | X: The pool of samples to query from.
|
162 | 171 | n_instances: Number of samples to be queried.
|
163 | 172 | **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function.
|
164 |
| -
|
165 | 173 | Returns:
|
166 | 174 | The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
|
167 | 175 | """
|
168 | 176 | margin = classifier_margin(classifier, X, **uncertainty_measure_kwargs)
|
169 |
| - query_idx = multi_argmax(-margin, n_instances=n_instances) |
| 177 | + |
| 178 | + if not random_tie_break: |
| 179 | + query_idx = multi_argmax(-margin, n_instances=n_instances) |
| 180 | + else: |
| 181 | + query_idx = shuffled_argmax(-margin, n_instances=n_instances) |
170 | 182 |
|
171 | 183 | return query_idx, X[query_idx]
|
172 | 184 |
|
173 | 185 |
|
174 | 186 | def entropy_sampling(classifier: BaseEstimator, X: modALinput,
|
175 |
| - n_instances: int = 1, **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
| 187 | + n_instances: int = 1, random_tie_break: bool = False, |
| 188 | + **uncertainty_measure_kwargs) -> Tuple[np.ndarray, modALinput]: |
176 | 189 | """
|
177 |
| - Entropy sampling query strategy. Selects the instances where the class probabilities have the largest entropy. |
| 190 | + Entropy sampling query strategy. Selects the instances where the class probabilities |
| 191 | + have the largest entropy. |
178 | 192 |
|
179 | 193 | Args:
|
180 | 194 | classifier: The classifier for which the labels are to be queried.
|
181 | 195 | X: The pool of samples to query from.
|
182 | 196 | n_instances: Number of samples to be queried.
|
183 |
| - **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty measure function. |
| 197 | + random_tie_break: If True, shuffles utility scores to randomize the order. This |
| 198 | + can be used to break the tie when the highest utility score is not unique. |
| 199 | + **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty |
| 200 | + measure function. |
184 | 201 |
|
185 | 202 | Returns:
|
186 |
| - The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled. |
| 203 | + The indices of the instances from X chosen to be labelled; |
| 204 | + the instances from X chosen to be labelled. |
187 | 205 | """
|
188 | 206 | entropy = classifier_entropy(classifier, X, **uncertainty_measure_kwargs)
|
189 |
| - query_idx = multi_argmax(entropy, n_instances=n_instances) |
| 207 | + |
| 208 | + if not random_tie_break: |
| 209 | + query_idx = multi_argmax(entropy, n_instances=n_instances) |
| 210 | + else: |
| 211 | + query_idx = shuffled_argmax(entropy, n_instances=n_instances) |
190 | 212 |
|
191 | 213 | return query_idx, X[query_idx]
|
0 commit comments