Skip to content

Commit e457746

Browse files
authored
updating sklearn cross validation syntax to newer version (hyperopt#78)
* updating sklearn cross validation syntax to newer version * fixing parameter name in newer version of KFold
1 parent e0c277c commit e457746

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

hpsklearn/estimator.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
from multiprocessing import Process, Pipe
77
import time
88
from sklearn.base import BaseEstimator
9-
from sklearn.cross_validation import KFold, StratifiedKFold, LeaveOneOut, \
10-
ShuffleSplit, StratifiedShuffleSplit, \
11-
PredefinedSplit
129
from sklearn.metrics import accuracy_score, r2_score
1310
from sklearn.decomposition import PCA
11+
try:
12+
from sklearn.model_selection import KFold, StratifiedKFold, LeaveOneOut, \
13+
ShuffleSplit, StratifiedShuffleSplit, \
14+
PredefinedSplit
15+
except ImportError:
16+
# sklearn.cross_validation is deprecated in version 0.18 of sklearn
17+
from sklearn.cross_validation import KFold, StratifiedKFold, LeaveOneOut, \
18+
ShuffleSplit, StratifiedShuffleSplit, \
19+
PredefinedSplit
1420

1521
# For backwards compatibility with older versions of hyperopt.fmin
1622
import inspect
@@ -215,37 +221,69 @@ def _cost_fn(argd, X, y, EX_list, valid_size, n_folds, shuffle, random_state,
215221
if n_folds is not None:
216222
if n_folds == -1:
217223
info('Will use leave-one-out CV')
218-
cv_iter = LeaveOneOut(len(y))
224+
try:
225+
cv_iter = LeaveOneOut().split(X)
226+
except TypeError:
227+
# Older syntax before sklearn version 0.18
228+
cv_iter = LeaveOneOut(len(y))
219229
elif is_classif:
220230
info('Will use stratified K-fold CV with K:', n_folds,
221231
'and Shuffle:', shuffle)
222-
cv_iter = StratifiedKFold(y, n_folds=n_folds,
223-
shuffle=shuffle,
224-
random_state=random_state)
232+
try:
233+
cv_iter = StratifiedKFold(n_splits=n_folds,
234+
shuffle=shuffle,
235+
random_state=random_state
236+
).split(X, y)
237+
except TypeError:
238+
# Older syntax before sklearn version 0.18
239+
cv_iter = StratifiedKFold(y, n_folds=n_folds,
240+
shuffle=shuffle,
241+
random_state=random_state)
225242
else:
226243
info('Will use K-fold CV with K:', n_folds,
227244
'and Shuffle:', shuffle)
228-
cv_iter = KFold(len(y), n_folds=n_folds,
229-
shuffle=shuffle,
230-
random_state=random_state)
245+
try:
246+
cv_iter = KFold(n_splits=n_folds,
247+
shuffle=shuffle,
248+
random_state=random_state).split(X)
249+
except TypeError:
250+
# Older syntax before sklearn version 0.18
251+
cv_iter = KFold(len(y), n_folds=n_folds,
252+
shuffle=shuffle,
253+
random_state=random_state)
231254
else:
232255
if not shuffle: # always choose the last samples.
233256
info('Will use the last', valid_size,
234257
'portion of samples for validation')
235258
n_train = int(len(y) * (1 - valid_size))
236259
valid_fold = np.ones(len(y), dtype=np.int)
237260
valid_fold[:n_train] = -1 # "-1" indicates train fold.
238-
cv_iter = PredefinedSplit(valid_fold)
261+
try:
262+
cv_iter = PredefinedSplit(valid_fold).split()
263+
except TypeError:
264+
# Older syntax before sklearn version 0.18
265+
cv_iter = PredefinedSplit(valid_fold)
239266
elif is_classif:
240267
info('Will use stratified shuffle-and-split with validation \
241268
portion:', valid_size)
242-
cv_iter = StratifiedShuffleSplit(y, 1, test_size=valid_size,
243-
random_state=random_state)
269+
try:
270+
cv_iter = StratifiedShuffleSplit(1, test_size=valid_size,
271+
random_state=random_state
272+
).split(X, y)
273+
except TypeError:
274+
# Older syntax before sklearn version 0.18
275+
cv_iter = StratifiedShuffleSplit(y, 1, test_size=valid_size,
276+
random_state=random_state)
244277
else:
245278
info('Will use shuffle-and-split with validation portion:',
246279
valid_size)
247-
cv_iter = ShuffleSplit(len(y), 1, test_size=valid_size,
248-
random_state=random_state)
280+
try:
281+
cv_iter = ShuffleSplit(n_splits=1, test_size=valid_size,
282+
random_state=random_state).split(X)
283+
except TypeError:
284+
# Older syntax before sklearn version 0.18
285+
cv_iter = ShuffleSplit(len(y), 1, test_size=valid_size,
286+
random_state=random_state)
249287

250288
# Use the above iterator for cross-validation prediction.
251289
cv_y_pool = np.array([])

0 commit comments

Comments
 (0)