Skip to content

Commit 03a9308

Browse files
committed
Custom Mean Shift
1 parent 1b335ae commit 03a9308

File tree

1 file changed

+63
-69
lines changed

1 file changed

+63
-69
lines changed

Clustering/Mean-Shift/Custom_Mean-Shift.py

+63-69
Original file line numberDiff line numberDiff line change
@@ -3,94 +3,88 @@
33
style.use('ggplot')
44
import numpy as np
55

6-
class K_Means:
7-
def __init__(self, k=2, tol=0.001, max_iter=300):
8-
self.k = k
9-
self.tol = tol
10-
self.max_iter = max_iter
6+
class Mean_Shift:
7+
def __init__(self, radius=4):
8+
self.radius = radius
119

12-
def fit(self,data):
10+
def fit(self, data):
11+
centroids = {}
12+
## make id for values
13+
for i in range(len(data)):
14+
centroids[i] = data[i]
15+
#print(centroids)
1316

14-
self.centroids = {} # self.centroids this mean center get best center point
17+
18+
#Make all datapoints centroids
19+
#Take mean of all featuresets within centroid's radius, setting this mean as new centroid.
20+
#Repeat step #2 until convergence.
21+
while True:
22+
new_centroids = []
23+
for i in centroids:
24+
in_bandwidth = []
25+
centroid = centroids[i]
26+
for featureset in data:
27+
## if distance between featureset and centroid less then radius
28+
## add featureset in bandwidth list
29+
if np.linalg.norm(featureset-centroid) < self.radius:
30+
in_bandwidth.append(featureset)
1531

16-
# add point from 0 to k in the dictionary ... just start k point
17-
for i in range(self.k):
18-
self.centroids[i] = data[i]
19-
32+
# get the average between values in bandwidth list
33+
new_centroid = np.average(in_bandwidth,axis=0)
34+
new_centroids.append(tuple(new_centroid))
2035

21-
for i in range(self.max_iter):
22-
self.classifications = {}
23-
## create classification dictionary and set it empty list form 0 to k
24-
for i in range(self.k):
25-
self.classifications[i] = []
26-
27-
for featureset in data:
28-
## get distances between featureset and center point
29-
distances = [np.linalg.norm(featureset-self.centroids[centroid]) for centroid in self.centroids]
30-
## get index of min value in distances is go to group 0 or 1 or ... to k
31-
classification = distances.index(min(distances))
32-
## add featureset to the the gourp of k
33-
# exmp: if you have k = 2 you have 2 group
34-
# so add the distances min value to her group 1 or 0
35-
self.classifications[classification].append(featureset)
36-
37-
## last centroids point before get average
38-
prev_centroids = dict(self.centroids)
39-
for classification in self.classifications:
40-
## get average to the point of the gourp 0 or 1 or ... in range k and add her to centroids dictionary
41-
self.centroids[classification] = np.average(self.classifications[classification],axis=0)
36+
uniques = sorted(list(set(new_centroids))) # sorted and remove duplicate value
37+
print(new_centroids)
38+
prev_centroids = dict(centroids)
4239

40+
centroids = {}
41+
for i in range(len(uniques)):
42+
centroids[i] = np.array(uniques[i]) ## add uniques in centroids
4343

44-
## compare between original_centroid and current_centroid
4544
optimized = True
46-
for c in self.centroids:
47-
original_centroid = prev_centroids[c]
48-
current_centroid = self.centroids[c]
49-
## get all current_centroid and original_centroid if they are within our required tolerance, this is good
50-
## else the optimized = False and stop fist loop for i in range(self.max_iter):
51-
if np.sum((current_centroid-original_centroid)/original_centroid*100.0) > self.tol:
52-
#print(c, np.sum((current_centroid-original_centroid)/original_centroid*100.0))
45+
'''
46+
Here we note the previous centroids, before we begin to reset "current" or "new" centroids
47+
by setting them as the uniques. Finally, we compare the previous centroids to the new ones, and measure movement.
48+
If any of the centroids have moved, then we're not content that we've got full convergence
49+
and optimization, and we want to go ahead and run another cycle.
50+
If we are optimized, great, we break, and then finally set the centroids attribute to the final centroids we came up with.
51+
'''
52+
53+
for i in centroids:
54+
if not np.array_equal(centroids[i], prev_centroids[i]):
5355
optimized = False
54-
56+
if not optimized:
57+
break
58+
5559
if optimized:
5660
break
61+
62+
self.centroids = centroids
63+
5764
def predict(self,data):
58-
## get distances between featureset and center point (centroid)
59-
distances = [np.linalg.norm(data-self.centroids[centroid]) for centroid in self.centroids]
60-
## get index of min value in distances is go to group 0 or 1 or ... to k
61-
classification = distances.index(min(distances))
62-
return classification
65+
pass
6366

6467

6568
X = np.array([[1, 2],
6669
[1.5, 1.8],
6770
[5, 8 ],
6871
[8, 8],
6972
[1, 0.6],
70-
[9,11]])
73+
[9,11],
74+
[8,2],
75+
[10,2],
76+
[9,3],])
77+
7178
colors = 10*["g","r","c","b","k"]
72-
clf = K_Means(k=2)
79+
80+
clf = Mean_Shift()
7381
clf.fit(X)
7482

75-
## scatter center point of ths groups
76-
for centroid in clf.centroids:
77-
plt.scatter(clf.centroids[centroid][0], clf.centroids[centroid][1],
78-
marker="o", color="k", s=150, linewidths=5)
79-
80-
## scatter classification point of ths group
81-
for classification in clf.classifications:
82-
color = colors[classification]
83-
for featureset in clf.classifications[classification]:
84-
plt.scatter(featureset[0], featureset[1], marker="x", color=color, s=150, linewidths=5)
85-
86-
87-
## predict new features
88-
new_features = np.array([[1, 3],
89-
[8, 9],
90-
[0, 3 ],
91-
[5, 4],
92-
[6, 4],])
93-
for feature in new_features:
94-
classification = clf.predict(feature)
95-
plt.scatter(feature[0], feature[1], marker="*", color=colors[classification], s=150, linewidths=5)
83+
centroids = clf.centroids
84+
85+
plt.scatter(X[:,0], X[:,1], s=150)
86+
87+
for c in centroids:
88+
plt.scatter(centroids[c][0], centroids[c][1], color='k', marker='*', s=150)
89+
9690
plt.show()

0 commit comments

Comments
 (0)