|
21 | 21 | from collections import namedtuple
|
22 | 22 |
|
23 | 23 | from sklearn.ensemble import RandomForestClassifier
|
| 24 | +from sklearn.gaussian_process import GaussianProcessRegressor |
24 | 25 | from sklearn.exceptions import NotFittedError
|
25 | 26 | from sklearn.metrics import confusion_matrix
|
26 | 27 | from sklearn.svm import SVC
|
@@ -417,6 +418,39 @@ def test_KL_max_disagreement(self):
|
417 | 418 | )
|
418 | 419 | np.testing.assert_almost_equal(returned_KL_disagreement, true_KL_disagreement)
|
419 | 420 |
|
| 421 | + def test_vote_entropy_sampling(self): |
| 422 | + for n_samples, n_features, n_classes in product(range(1, 10), range(1, 10), range(1, 10)): |
| 423 | + committee = mock.MockCommittee(classes_=np.asarray(range(n_classes)), |
| 424 | + vote_return=np.zeros(shape=(n_samples, n_classes), dtype=np.int16)) |
| 425 | + modAL.disagreement.vote_entropy_sampling(committee, np.random.rand(n_samples, n_features)) |
| 426 | + modAL.disagreement.vote_entropy_sampling(committee, np.random.rand(n_samples, n_features), |
| 427 | + random_tie_break=True) |
| 428 | + |
| 429 | + def test_consensus_entropy_sampling(self): |
| 430 | + for n_samples, n_features, n_classes in product(range(1, 10), range(1, 10), range(1, 10)): |
| 431 | + committee = mock.MockCommittee(predict_proba_return=np.random.rand(n_samples, n_classes)) |
| 432 | + modAL.disagreement.consensus_entropy_sampling(committee, np.random.rand(n_samples, n_features)) |
| 433 | + modAL.disagreement.consensus_entropy_sampling(committee, np.random.rand(n_samples, n_features), |
| 434 | + random_tie_break=True) |
| 435 | + |
| 436 | + def test_max_disagreement_sampling(self): |
| 437 | + for n_samples, n_features, n_classes, n_learners in product(range(1, 10), range(1, 10), range(1, 10), range(2, 5)): |
| 438 | + committee = mock.MockCommittee( |
| 439 | + n_learners=n_learners, classes_=range(n_classes), |
| 440 | + vote_proba_return=np.zeros(shape=(n_samples, n_learners, n_classes)) |
| 441 | + ) |
| 442 | + modAL.disagreement.max_disagreement_sampling(committee, np.random.rand(n_samples, n_features)) |
| 443 | + modAL.disagreement.max_disagreement_sampling(committee, np.random.rand(n_samples, n_features), |
| 444 | + random_tie_break=True) |
| 445 | + |
| 446 | + def test_max_std_sampling(self): |
| 447 | + for n_samples, n_features in product(range(1, 10), range(1, 10)): |
| 448 | + regressor = GaussianProcessRegressor() |
| 449 | + regressor.fit(np.random.rand(n_samples, n_features), np.random.rand(n_samples)) |
| 450 | + modAL.disagreement.max_std_sampling(regressor, np.random.rand(n_samples, n_features)) |
| 451 | + modAL.disagreement.max_std_sampling(regressor, np.random.rand(n_samples, n_features), |
| 452 | + random_tie_break=True) |
| 453 | + |
420 | 454 |
|
421 | 455 | class TestEER(unittest.TestCase):
|
422 | 456 | def test_eer(self):
|
|
0 commit comments