-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvoting.py
30 lines (25 loc) · 847 Bytes
/
voting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from ..base import BaseClassifier,BaseRegressor
import numpy as np
class BaseVoting:
def __init__(self,models):
self.models = models
def fit(self,X,y):
for model in self.models :
model.fit(X,y)
def predict_all(self,X):
res = []
for model in self.models :
res.append(model.predict(X))
return res
class VotingClassifier(BaseVoting,BaseClassifier):
def predict(self,X):
res = np.array(self.predict_all(X))
y_hat = []
for i in range(X.shape[0]):
labels,count = np.unique(res[:,i],return_counts=True)
y_hat.append(labels[count.argmax()])
return np.array(y_hat)
class VotingRegressor(BaseVoting,BaseRegressor):
def predict(self,X):
res = self.predict_all(X)
return np.mean(res,axis=0)