Skip to content

Commit 47bc726

Browse files
author
Boyan Hristov
committed
FIX #108 - no longer storing transformed training data for on_transformed strategies
1 parent 4254df9 commit 47bc726

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

modAL/batch.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import scipy.sparse as sp
99
from sklearn.metrics.pairwise import pairwise_distances, pairwise_distances_argmin_min
1010

11-
from modAL.utils.data import data_vstack, modALinput
11+
from modAL.utils.data import data_vstack, modALinput, data_shape
1212
from modAL.models.base import BaseCommittee, BaseLearner
1313
from modAL.uncertainty import classifier_uncertainty
1414

@@ -150,8 +150,10 @@ def ranked_batch(classifier: Union[BaseLearner, BaseCommittee],
150150
if classifier.X_training is None:
151151
best_coldstart_instance_index, labeled = select_cold_start_instance(X=unlabeled, metric=metric, n_jobs=n_jobs)
152152
instance_index_ranking = [best_coldstart_instance_index]
153-
elif classifier.X_training.shape[0] > 0:
154-
labeled = classifier.Xt_training[:] if classifier.on_transformed else classifier.X_training[:]
153+
elif data_shape(classifier.X_training)[0] > 0:
154+
labeled = classifier.transform_without_estimating(
155+
classifier.X_training
156+
) if classifier.on_transformed else classifier.X_training[:]
155157
instance_index_ranking = []
156158

157159
# The maximum number of records to sample.

modAL/models/base.py

-8
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ def __init__(self,
6666
self.on_transformed = on_transformed
6767

6868
self.X_training = X_training
69-
self.Xt_training = None
7069
self.y_training = y_training
7170
if X_training is not None:
7271
self._fit_to_known(bootstrap=bootstrap_init, **fit_kwargs)
73-
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
7472

7573
assert isinstance(force_all_finite, bool), 'force_all_finite must be a bool'
7674
self.force_all_finite = force_all_finite
@@ -92,15 +90,10 @@ def _add_training_data(self, X: modALinput, y: modALinput) -> None:
9290

9391
if self.X_training is None:
9492
self.X_training = X
95-
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
9693
self.y_training = y
9794
else:
9895
try:
9996
self.X_training = data_vstack((self.X_training, X))
100-
self.Xt_training = data_vstack((
101-
self.Xt_training,
102-
self.transform_without_estimating(X)
103-
)) if self.on_transformed else None
10497
self.y_training = data_vstack((self.y_training, y))
10598
except ValueError:
10699
raise ValueError('the dimensions of the new training data and label must'
@@ -213,7 +206,6 @@ def fit(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwarg
213206
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None,
214207
force_all_finite=self.force_all_finite)
215208
self.X_training, self.y_training = X, y
216-
self.Xt_training = self.transform_without_estimating(self.X_training) if self.on_transformed else None
217209
return self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
218210

219211
def predict(self, X: modALinput, **predict_kwargs) -> Any:

tests/core_tests.py

+40
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.multiclass import OneVsRestClassifier
3030
from sklearn.pipeline import make_pipeline
3131
from sklearn.preprocessing import FunctionTransformer
32+
from sklearn.feature_extraction.text import CountVectorizer
3233
from scipy.stats import entropy, norm
3334
from scipy.special import ndtr
3435
from scipy import sparse as sp
@@ -824,6 +825,45 @@ def test_on_transformed(self):
824825
query_idx, query_inst = learner.query(X_pool)
825826
learner.teach(X_pool.iloc[query_idx], y_pool[query_idx])
826827

828+
def test_on_transformed_with_variable_transformation(self):
829+
"""
830+
Learnable transformations naturally change after a model is retrained. Make sure this is handled
831+
properly for on_transformed=True query strategies.
832+
"""
833+
query_strategies = [
834+
modAL.batch.uncertainty_batch_sampling
835+
# add further strategies which work with instance representations
836+
# no further ones as of 09.12.2020
837+
]
838+
839+
X_labeled = ['Dog', 'Cat', 'Tree']
840+
841+
# contains unseen in labeled words, training model on those
842+
# will alter CountVectorizer transformations
843+
X_pool = ['Airplane', 'House']
844+
845+
y = [0, 1, 1, 0, 1] # irrelevant for test
846+
847+
for query_strategy in query_strategies:
848+
learner = modAL.models.learners.ActiveLearner(
849+
estimator=make_pipeline(
850+
CountVectorizer(),
851+
RandomForestClassifier(n_estimators=10)
852+
),
853+
query_strategy=query_strategy,
854+
X_training=X_labeled, y_training=y[:len(X_labeled)],
855+
on_transformed=True,
856+
)
857+
858+
for _ in range(len(X_pool)):
859+
query_idx, query_instance = learner.query(X_pool, n_instances=1)
860+
i = query_idx[0]
861+
862+
learner.teach(
863+
X=[X_pool[i]],
864+
y=[y[i]]
865+
)
866+
827867
def test_old_query_strategy_interface(self):
828868
n_samples = 10
829869
n_features = 5

0 commit comments

Comments
 (0)