@@ -215,6 +215,7 @@ def fit_iter(self, X, y,
215
215
EX_list : typing .Union [list , tuple ] = None ,
216
216
valid_size : float = .2 ,
217
217
n_folds : int = None ,
218
+ kfolds_group : typing .Union [list , np .ndarray ] = None ,
218
219
cv_shuffle : bool = False ,
219
220
warm_start : bool = False ,
220
221
random_state : np .random .Generator = np .random .default_rng (),
@@ -240,7 +241,13 @@ def fit_iter(self, X, y,
240
241
n_folds: int, default is None
241
242
When n_folds is not None, use K-fold cross-validation when
242
243
n_folds > 2. Or, use leave-one-out cross-validation when
243
- n_folds = -1.
244
+ n_folds = -1. For Group K-fold cross-validation, functions as
245
+ `n_splits`.
246
+
247
+ kfolds_group: list or ndarray, default is None
248
+ When kfolds_group is not None, use Group K-fold cross-validation
249
+ with the specified groups. The length of kfolds_group must be
250
+ equal to the number of samples in X.
244
251
245
252
cv_shuffle: bool, default is False
246
253
Whether to perform sample shuffling before splitting the
@@ -277,6 +284,7 @@ def fit_iter(self, X, y,
277
284
EX_list = EX_list ,
278
285
valid_size = valid_size ,
279
286
n_folds = n_folds ,
287
+ kfolds_group = kfolds_group ,
280
288
shuffle = cv_shuffle ,
281
289
random_state = random_state ,
282
290
use_partial_fit = self .use_partial_fit ,
@@ -398,6 +406,7 @@ def fit(self, X, y,
398
406
EX_list : typing .Union [list , tuple ] = None ,
399
407
valid_size : float = .2 ,
400
408
n_folds : int = None ,
409
+ kfolds_group : typing .Union [list , np .ndarray ] = None ,
401
410
cv_shuffle : bool = False ,
402
411
warm_start : bool = False ,
403
412
random_state : np .random .Generator = np .random .default_rng ()
@@ -424,7 +433,13 @@ def fit(self, X, y,
424
433
n_folds: int, default is None
425
434
When n_folds is not None, use K-fold cross-validation when
426
435
n_folds > 2. Or, use leave-one-out cross-validation when
427
- n_folds = -1.
436
+ n_folds = -1. For Group K-fold cross-validation, functions as
437
+ `n_splits`.
438
+
439
+ kfolds_group: list or ndarray, default is None
440
+ When kfolds_group is not None, use Group K-fold cross-validation
441
+ with the specified groups. The length of group_kfolds must be
442
+ equal to the number of samples in X.
428
443
429
444
cv_shuffle: bool, default is False
430
445
Whether to perform sample shuffling before splitting the
@@ -450,6 +465,7 @@ def fit(self, X, y,
450
465
EX_list = EX_list ,
451
466
valid_size = valid_size ,
452
467
n_folds = n_folds ,
468
+ kfolds_group = kfolds_group ,
453
469
cv_shuffle = cv_shuffle ,
454
470
warm_start = warm_start ,
455
471
random_state = random_state )
0 commit comments