10
10
from sklearn .base import BaseEstimator
11
11
12
12
from modAL .utils .data import modALinput
13
- from modAL .utils .selection import multi_argmax
13
+ from modAL .utils .selection import multi_argmax , shuffled_argmax
14
14
from modAL .models .base import BaseCommittee
15
15
16
16
@@ -103,80 +103,116 @@ def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba
103
103
104
104
105
105
def vote_entropy_sampling (committee : BaseCommittee , X : modALinput ,
106
- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
106
+ n_instances : int = 1 , random_tie_break = False ,
107
+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
107
108
"""
108
109
Vote entropy sampling strategy.
109
110
110
111
Args:
111
112
committee: The committee for which the labels are to be queried.
112
113
X: The pool of samples to query from.
113
114
n_instances: Number of samples to be queried.
114
- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
115
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
116
+ can be used to break the tie when the highest utility score is not unique.
117
+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
118
+ measure function.
115
119
116
120
Returns:
117
- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
121
+ The indices of the instances from X chosen to be labelled;
122
+ the instances from X chosen to be labelled.
118
123
"""
119
124
disagreement = vote_entropy (committee , X , ** disagreement_measure_kwargs )
120
- query_idx = multi_argmax (disagreement , n_instances = n_instances )
125
+
126
+ if not random_tie_break :
127
+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
128
+ else :
129
+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
121
130
122
131
return query_idx , X [query_idx ]
123
132
124
133
125
134
def consensus_entropy_sampling (committee : BaseCommittee , X : modALinput ,
126
- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
135
+ n_instances : int = 1 , random_tie_break = False ,
136
+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
127
137
"""
128
138
Consensus entropy sampling strategy.
129
139
130
140
Args:
131
141
committee: The committee for which the labels are to be queried.
132
142
X: The pool of samples to query from.
133
143
n_instances: Number of samples to be queried.
134
- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
144
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
145
+ can be used to break the tie when the highest utility score is not unique.
146
+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
147
+ measure function.
135
148
136
149
Returns:
137
- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
150
+ The indices of the instances from X chosen to be labelled;
151
+ the instances from X chosen to be labelled.
138
152
"""
139
153
disagreement = consensus_entropy (committee , X , ** disagreement_measure_kwargs )
140
- query_idx = multi_argmax (disagreement , n_instances = n_instances )
154
+
155
+ if not random_tie_break :
156
+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
157
+ else :
158
+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
141
159
142
160
return query_idx , X [query_idx ]
143
161
144
162
145
163
def max_disagreement_sampling (committee : BaseCommittee , X : modALinput ,
146
- n_instances : int = 1 ,** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
164
+ n_instances : int = 1 , random_tie_break = False ,
165
+ ** disagreement_measure_kwargs ) -> Tuple [np .ndarray , modALinput ]:
147
166
"""
148
167
Maximum disagreement sampling strategy.
149
168
150
169
Args:
151
170
committee: The committee for which the labels are to be queried.
152
171
X: The pool of samples to query from.
153
172
n_instances: Number of samples to be queried.
154
- **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
173
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
174
+ can be used to break the tie when the highest utility score is not unique.
175
+ **disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
176
+ measure function.
155
177
156
178
Returns:
157
- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
179
+ The indices of the instances from X chosen to be labelled;
180
+ the instances from X chosen to be labelled.
158
181
"""
159
182
disagreement = KL_max_disagreement (committee , X , ** disagreement_measure_kwargs )
160
- query_idx = multi_argmax (disagreement , n_instances = n_instances )
183
+
184
+ if not random_tie_break :
185
+ query_idx = multi_argmax (disagreement , n_instances = n_instances )
186
+ else :
187
+ query_idx = shuffled_argmax (disagreement , n_instances = n_instances )
161
188
162
189
return query_idx , X [query_idx ]
163
190
164
191
165
192
def max_std_sampling (regressor : BaseEstimator , X : modALinput ,
166
- n_instances : int = 1 , ** predict_kwargs ) -> Tuple [np .ndarray , modALinput ]:
193
+ n_instances : int = 1 , random_tie_break = False ,
194
+ ** predict_kwargs ) -> Tuple [np .ndarray , modALinput ]:
167
195
"""
168
196
Regressor standard deviation sampling strategy.
169
197
170
198
Args:
171
199
regressor: The regressor for which the labels are to be queried.
172
200
X: The pool of samples to query from.
173
201
n_instances: Number of samples to be queried.
202
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
203
+ can be used to break the tie when the highest utility score is not unique.
174
204
**predict_kwargs: Keyword arguments to be passed to :meth:`predict` of the CommiteeRegressor.
175
205
176
206
Returns:
177
- The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
207
+ The indices of the instances from X chosen to be labelled;
208
+ the instances from X chosen to be labelled.
178
209
"""
179
210
_ , std = regressor .predict (X , return_std = True , ** predict_kwargs )
180
211
std = std .reshape (X .shape [0 ], )
181
- query_idx = multi_argmax (std , n_instances = n_instances )
212
+
213
+ if not random_tie_break :
214
+ query_idx = multi_argmax (std , n_instances = n_instances )
215
+ else :
216
+ query_idx = shuffled_argmax (std , n_instances = n_instances )
217
+
182
218
return query_idx , X [query_idx ]
0 commit comments