-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathdemo.py
118 lines (104 loc) · 4.39 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Clustering example using LFW data:
import os
import pandas as pd
from matplotlib import pyplot as plt
import argparse
import json
import scipy.io as sio
from clustering import cluster
from evaluation import calculate_pairwise_pr
def plot_histogram(lfw_dir):
"""
Function to plot the distribution of cluster sizes in LFW.
"""
filecount_dict = {}
for root, dirs, files in os.walk(lfw_dir):
for dirname in dirs:
n_photos = len(os.listdir(os.path.join(root, dirname)))
filecount_dict[dirname] = n_photos
print("No of unique people: {}".format(len(filecount_dict.keys())))
df = pd.DataFrame(filecount_dict.items(), columns=['Name', 'Count'])
print("Singletons : {}\nTwo :{}\n".format((df['Count'] == 1).sum(),
(df['Count'] == 2).sum()))
plt.hist(df['Count'], bins=max(df['Count']))
plt.title('Cluster Sizes')
plt.xlabel('No of images in folder')
plt.ylabel('No of folders')
plt.show()
def approximate_rank_order_clustering(vectors):
"""
Cluster the input vectors.
"""
clusters = cluster(vectors, n_neighbors=200, thresh= [1.1])
return clusters
def evaluate_clusters(clusters, labels_lookup):
"""
This function calculates the pairwise precision and recall for the
clusters.
Input:
clusters: list of lists
Each list contains a set of integers that correspond to a particular
image in the LFW dataset.
labels: dict
It is a dictionary where the keys are row numbers and the values
are lables(string).
Output:
pairwise_precision: float
Fraction of pair of samples within a cluster that belong to one
identity
pairwise_recall: float
Fraction of pairs of samples within a cluster which are placed in
the same cluster over the total number of same cluster pairs within
the dataset.
f1_score: float
Defined as the harmonic mean of precision and recall.
"""
precision, recall = calculate_pairwise_pr(clusters, labels_lookup)
f1_score = 2*precision*recall/(precision+recall)
print("Precision : {}\nRecall : {}\nf1_score : {}".format(precision,
recall,
f1_score
))
print("---------------------------------------------------------")
return f1_score
def create_labels_lookup(labels):
"""
Create a dictionary where the key is the row number and the value is the
actual label.
In this case, labels is an array where the position corresponds to the row
number and the value is an integer indicating the label.
"""
labels_lookup = {}
for idx, label in enumerate(labels):
labels_lookup[idx] = int(label[0][:])
return labels_lookup
if __name__ == '__main__':
parser = argparse.ArgumentParser('Approximate Rank Order Clustering Demo')
parser.add_argument('--lfw_path', required=True,
help='Enter tha directory where LFW images are saved.')
parser.add_argument('-v', '--vector_file', required=False,
help="Path to where the vectors to be clustered are saved.")
args = vars(parser.parse_args())
# plot_histogram(args['lfw_path'])
if args['vector_file']:
f = sio.loadmat(args['vector_file'])
vectors = f['features']
labels = f['labels_original'][0]
clusters_thresholds = approximate_rank_order_clustering(vectors)
clusters_at_th = clusters_thresholds[0]
clusters_to_be_saved = {}
for i, cluster in enumerate(clusters_at_th["clusters"]):
c = [int(x) for x in list(cluster)]
clusters_to_be_saved[i] = c
with open("data/clusters.json","w") as f:
json.dump(clusters_to_be_saved, f)
labels_lookup = create_labels_lookup(labels)
for clusters in clusters_thresholds:
print("No of clusters: {}".format(len(clusters['clusters'])))
print("Threshold : {}".format(clusters['threshold']))
f1_score = evaluate_clusters(clusters['clusters'], labels_lookup)
# n_faces = 0
# for c in clusters:
# print c
# n_faces += len(c)
# print 'No of faces : {}'.format(n_faces)