@@ -28,14 +28,15 @@ class KMeans:
28
28
Implementation detail: <https://en.wikipedia.org/wiki/K-means_clustering>
29
29
"""
30
30
31
- def __init__ (self , obs , k , initMode = 'kmeans++' , iters = 1000 , compare = 'sqeuclidean' ):
31
+ def __init__ (self , obs , k , initMode = 'kmeans++' , distance = 'sqeuclidean' , iters = 1000 ):
32
32
"""
33
33
Initializes the algorithm with observation, number of k clusters, the initial method and
34
34
the maximum number of iterations.
35
35
Initialization method of random cluster choice can be: forgy, uniform, random, plusplus
36
36
:param obs: genomic data / matrix
37
37
:param k: number of clusters
38
38
:param initMode: initialization method
39
+ :param distance: distance measurement
39
40
:param iters: number of maximum iterations
40
41
:return:
41
42
"""
@@ -59,7 +60,7 @@ def __init__(self, obs, k, initMode='kmeans++', iters=1000, compare='sqeuclidean
59
60
# initialization method
60
61
self .__initMode = initMode
61
62
# compare function
62
- self .__compare = compare
63
+ self .__distance = distance
63
64
64
65
# ------------------------------------------------------------------------------------------------------------------
65
66
@@ -150,7 +151,7 @@ def __plusplusMethod(self):
150
151
probs .fill (maxValue )
151
152
# compute new probabilities, choose min of all distances
152
153
for j in range (0 , i ):
153
- dists = similarityMeasurement (self .__obs , self .__clusterMeans [j ], self .__compare )
154
+ dists = similarityMeasurement (self .__obs , self .__clusterMeans [j ], self .__distance )
154
155
# collect minimum squared distances to cluster centroids
155
156
probs = np .minimum (probs , dists )
156
157
@@ -210,7 +211,7 @@ def __assignment(self):
210
211
value = self .__obs [i ]
211
212
212
213
# compute squared distances to each mean
213
- dists = similarityMeasurement (self .__clusterMeans , value , self .__compare )
214
+ dists = similarityMeasurement (self .__clusterMeans , value , self .__distance )
214
215
# nearest cluster
215
216
nearestID = np .argmin (dists )
216
217
@@ -347,12 +348,12 @@ def _plugin_initialize():
347
348
348
349
# ----------------------------------------------------------------------------------------------------------------------
349
350
350
- def create (data , k , initMethod ):
351
+ def create (data , k , initMethod , distance ):
351
352
"""
352
353
by convention contain a factory called create returning the extension implementation
353
354
:return:
354
355
"""
355
- return KMeans (data , k , initMethod )
356
+ return KMeans (data , k , initMethod , distance )
356
357
357
358
########################################################################################################################
358
359
@@ -377,7 +378,7 @@ def create(data, k, initMethod):
377
378
378
379
for i in range (10 ):
379
380
s1 = timer ()
380
- kMeansPlus = KMeans (data , k , 'kmeans++' , 10 )
381
+ kMeansPlus = KMeans (data , k , 'kmeans++' , 'sqeuclidean' , 10 )
381
382
result1 = kMeansPlus .run ()
382
383
#print(result)
383
384
e1 = timer ()
0 commit comments