-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathresnet152_test.py
117 lines (94 loc) · 3.59 KB
/
resnet152_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import urllib
import io
import skimage.transform
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = 8, 6
import argparse
import time
import pickle
import numpy as np
import theano
import lasagne
parser = argparse.ArgumentParser(description="Getting top 5 classes of images")
add_arg = parser.add_argument
add_arg("-i", "--input_image", help="Input image")
add_arg("-m", "--model_file", help="Model pickle file")
args = parser.parse_args()
import resnet50
def prep_image(fname, mean_values):
t0 = time.time()
ext = fname.split('.')[-1]
im = plt.imread(fname, ext)
h, w, _ = im.shape
if h < w:
im = skimage.transform.resize(im, (256, w*256/h), preserve_range=True)
else:
im = skimage.transform.resize(im, (h*256/w, 256), preserve_range=True)
h, w, _ = im.shape
im = im[h//2-112:h//2+112, w//2-112:w//2+112]
# h, w, _ = im.shape
# im = skimage.transform.resize(im, (224, 224), preserve_range=True)
h, w, _ = im.shape
rawim = np.copy(im).astype('uint8')
im = np.swapaxes(np.swapaxes(im, 1, 2), 0, 1)
im = im[::-1, :, :]
im = im - mean_values
t1 = time.time()
print "Time taken in preparing the image : {}".format(t1 - t0)
return rawim, im[np.newaxis].astype('float32')
def get_net_fun(pkl_model):
net, mean_img, synset_words = resnet50.load_model(pkl_model)
get_class_prob = theano.function([net['input'].input_var], lasagne.layers.get_output(net['prob'],deterministic=True))
def print_top5(im_path):
raw_im, im = prep_image(im_path, mean_img)
prob = get_class_prob(im)[0]
res = sorted(zip(synset_words, prob), key=lambda t: t[1], reverse=True)[:5]
for c, p in res:
print ' ', c, p
return get_class_prob, print_top5
def get_feature_extractor(pkl_model, layer_name):
net, mean_img, synset_words = resnet50.load_model(pkl_model)
layer_output = theano.function([net['input'].input_var], lasagne.layers.get_output(net[layer_name],deterministic=True))
def feature_extractor(im_path):
raw_im, im = prep_image(im_path, mean_img)
return layer_output(im)[0]
return feature_extractor
if __name__ == "__main__":
print "Compiling functions..."
get_prob, print_top5 = get_net_fun(args.model_file)
t0 = time.time()
print_top5(args.input_image)
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))
print "Compiling function for getting conv1 ...."
feature_extractor = get_feature_extractor(args.model_file, 'conv1')
t0 = time.time()
print feature_extractor(args.input_image).shape
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))
print "Compiling function for getting res2c ...."
feature_extractor = get_feature_extractor(args.model_file, 'res2c')
t0 = time.time()
print feature_extractor(args.input_image).shape
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))
print "Compiling function for getting res3d ...."
feature_extractor = get_feature_extractor(args.model_file, 'res3d')
t0 = time.time()
print feature_extractor(args.input_image).shape
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))
print "Compiling function for getting conv res4f ...."
feature_extractor = get_feature_extractor(args.model_file, 'res4f')
t0 = time.time()
print feature_extractor(args.input_image).shape
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))
print "Compiling function for getting conv res5c ...."
feature_extractor = get_feature_extractor(args.model_file, 'res5c')
t0 = time.time()
print feature_extractor(args.input_image).shape
t1 = time.time()
print("Total time taken {:.4f}".format(t1 - t0))