Skip to content

Commit 0d24a4d

Browse files
committed
modAL.utils.selection.weighted_random checks added to avoid division with zero
1 parent e6d658c commit 0d24a4d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modAL/utils/selection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def weighted_random(weights, n_instances=1):
4747
n_instances random indices based on the weights.
4848
"""
4949
assert n_instances <= len(weights), 'n_instances must be less or equal than the size of utility'
50+
weight_sum = np.sum(weights)
51+
assert weight_sum > 0, 'the sum of weights must be larger than zero'
5052

51-
random_idx = np.random.choice(range(len(weights)), size=n_instances, p=weights / np.sum(weights), replace=False)
53+
random_idx = np.random.choice(range(len(weights)), size=n_instances, p=weights/weight_sum, replace=False)
5254
return random_idx

0 commit comments

Comments
 (0)