Skip to content

Commit b7133e8

Browse files
add Groups support in feature selection for cross validate (#804)
* add "groups" for cross_valiate enabling group based split support * add groups docstring * update probe_feature_selector for groups based cross_valiete support * update single_feature_importance for groups based cross_valiete support * update smart_correlation_selector for groups based cross_valiete support * update tests for groups argument * sort imports * update recursive feture selectors with groups * update SelectByShuffle with groups * update SelectByTargetMean with groups * remove unnecessary groups from feature selection tests * linting * remove groups from shuffle feature selector * add tests for groups * linting
1 parent 5dfceb8 commit b7133e8

18 files changed

+432
-48
lines changed

feature_engine/_docstrings/selection/_docstring.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@
5353
documentation.
5454
""".rstrip()
5555

56+
_groups_docstring = """groups: array-like of shape (n_samples,), default=None
57+
Group labels for the samples used while splitting the dataset into train/test set.
58+
Only used in conjunction with a “Group” cv instance (e.g., GroupKFold).
59+
""".rstrip()
60+
5661
_initial_model_performance_docstring = """initial_model_performance_:
5762
The model's performance when trained with the original dataset.
5863
""".rstrip()

feature_engine/selection/base_recursive_selector.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ class BaseRecursiveSelector(BaseSelector):
6666
across calls. For more details check Scikit-learn's `cross_validate`'s
6767
documentation.
6868
69+
groups: Array-like of shape (n_samples,), default=None
70+
Group labels for the samples used while splitting
71+
the dataset into train/test set. Only used in conjunction with a
72+
“Group” cv instance (e.g., GroupKFold).
73+
6974
confirm_variables: bool, default=False
7075
If set to True, variables that are not present in the input dataframe will be
7176
removed from the list of variables. Only used when passing a variable list to
@@ -105,6 +110,7 @@ def __init__(
105110
estimator,
106111
scoring: str = "roc_auc",
107112
cv=3,
113+
groups=None,
108114
threshold: Union[int, float] = 0.01,
109115
variables: Variables = None,
110116
confirm_variables: bool = False,
@@ -119,6 +125,7 @@ def __init__(
119125
self.scoring = scoring
120126
self.threshold = threshold
121127
self.cv = cv
128+
self.groups = groups
122129

123130
def fit(self, X: pd.DataFrame, y: pd.Series):
124131
"""
@@ -155,10 +162,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
155162

156163
# train model with all features and cross-validation
157164
model = cross_validate(
158-
self.estimator,
159-
X[self.variables_],
160-
y,
165+
estimator=self.estimator,
166+
X=X[self.variables_],
167+
y=y,
161168
cv=self._cv,
169+
groups=self.groups,
162170
scoring=self.scoring,
163171
return_estimator=True,
164172
)

feature_engine/selection/base_selection_functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def single_feature_performance(
167167
estimator,
168168
cv,
169169
scoring,
170+
groups=None,
170171
):
171172
"""
172173
Trains one estimator per feature and determines the performance of that estimator.
@@ -191,6 +192,11 @@ def single_feature_performance(
191192
scoring:
192193
The performance metric. Any supported by the Scikit-learn estimator.
193194
195+
groups: Array-like of shape (n_samples,), default=None
196+
Group labels for the samples used while splitting
197+
the dataset into train/test set. Only used in conjunction with a
198+
“Group” cv instance (e.g., GroupKFold).
199+
194200
Returns
195201
-------
196202
feature_performance: dict
@@ -213,6 +219,7 @@ def single_feature_performance(
213219
X[feature].to_frame(),
214220
y,
215221
cv=cv,
222+
groups=groups,
216223
return_estimator=False,
217224
scoring=scoring,
218225
)
@@ -228,6 +235,7 @@ def find_feature_importance(
228235
estimator,
229236
cv,
230237
scoring,
238+
groups=None,
231239
):
232240
"""
233241
Trains an estimator using cross-validation and derives feature importance from it.
@@ -253,6 +261,11 @@ def find_feature_importance(
253261
scoring:
254262
The performance metric. Any supported by the Scikit-learn estimator.
255263
264+
groups: Array-like of shape (n_samples,), default=None
265+
Group labels for the samples used while splitting
266+
the dataset into train/test set. Only used in conjunction with a
267+
“Group” cv instance (e.g., GroupKFold).
268+
256269
Returns
257270
-------
258271
feature_importance: pd.Series
@@ -271,6 +284,7 @@ def find_feature_importance(
271284
X,
272285
y,
273286
cv=cv,
287+
groups=groups,
274288
scoring=scoring,
275289
return_estimator=True,
276290
)

feature_engine/selection/probe_feature_selection.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from feature_engine._docstrings.methods import _fit_transform_docstring
1515
from feature_engine._docstrings.selection._docstring import (
1616
_cv_docstring,
17+
_groups_docstring,
1718
_features_to_drop_docstring,
1819
_fit_docstring,
1920
_get_support_docstring,
@@ -40,6 +41,7 @@
4041
estimator=_estimator_docstring,
4142
scoring=_scoring_docstring,
4243
cv=_cv_docstring,
44+
groups=_groups_docstring,
4345
confirm_variables=_confirm_variables_docstring,
4446
variables=_variables_numerical_docstring,
4547
feature_names_in_=_feature_names_in_docstring,
@@ -104,6 +106,8 @@ class ProbeFeatureSelection(BaseSelector):
104106
105107
{cv}
106108
109+
{groups}
110+
107111
Attributes
108112
----------
109113
probe_features_:
@@ -173,6 +177,7 @@ def __init__(
173177
n_probes: int = 1,
174178
distribution: str = "normal",
175179
cv=5,
180+
groups=None,
176181
random_state: int = 0,
177182
confirm_variables: bool = False,
178183
):
@@ -203,6 +208,7 @@ def __init__(
203208
self.scoring = scoring
204209
self.distribution = distribution
205210
self.cv = cv
211+
self.groups = groups
206212
self.n_probes = n_probes
207213
self.random_state = random_state
208214

@@ -238,20 +244,26 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
238244
if self.collective is True:
239245
# train model using entire dataset and derive feature importance
240246
f_importance_mean, f_importance_std = find_feature_importance(
241-
X_new, y, self.estimator, self.cv, self.scoring,
247+
X=X_new,
248+
y=y,
249+
estimator=self.estimator,
250+
cv=self.cv,
251+
groups=self.groups,
252+
scoring=self.scoring,
242253
)
243254
self.feature_importances_ = f_importance_mean
244255
self.feature_importances_std_ = f_importance_std
245256

246257
else:
247258
# trains a model per feature (single feature models)
248259
f_importance_mean, f_importance_std = single_feature_performance(
249-
X_new,
250-
y,
251-
X_new.columns,
252-
self.estimator,
253-
self.cv,
254-
self.scoring,
260+
X=X_new,
261+
y=y,
262+
variables=X_new.columns,
263+
estimator=self.estimator,
264+
cv=self.cv,
265+
groups=self.groups,
266+
scoring=self.scoring,
255267
)
256268
self.feature_importances_ = pd.Series(f_importance_mean)
257269
self.feature_importances_std_ = pd.Series(f_importance_std)

feature_engine/selection/recursive_feature_addition.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_features_to_drop_docstring,
2020
_fit_docstring,
2121
_get_support_docstring,
22+
_groups_docstring,
2223
_initial_model_performance_docstring,
2324
_scoring_docstring,
2425
_threshold_docstring,
@@ -35,6 +36,7 @@
3536
scoring=_scoring_docstring,
3637
threshold=_threshold_docstring,
3738
cv=_cv_docstring,
39+
groups=_groups_docstring,
3840
variables=_variables_numerical_docstring,
3941
confirm_variables=_confirm_variables_docstring,
4042
initial_model_performance_=_initial_model_performance_docstring,
@@ -87,6 +89,8 @@ class RecursiveFeatureAddition(BaseRecursiveSelector):
8789
8890
{cv}
8991
92+
{groups}
93+
9094
{confirm_variables}
9195
9296
Attributes
@@ -167,10 +171,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
167171

168172
# Run baseline model using only the most important feature
169173
baseline_model = cross_validate(
170-
self.estimator,
171-
X[first_most_important_feature].to_frame(),
172-
y,
174+
estimator=self.estimator,
175+
X=X[first_most_important_feature].to_frame(),
176+
y=y,
173177
cv=self._cv,
178+
groups=self.groups,
174179
scoring=self.scoring,
175180
return_estimator=True,
176181
)
@@ -194,10 +199,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
194199

195200
# Add feature and train new model
196201
model_tmp = cross_validate(
197-
self.estimator,
198-
X[_selected_features + [feature]],
199-
y,
202+
estimator=self.estimator,
203+
X=X[_selected_features + [feature]],
204+
y=y,
200205
cv=self._cv,
206+
groups=self.groups,
201207
scoring=self.scoring,
202208
return_estimator=True,
203209
)

feature_engine/selection/recursive_feature_elimination.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_features_to_drop_docstring,
2020
_fit_docstring,
2121
_get_support_docstring,
22+
_groups_docstring,
2223
_initial_model_performance_docstring,
2324
_scoring_docstring,
2425
_threshold_docstring,
@@ -35,6 +36,7 @@
3536
scoring=_scoring_docstring,
3637
threshold=_threshold_docstring,
3738
cv=_cv_docstring,
39+
groups=_groups_docstring,
3840
variables=_variables_numerical_docstring,
3941
confirm_variables=_confirm_variables_docstring,
4042
initial_model_performance_=_initial_model_performance_docstring,
@@ -88,6 +90,8 @@ class RecursiveFeatureElimination(BaseRecursiveSelector):
8890
8991
{cv}
9092
93+
{groups}
94+
9195
{confirm_variables}
9296
9397
Attributes
@@ -186,10 +190,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
186190

187191
# remove feature and train new model
188192
model_tmp = cross_validate(
189-
self.estimator,
190-
X_tmp.drop(columns=feature),
191-
y,
193+
estimator=self.estimator,
194+
X=X_tmp.drop(columns=feature),
195+
y=y,
192196
cv=self._cv,
197+
groups=self.groups,
193198
scoring=self.scoring,
194199
return_estimator=False,
195200
)
@@ -213,10 +218,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
213218
X_tmp = X_tmp.drop(columns=feature)
214219

215220
baseline_model = cross_validate(
216-
self.estimator,
217-
X_tmp,
218-
y,
221+
estimator=self.estimator,
222+
X=X_tmp,
223+
y=y,
219224
cv=self._cv,
225+
groups=self.groups,
220226
return_estimator=False,
221227
scoring=self.scoring,
222228
)

feature_engine/selection/shuffle_features.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ def fit(
226226

227227
# train model with all features and cross-validation
228228
model = cross_validate(
229-
self.estimator,
230-
X[self.variables_],
231-
y,
229+
estimator=self.estimator,
230+
X=X[self.variables_],
231+
y=y,
232232
cv=cv,
233233
return_estimator=True,
234234
scoring=self.scoring,

feature_engine/selection/single_feature_performance.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from feature_engine._docstrings.methods import _fit_transform_docstring
1717
from feature_engine._docstrings.selection._docstring import (
1818
_cv_docstring,
19+
_groups_docstring,
1920
_estimator_docstring,
2021
_features_to_drop_docstring,
2122
_fit_docstring,
@@ -45,6 +46,7 @@
4546
scoring=_scoring_docstring,
4647
threshold=_threshold_docstring,
4748
cv=_cv_docstring,
49+
groups=_groups_docstring,
4850
variables=_variables_numerical_docstring,
4951
confirm_variables=_confirm_variables_docstring,
5052
initial_model_performance_=_initial_model_performance_docstring,
@@ -83,6 +85,8 @@ class SelectBySingleFeaturePerformance(BaseSelector):
8385
8486
{cv}
8587
88+
{groups}
89+
8690
{confirm_variables}
8791
8892
Attributes
@@ -147,6 +151,7 @@ def __init__(
147151
estimator,
148152
scoring: str = "roc_auc",
149153
cv=3,
154+
groups=None,
150155
threshold: Union[int, float, None] = None,
151156
variables: Variables = None,
152157
confirm_variables: bool = False,
@@ -177,6 +182,7 @@ def __init__(
177182
self.scoring = scoring
178183
self.threshold = threshold
179184
self.cv = cv
185+
self.groups = groups
180186

181187
def fit(self, X: pd.DataFrame, y: pd.Series):
182188
"""
@@ -209,7 +215,13 @@ def fit(self, X: pd.DataFrame, y: pd.Series):
209215
)
210216

211217
self.feature_performance_, _ = single_feature_performance(
212-
X, y, self.variables_, self.estimator, self.cv, self.scoring
218+
X=X,
219+
y=y,
220+
variables=self.variables_,
221+
estimator=self.estimator,
222+
cv=self.cv,
223+
groups=self.groups,
224+
scoring=self.scoring,
213225
)
214226

215227
# select features

0 commit comments

Comments
 (0)