Skip to content

Commit 50c25e2

Browse files
authored
Merge branch 'hyperopt:master' into master
2 parents 191efe9 + 840cfc9 commit 50c25e2

File tree

3 files changed

+68
-7
lines changed

3 files changed

+68
-7
lines changed

hpsklearn/estimator/_cost_fn.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
LeaveOneOut, \
1010
StratifiedKFold, \
1111
KFold, \
12+
GroupKFold, \
1213
PredefinedSplit
1314
from sklearn.metrics import accuracy_score, r2_score
1415

@@ -24,6 +25,7 @@ def _cost_fn(argd,
2425
EX_list: typing.Union[list, tuple] = None,
2526
valid_size: float = 0.2,
2627
n_folds: int = None,
28+
kfolds_group: typing.Union[list, np.ndarray] = None,
2729
shuffle: bool = False,
2830
random_state: typing.Union[int, np.random.Generator] = np.random.default_rng(),
2931
use_partial_fit: bool = False,
@@ -55,7 +57,13 @@ def _cost_fn(argd,
5557
n_folds: int, default is None
5658
When n_folds is not None, use K-fold cross-validation when
5759
n_folds > 2. Or, use leave-one-out cross-validation when
58-
n_folds = -1.
60+
n_folds = -1. For Group K-fold cross-validation, functions as
61+
`n_splits`.
62+
63+
kfolds_group: list or ndarray, default is None
64+
When kfolds_group is not None, use Group K-fold cross-validation
65+
with the specified groups. The length of kfolds_group must be
66+
equal to the number of samples in X.
5967
6068
shuffle: bool, default is False
6169
Whether to perform sample shuffling before splitting the
@@ -145,10 +153,14 @@ def _cost_fn(argd,
145153
random_state=random_state_sklearn
146154
).split(X, y)
147155
else:
148-
info(f"Will use K-fold CV with K: {n_folds} and Shuffle: {shuffle}")
149-
cv_iter = KFold(n_splits=n_folds,
150-
shuffle=shuffle,
151-
random_state=random_state_sklearn).split(X)
156+
if kfolds_group is not None:
157+
info(f"Will use Group K-fold CV with K: {n_folds} and Shuffle: {shuffle}")
158+
cv_iter = GroupKFold(n_splits=n_folds).split(X, y, kfolds_group)
159+
else:
160+
info(f"Will use K-fold CV with K: {n_folds} and Shuffle: {shuffle}")
161+
cv_iter = KFold(n_splits=n_folds,
162+
shuffle=shuffle,
163+
random_state=random_state_sklearn).split(X)
152164
else:
153165
if not shuffle: # always choose the last samples.
154166
info(f"Will use the last {valid_size} portion of samples for validation")

hpsklearn/estimator/estimator.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def fit_iter(self, X, y,
215215
EX_list: typing.Union[list, tuple] = None,
216216
valid_size: float = .2,
217217
n_folds: int = None,
218+
kfolds_group: typing.Union[list, np.ndarray] = None,
218219
cv_shuffle: bool = False,
219220
warm_start: bool = False,
220221
random_state: np.random.Generator = np.random.default_rng(),
@@ -240,7 +241,13 @@ def fit_iter(self, X, y,
240241
n_folds: int, default is None
241242
When n_folds is not None, use K-fold cross-validation when
242243
n_folds > 2. Or, use leave-one-out cross-validation when
243-
n_folds = -1.
244+
n_folds = -1. For Group K-fold cross-validation, functions as
245+
`n_splits`.
246+
247+
kfolds_group: list or ndarray, default is None
248+
When kfolds_group is not None, use Group K-fold cross-validation
249+
with the specified groups. The length of kfolds_group must be
250+
equal to the number of samples in X.
244251
245252
cv_shuffle: bool, default is False
246253
Whether to perform sample shuffling before splitting the
@@ -277,6 +284,7 @@ def fit_iter(self, X, y,
277284
EX_list=EX_list,
278285
valid_size=valid_size,
279286
n_folds=n_folds,
287+
kfolds_group=kfolds_group,
280288
shuffle=cv_shuffle,
281289
random_state=random_state,
282290
use_partial_fit=self.use_partial_fit,
@@ -398,6 +406,7 @@ def fit(self, X, y,
398406
EX_list: typing.Union[list, tuple] = None,
399407
valid_size: float = .2,
400408
n_folds: int = None,
409+
kfolds_group: typing.Union[list, np.ndarray] = None,
401410
cv_shuffle: bool = False,
402411
warm_start: bool = False,
403412
random_state: np.random.Generator = np.random.default_rng()
@@ -424,7 +433,13 @@ def fit(self, X, y,
424433
n_folds: int, default is None
425434
When n_folds is not None, use K-fold cross-validation when
426435
n_folds > 2. Or, use leave-one-out cross-validation when
427-
n_folds = -1.
436+
n_folds = -1. For Group K-fold cross-validation, functions as
437+
`n_splits`.
438+
439+
kfolds_group: list or ndarray, default is None
440+
When kfolds_group is not None, use Group K-fold cross-validation
441+
with the specified groups. The length of group_kfolds must be
442+
equal to the number of samples in X.
428443
429444
cv_shuffle: bool, default is False
430445
Whether to perform sample shuffling before splitting the
@@ -450,6 +465,7 @@ def fit(self, X, y,
450465
EX_list=EX_list,
451466
valid_size=valid_size,
452467
n_folds=n_folds,
468+
kfolds_group=kfolds_group,
453469
cv_shuffle=cv_shuffle,
454470
warm_start=warm_start,
455471
random_state=random_state)

tests/test_estimator/test_estimator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,5 +246,38 @@ def test_crossvalidation(self):
246246
cls.fit(X, y, cv_shuffle=True, n_folds=5)
247247

248248

249+
class TestGroupCrossValidation(unittest.TestCase):
250+
"""
251+
Class for testing estimator with group cross validation
252+
"""
253+
254+
def setUp(self):
255+
"""
256+
Setup the random seed
257+
"""
258+
np.random.seed(123)
259+
260+
@RetryOnTrialsException
261+
def test_crossvalidation(self):
262+
"""
263+
Demonstrate performing a group k-fold CV using the fit() method.
264+
"""
265+
# Generate some random data
266+
X = np.hstack([
267+
np.vstack([
268+
np.random.normal(0, 1, size=(1000, 10)),
269+
np.random.normal(1, 1, size=(1000, 10)),
270+
]),
271+
np.random.normal(0, 1, size=(2000, 10)),
272+
])
273+
y = np.zeros(2000)
274+
y[:1000] = 1
275+
276+
# Try to fit a model
277+
cls = HyperoptEstimator(classifier=sgd_classifier("sgd", loss="log"), preprocessing=[])
278+
cls.fit(X, y, cv_shuffle=True, n_folds=5,
279+
kfolds_group=np.array([0]*500 + [1]*500 + [2]*500 + [3]*500)) # noqa: E226
280+
281+
249282
if __name__ == '__main__':
250283
unittest.main()

0 commit comments

Comments
 (0)