-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathknn.py
37 lines (33 loc) · 1.08 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
class KNN:
def fit(self, X, y, n_neighbors, distance):
'''
Parameters
----------
X : shape (n_samples, n_features)
Training data
y : shape (n_samples, n_classes)
Target values
n_neighbors : Number of neighbors
distance : Distance algorithm, see also distance.py
'''
self.__X = X
self.__y = y
self.__n_neighbors = n_neighbors
self.__distance = distance
def __predict(self, x):
distances = self.__distance(x, self.__X)
nearest_items = np.argpartition(distances, self.__n_neighbors - 1)[:self.__n_neighbors]
return np.argmax(np.bincount(self.__y[nearest_items].astype(int)))
def predict(self, X):
'''
Parameters
----------
X : shape (n_samples, n_features)
Predicting data
Returns
-------
y : shape (n_samples,)
Predicted class label per sample.
'''
return np.apply_along_axis(self.__predict, 1, X)