Skip to content

Commit e0af35f

Browse files
committed
add: random tie break for disagreement sampling
1 parent b2eeede commit e0af35f

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

modAL/disagreement.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.base import BaseEstimator
1111

1212
from modAL.utils.data import modALinput
13-
from modAL.utils.selection import multi_argmax
13+
from modAL.utils.selection import multi_argmax, shuffled_argmax
1414
from modAL.models.base import BaseCommittee
1515

1616

@@ -103,80 +103,116 @@ def KL_max_disagreement(committee: BaseCommittee, X: modALinput, **predict_proba
103103

104104

105105
def vote_entropy_sampling(committee: BaseCommittee, X: modALinput,
106-
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
106+
n_instances: int = 1, random_tie_break=False,
107+
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
107108
"""
108109
Vote entropy sampling strategy.
109110
110111
Args:
111112
committee: The committee for which the labels are to be queried.
112113
X: The pool of samples to query from.
113114
n_instances: Number of samples to be queried.
114-
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
115+
random_tie_break: If True, shuffles utility scores to randomize the order. This
116+
can be used to break the tie when the highest utility score is not unique.
117+
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
118+
measure function.
115119
116120
Returns:
117-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
121+
The indices of the instances from X chosen to be labelled;
122+
the instances from X chosen to be labelled.
118123
"""
119124
disagreement = vote_entropy(committee, X, **disagreement_measure_kwargs)
120-
query_idx = multi_argmax(disagreement, n_instances=n_instances)
125+
126+
if not random_tie_break:
127+
query_idx = multi_argmax(disagreement, n_instances=n_instances)
128+
else:
129+
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
121130

122131
return query_idx, X[query_idx]
123132

124133

125134
def consensus_entropy_sampling(committee: BaseCommittee, X: modALinput,
126-
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
135+
n_instances: int = 1, random_tie_break=False,
136+
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
127137
"""
128138
Consensus entropy sampling strategy.
129139
130140
Args:
131141
committee: The committee for which the labels are to be queried.
132142
X: The pool of samples to query from.
133143
n_instances: Number of samples to be queried.
134-
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
144+
random_tie_break: If True, shuffles utility scores to randomize the order. This
145+
can be used to break the tie when the highest utility score is not unique.
146+
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
147+
measure function.
135148
136149
Returns:
137-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
150+
The indices of the instances from X chosen to be labelled;
151+
the instances from X chosen to be labelled.
138152
"""
139153
disagreement = consensus_entropy(committee, X, **disagreement_measure_kwargs)
140-
query_idx = multi_argmax(disagreement, n_instances=n_instances)
154+
155+
if not random_tie_break:
156+
query_idx = multi_argmax(disagreement, n_instances=n_instances)
157+
else:
158+
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
141159

142160
return query_idx, X[query_idx]
143161

144162

145163
def max_disagreement_sampling(committee: BaseCommittee, X: modALinput,
146-
n_instances: int = 1,**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
164+
n_instances: int = 1, random_tie_break=False,
165+
**disagreement_measure_kwargs) -> Tuple[np.ndarray, modALinput]:
147166
"""
148167
Maximum disagreement sampling strategy.
149168
150169
Args:
151170
committee: The committee for which the labels are to be queried.
152171
X: The pool of samples to query from.
153172
n_instances: Number of samples to be queried.
154-
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement measure function.
173+
random_tie_break: If True, shuffles utility scores to randomize the order. This
174+
can be used to break the tie when the highest utility score is not unique.
175+
**disagreement_measure_kwargs: Keyword arguments to be passed for the disagreement
176+
measure function.
155177
156178
Returns:
157-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
179+
The indices of the instances from X chosen to be labelled;
180+
the instances from X chosen to be labelled.
158181
"""
159182
disagreement = KL_max_disagreement(committee, X, **disagreement_measure_kwargs)
160-
query_idx = multi_argmax(disagreement, n_instances=n_instances)
183+
184+
if not random_tie_break:
185+
query_idx = multi_argmax(disagreement, n_instances=n_instances)
186+
else:
187+
query_idx = shuffled_argmax(disagreement, n_instances=n_instances)
161188

162189
return query_idx, X[query_idx]
163190

164191

165192
def max_std_sampling(regressor: BaseEstimator, X: modALinput,
166-
n_instances: int = 1, **predict_kwargs) -> Tuple[np.ndarray, modALinput]:
193+
n_instances: int = 1, random_tie_break=False,
194+
**predict_kwargs) -> Tuple[np.ndarray, modALinput]:
167195
"""
168196
Regressor standard deviation sampling strategy.
169197
170198
Args:
171199
regressor: The regressor for which the labels are to be queried.
172200
X: The pool of samples to query from.
173201
n_instances: Number of samples to be queried.
202+
random_tie_break: If True, shuffles utility scores to randomize the order. This
203+
can be used to break the tie when the highest utility score is not unique.
174204
**predict_kwargs: Keyword arguments to be passed to :meth:`predict` of the CommiteeRegressor.
175205
176206
Returns:
177-
The indices of the instances from X chosen to be labelled; the instances from X chosen to be labelled.
207+
The indices of the instances from X chosen to be labelled;
208+
the instances from X chosen to be labelled.
178209
"""
179210
_, std = regressor.predict(X, return_std=True, **predict_kwargs)
180211
std = std.reshape(X.shape[0], )
181-
query_idx = multi_argmax(std, n_instances=n_instances)
212+
213+
if not random_tie_break:
214+
query_idx = multi_argmax(std, n_instances=n_instances)
215+
else:
216+
query_idx = shuffled_argmax(std, n_instances=n_instances)
217+
182218
return query_idx, X[query_idx]

0 commit comments

Comments
 (0)