@@ -234,27 +234,22 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
234234 self ._check_resampling_strategy_args ()
235235
236236 labels_to_stratify = self .train_tensors [- 1 ] if self .is_stratify else None
237+ kwargs = {}
238+ kwargs .update (
239+ random_state = self .random_state ,
240+ shuffle = self .shuffle_split ,
241+ indices = self ._get_indices (),
242+ labels_to_stratify = labels_to_stratify
243+ )
237244
238245 if isinstance (self .resampling_strategy , HoldoutValTypes ):
239246 val_share = self .resampling_strategy_args .get ('val_share' , None )
247+ return self .resampling_strategy (val_share = val_share , ** kwargs )
240248
241- return self .resampling_strategy (
242- random_state = self .random_state ,
243- val_share = val_share ,
244- shuffle = self .shuffle_split ,
245- indices = self ._get_indices (),
246- labels_to_stratify = labels_to_stratify
247- )
248249 elif isinstance (self .resampling_strategy , CrossValTypes ):
249250 num_splits = self .resampling_strategy_args .get ('num_splits' , None )
251+ return self .resampling_strategy (num_splits = num_splits , ** kwargs )
250252
251- return self .resampling_strategy (
252- random_state = self .random_state ,
253- num_splits = num_splits ,
254- shuffle = self .shuffle_split ,
255- indices = self ._get_indices (),
256- labels_to_stratify = labels_to_stratify
257- )
258253 else :
259254 raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
260255
0 commit comments