@@ -65,7 +65,7 @@ def __init__(self, regressors, meta_regressor, verbose=0):
6565 _name_estimators ([meta_regressor ])}
6666 self .verbose = verbose
6767
68- def fit (self , X , y ):
68+ def fit (self , X , y , ** fit_params ):
6969 """Learn weight coefficients from training data for each regressor.
7070
7171 Parameters
@@ -75,18 +75,25 @@ def fit(self, X, y):
7575 n_features is the number of features.
7676 y : array-like, shape = [n_samples] or [n_samples, n_targets]
7777 Target values.
78+ fit_params : dict, optional
79+ Parameters to pass to the fit methods of `regressors` and
80+ `meta_regressor`.
7881
7982 Returns
8083 -------
8184 self : object
8285
8386 """
8487 self .regr_ = [clone (regr ) for regr in self .regressors ]
88+ self .named_regr_ = {key : value for key , value in
89+ _name_estimators (self .regr_ )}
8590 self .meta_regr_ = clone (self .meta_regressor )
91+ self .named_meta_regr_ = {'meta-%s' % key : value for key , value in
92+ _name_estimators ([self .meta_regr_ ])}
8693 if self .verbose > 0 :
8794 print ("Fitting %d regressors..." % (len (self .regressors )))
8895
89- for regr in self .regr_ :
96+ for name , regr in six . iteritems ( self .named_regr_ ) :
9097
9198 if self .verbose > 0 :
9299 i = self .regr_ .index (regr ) + 1
@@ -100,10 +107,23 @@ def fit(self, X, y):
100107 if self .verbose > 1 :
101108 print (_name_estimators ((regr ,))[0 ][1 ])
102109
103- regr .fit (X , y )
110+ # Extract fit_params for regr
111+ regr_fit_params = {}
112+ for key , value in six .iteritems (fit_params ):
113+ if name in key and 'meta-' not in key :
114+ regr_fit_params [key .replace (name + '__' , '' )] = value
115+
116+ regr .fit (X , y , ** regr_fit_params )
104117
105118 meta_features = self ._predict_meta_features (X )
106- self .meta_regr_ .fit (meta_features , y )
119+ # Extract fit_params for meta_regr_
120+ meta_fit_params = {}
121+ meta_regr_name = list (self .named_meta_regr_ .keys ())[0 ]
122+ for key , value in six .iteritems (fit_params ):
123+ if meta_regr_name in key and 'meta-' in meta_regr_name :
124+ meta_fit_params [key .replace (meta_regr_name + '__' , '' )] = value
125+ self .meta_regr_ .fit (meta_features , y , ** meta_fit_params )
126+
107127 return self
108128
109129 @property
0 commit comments