Skip to content

Commit 5dfceb8

Browse files
authored
fix cv generator recursive selectors (#805)
1 parent af71893 commit 5dfceb8

File tree

5 files changed

+52
-5
lines changed

5 files changed

+52
-5
lines changed

feature_engine/selection/base_recursive_selector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import GeneratorType
12
from typing import List, Union
23

34
import pandas as pd
@@ -144,6 +145,8 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
144145
else:
145146
self.variables_ = check_numerical_variables(X, self.variables)
146147

148+
self._cv = list(self.cv) if isinstance(self.cv, GeneratorType) else self.cv
149+
147150
# check that there are more than 1 variable to select from
148151
self._check_variable_number()
149152

@@ -155,7 +158,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
155158
self.estimator,
156159
X[self.variables_],
157160
y,
158-
cv=self.cv,
161+
cv=self._cv,
159162
scoring=self.scoring,
160163
return_estimator=True,
161164
)

feature_engine/selection/recursive_feature_addition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
170170
self.estimator,
171171
X[first_most_important_feature].to_frame(),
172172
y,
173-
cv=self.cv,
173+
cv=self._cv,
174174
scoring=self.scoring,
175175
return_estimator=True,
176176
)
@@ -197,7 +197,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
197197
self.estimator,
198198
X[_selected_features + [feature]],
199199
y,
200-
cv=self.cv,
200+
cv=self._cv,
201201
scoring=self.scoring,
202202
return_estimator=True,
203203
)

feature_engine/selection/recursive_feature_elimination.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
189189
self.estimator,
190190
X_tmp.drop(columns=feature),
191191
y,
192-
cv=self.cv,
192+
cv=self._cv,
193193
scoring=self.scoring,
194194
return_estimator=False,
195195
)
@@ -216,7 +216,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
216216
self.estimator,
217217
X_tmp,
218218
y,
219-
cv=self.cv,
219+
cv=self._cv,
220220
return_estimator=False,
221221
scoring=self.scoring,
222222
)

tests/test_selection/test_recursive_feature_addition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier
44
from sklearn.linear_model import Lasso, LogisticRegression, LinearRegression
5+
from sklearn.model_selection import KFold
56
from sklearn.tree import DecisionTreeRegressor
67

78
from feature_engine.selection import RecursiveFeatureAddition
@@ -254,3 +255,23 @@ def test_feature_importance(load_diabetes_dataset):
254255

255256
assert round(sel.feature_importances_, 2).to_list() == imps
256257
assert round(sel.feature_importances_std_, 2).to_list() == imps_std
258+
259+
260+
def test_cv_generator(load_diabetes_dataset):
261+
X, y = load_diabetes_dataset
262+
linear_model = LinearRegression()
263+
cv = KFold(n_splits=3)
264+
sel = RecursiveFeatureAddition(estimator=linear_model, scoring="r2", cv=3).fit(X, y)
265+
expected = sel.transform(X)
266+
267+
sel = RecursiveFeatureAddition(estimator=linear_model, scoring="r2", cv=cv).fit(
268+
X, y
269+
)
270+
test1 = sel.transform(X)
271+
pd.testing.assert_frame_equal(expected, test1)
272+
273+
sel = RecursiveFeatureAddition(
274+
estimator=linear_model, scoring="r2", cv=cv.split(X, y)
275+
).fit(X, y)
276+
test2 = sel.transform(X)
277+
pd.testing.assert_frame_equal(expected, test2)

tests/test_selection/test_recursive_feature_elimination.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier
44
from sklearn.linear_model import Lasso, LinearRegression, LogisticRegression
5+
from sklearn.model_selection import KFold
56
from sklearn.tree import DecisionTreeRegressor
67

78
from feature_engine.selection import RecursiveFeatureElimination
@@ -276,3 +277,25 @@ def test_feature_importance(load_diabetes_dataset):
276277

277278
assert round(sel.feature_importances_, 2).to_list() == imps
278279
assert round(sel.feature_importances_std_, 2).to_list() == imps_std
280+
281+
282+
def test_cv_generator(load_diabetes_dataset):
283+
X, y = load_diabetes_dataset
284+
linear_model = LinearRegression()
285+
cv = KFold(n_splits=3)
286+
sel = RecursiveFeatureElimination(estimator=linear_model, scoring="r2", cv=3).fit(
287+
X, y
288+
)
289+
expected = sel.transform(X)
290+
291+
sel = RecursiveFeatureElimination(estimator=linear_model, scoring="r2", cv=cv).fit(
292+
X, y
293+
)
294+
test1 = sel.transform(X)
295+
pd.testing.assert_frame_equal(expected, test1)
296+
297+
sel = RecursiveFeatureElimination(
298+
estimator=linear_model, scoring="r2", cv=cv.split(X, y)
299+
).fit(X, y)
300+
test2 = sel.transform(X)
301+
pd.testing.assert_frame_equal(expected, test2)

0 commit comments

Comments
 (0)