@@ -111,29 +111,36 @@ def __init__(self, classifiers, meta_classifier,
111111 self .stratify = stratify
112112 self .shuffle = shuffle
113113
114- def fit (self , X , y , groups = None ):
114+ def fit (self , X , y , groups = None , ** fit_params ):
115115 """ Fit ensemble classifers and the meta-classifier.
116116
117117 Parameters
118118 ----------
119119 X : numpy array, shape = [n_samples, n_features]
120120 Training vectors, where n_samples is the number of samples and
121121 n_features is the number of features.
122-
123122 y : numpy array, shape = [n_samples]
124123 Target values.
125-
126124 groups : numpy array/None, shape = [n_samples]
127125 The group that each sample belongs to. This is used by specific
128126 folding strategies such as GroupKFold()
127+ fit_params : dict, optional
128+ Parameters to pass to the fit methods of `classifiers` and
129+ `meta_classifier`. Note that only fit parameters for `classifiers`
130+ that are the same for each cross-validation split are supported
131+ (e.g. `sample_weight` is not currently supported).
129132
130133 Returns
131134 -------
132135 self : object
133136
134137 """
135138 self .clfs_ = [clone (clf ) for clf in self .classifiers ]
139+ self .named_clfs_ = {key : value for key , value in
140+ _name_estimators (self .clfs_ )}
136141 self .meta_clf_ = clone (self .meta_classifier )
142+ self .named_meta_clf_ = {'meta-%s' % key : value for key , value in
143+ _name_estimators ([self .meta_clf_ ])}
137144 if self .verbose > 0 :
138145 print ("Fitting %d classifiers..." % (len (self .classifiers )))
139146
@@ -144,8 +151,23 @@ def fit(self, X, y, groups=None):
144151 final_cv .shuffle = self .shuffle
145152 skf = list (final_cv .split (X , y , groups ))
146153
154+ # Get fit_params for each classifier in self.named_clfs_
155+ named_clfs_fit_params = {}
156+ for name , clf in six .iteritems (self .named_clfs_ ):
157+ clf_fit_params = {}
158+ for key , value in six .iteritems (fit_params ):
159+ if name in key and 'meta-' not in key :
160+ clf_fit_params [key .replace (name + '__' , '' )] = value
161+ named_clfs_fit_params [name ] = clf_fit_params
162+ # Get fit_params for self.named_meta_clf_
163+ meta_fit_params = {}
164+ meta_clf_name = list (self .named_meta_clf_ .keys ())[0 ]
165+ for key , value in six .iteritems (fit_params ):
166+ if meta_clf_name in key and 'meta-' in meta_clf_name :
167+ meta_fit_params [key .replace (meta_clf_name + '__' , '' )] = value
168+
147169 all_model_predictions = np .array ([]).reshape (len (y ), 0 )
148- for model in self .clfs_ :
170+ for name , model in six . iteritems ( self .named_clfs_ ) :
149171
150172 if self .verbose > 0 :
151173 i = self .clfs_ .index (model ) + 1
@@ -172,7 +194,8 @@ def fit(self, X, y, groups=None):
172194 ((num + 1 ), final_cv .get_n_splits ()))
173195
174196 try :
175- model .fit (X [train_index ], y [train_index ])
197+ model .fit (X [train_index ], y [train_index ],
198+ ** named_clfs_fit_params [name ])
176199 except TypeError as e :
177200 raise TypeError (str (e ) + '\n Please check that X and y'
178201 'are NumPy arrays. If X and y are lists'
@@ -215,16 +238,17 @@ def fit(self, X, y, groups=None):
215238 X [test_index ]))
216239
217240 # Fit the base models correctly this time using ALL the training set
218- for model in self .clfs_ :
219- model .fit (X , y )
241+ for name , model in six . iteritems ( self .named_clfs_ ) :
242+ model .fit (X , y , ** named_clfs_fit_params [ name ] )
220243
221244 # Fit the secondary model
222245 if not self .use_features_in_secondary :
223- self .meta_clf_ .fit (all_model_predictions , reordered_labels )
246+ self .meta_clf_ .fit (all_model_predictions , reordered_labels ,
247+ ** meta_fit_params )
224248 else :
225249 self .meta_clf_ .fit (np .hstack ((reordered_features ,
226250 all_model_predictions )),
227- reordered_labels )
251+ reordered_labels , ** meta_fit_params )
228252
229253 return self
230254
0 commit comments