Skip to content

Commit af71893

Browse files
authored
fix cv generator in feature selection classes (#803)
* fix cv generator in feature selection classes * add tests| * fix error shuffle features * fix style
1 parent e3f99db commit af71893

File tree

6 files changed

+89
-4
lines changed

6 files changed

+89
-4
lines changed

feature_engine/selection/shuffle_features.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import GeneratorType
12
from typing import List, MutableSequence, Union
23

34
import numpy as np
@@ -221,12 +222,14 @@ def fit(
221222
# check that there are more than 1 variable to select from
222223
self._check_variable_number()
223224

225+
cv = list(self.cv) if isinstance(self.cv, GeneratorType) else self.cv
226+
224227
# train model with all features and cross-validation
225228
model = cross_validate(
226229
self.estimator,
227230
X[self.variables_],
228231
y,
229-
cv=self.cv,
232+
cv=cv,
230233
return_estimator=True,
231234
scoring=self.scoring,
232235
params={"sample_weight": sample_weight},
@@ -236,7 +239,7 @@ def fit(
236239
self.initial_model_performance_ = model["test_score"].mean()
237240

238241
# extract the validation folds
239-
cv_ = check_cv(self.cv, y=y, classifier=is_classifier(self.estimator))
242+
cv_ = check_cv(cv, y=y, classifier=is_classifier(self.estimator))
240243
validation_indices = [val_index for _, val_index in cv_.split(X, y)]
241244

242245
# get performance metric

feature_engine/selection/smart_correlation_selection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import GeneratorType
12
from typing import List, Union
23

34
import pandas as pd
@@ -317,13 +318,14 @@ def fit(self, X: pd.DataFrame, y: pd.Series = None):
317318
# select best performing feature according to estimator
318319
if self.selection_method == "model_performance":
319320
correlated_dict = dict()
321+
cv = list(self.cv) if isinstance(self.cv, GeneratorType) else self.cv
320322
for feature_group in correlated_groups:
321323
feature_performance, _ = single_feature_performance(
322324
X,
323325
y,
324326
feature_group,
325327
self.estimator,
326-
self.cv,
328+
cv,
327329
self.scoring,
328330
)
329331
# get most important feature

feature_engine/selection/target_mean_selection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import GeneratorType
12
from typing import List, Union
23

34
import pandas as pd
@@ -299,6 +300,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
299300

300301
self.feature_performance_ = {}
301302

303+
cv = list(self.cv) if isinstance(self.cv, GeneratorType) else self.cv
304+
302305
for variable in self.variables_:
303306
# clone estimator
304307
estimator = clone(est)
@@ -310,7 +313,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
310313
estimator,
311314
X,
312315
y,
313-
cv=self.cv,
316+
cv=cv,
314317
scoring=self.scoring,
315318
)
316319

tests/test_selection/test_shuffle_features.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import numpy as np
22
import pandas as pd
33
import pytest
4+
45
from sklearn.ensemble import RandomForestClassifier
56
from sklearn.linear_model import LinearRegression
7+
from sklearn.model_selection import StratifiedKFold
68
from sklearn.tree import DecisionTreeRegressor
79

810
from feature_engine.selection import SelectByShuffling
@@ -93,6 +95,42 @@ def test_regression_cv_2_and_mse(load_diabetes_dataset):
9395
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
9496

9597

98+
def test_cv_generator(df_test):
99+
X, y = df_test
100+
cv = StratifiedKFold(n_splits=3)
101+
102+
X, y = df_test
103+
sel = SelectByShuffling(
104+
RandomForestClassifier(random_state=1),
105+
threshold=0.01,
106+
random_state=1,
107+
cv=3,
108+
)
109+
sel.fit(X, y)
110+
111+
# expected result
112+
Xtransformed = pd.DataFrame(X["var_7"].copy())
113+
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
114+
115+
sel = SelectByShuffling(
116+
RandomForestClassifier(random_state=1),
117+
threshold=0.01,
118+
random_state=1,
119+
cv=cv,
120+
)
121+
sel.fit(X, y)
122+
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
123+
124+
sel = SelectByShuffling(
125+
RandomForestClassifier(random_state=1),
126+
threshold=0.01,
127+
random_state=1,
128+
cv=cv.split(X, y),
129+
)
130+
sel.fit(X, y)
131+
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
132+
133+
96134
def test_raises_threshold_error():
97135
with pytest.raises(ValueError):
98136
SelectByShuffling(RandomForestClassifier(random_state=1), threshold="hello")

tests/test_selection/test_smart_correlation_selection.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from sklearn.datasets import make_classification
55
from sklearn.ensemble import RandomForestClassifier
6+
from sklearn.model_selection import KFold
67

78
from feature_engine.selection import SmartCorrelatedSelection
89
from tests.estimator_checks.init_params_allowed_values_checks import (
@@ -213,6 +214,27 @@ def test_model_performance_2_correlated_groups(df_test):
213214
pd.testing.assert_frame_equal(Xt, df)
214215

215216

217+
def test_cv_generator(df_single):
218+
X, y = df_single
219+
cv = KFold(3)
220+
221+
transformer = SmartCorrelatedSelection(
222+
variables=None,
223+
method="pearson",
224+
threshold=0.8,
225+
missing_values="raise",
226+
selection_method="model_performance",
227+
estimator=RandomForestClassifier(n_estimators=10, random_state=1),
228+
scoring="roc_auc",
229+
cv=cv.split(X, y),
230+
)
231+
232+
Xt = transformer.fit_transform(X, y)
233+
234+
df = X[["var_0", "var_2", "var_3", "var_4", "var_5"]].copy()
235+
pd.testing.assert_frame_equal(Xt, df)
236+
237+
216238
def test_error_if_select_model_performance_and_y_is_none(df_single):
217239
X, y = df_single
218240

tests/test_selection/test_target_mean_selection.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
22
import pytest
33

4+
from sklearn.model_selection import StratifiedKFold
5+
46
from feature_engine.selection import SelectByTargetMeanPerformance
57

68

@@ -161,6 +163,21 @@ def test_regression():
161163
pd.testing.assert_frame_equal(sel.transform(X), Xtransformed)
162164

163165

166+
def test_cv_generator():
167+
X, y = df_classification()
168+
cv = StratifiedKFold(n_splits=2)
169+
sel = SelectByTargetMeanPerformance(
170+
variables=None,
171+
scoring="accuracy",
172+
threshold=None,
173+
bins=2,
174+
strategy="equal_width",
175+
cv=cv.split(X, y),
176+
)
177+
sel.fit(X, y)
178+
pd.testing.assert_frame_equal(sel.transform(X), X[["cat_var_A", "num_var_A"]])
179+
180+
164181
def test_error_wrong_params():
165182
with pytest.raises(ValueError):
166183
SelectByTargetMeanPerformance(scoring="mean_squared")

0 commit comments

Comments
 (0)