-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_on_images.py
executable file
·133 lines (110 loc) · 4.97 KB
/
predict_on_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import json
import time
import datetime
import numpy as np
import code
import os
import cPickle as pickle
import math
import scipy.io
from imagernn.solver import Solver
from imagernn.imagernn_utils import decodeGenerator, eval_split
"""
This script is used to predict sentences for arbitrary images
that are located in a folder we call root_folder. It is assumed that
the root_folder contains:
- the raw images
- a file tasks.txt that lists the images you'd like to use
- a file vgg_feats.mat that contains the CNN features.
You'll need to use the Matlab script I provided and point it at the
root folder and its tasks.txt file to save the features.
Then point this script at the folder and at a checkpoint model you'd
like to evaluate.
"""
def main(params):
# load the checkpoint
checkpoint_path = params['checkpoint_path']
print 'loading checkpoint %s' % (checkpoint_path, )
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
checkpoint_params = checkpoint['params']
dataset = checkpoint_params['dataset']
model = checkpoint['model']
misc = {}
misc['wordtoix'] = checkpoint['wordtoix']
ixtoword = checkpoint['ixtoword']
# output blob which we will dump to JSON for visualizing the results
blob = {}
blob['params'] = params
blob['checkpoint_params'] = checkpoint_params
blob['imgblobs'] = []
# load the tasks.txt file
# TODO FIND EASY WAY TO CALL FILE WITH PROPER root
root_path = params['root_path']
img_names = open(os.path.join(root_path, 'tasks.txt'), 'r').read().splitlines()
# load the features for all images
features_path = os.path.join(root_path, 'vgg_feats.mat')
features_struct = scipy.io.loadmat(features_path)
features = features_struct['feats'] # this is a 4096 x N numpy array of features
D,N = features.shape
# iterate over all images and predict sentences
BatchGenerator = decodeGenerator(checkpoint_params)
for n in xrange(N):
print 'image %d/%d:' % (n, N)
# encode the image
img = {}
img['feat'] = features[:, n]
img['local_file_path'] = img_names[n]
print img['local_file_path']
# perform the work. heavy lifting happens inside
kwparams = { 'beam_size' : params['beam_size'] }
Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams)
print Ys
# build up the output
img_blob = {}
img_blob['img_path'] = img['local_file_path']
# encode the top prediction
# top_predictions = Ys[0] # take predictions for the first (and only) image we passed in
# top_prediction = top_predictions[0] # these are sorted with highest on top
# candidate = ' '.join([ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that
# print 'PRED: (%f) %s' % (top_prediction[0], candidate)
# img_blob['candidate'] = {'text': candidate, 'logprob': top_prediction[0]}
# blob['imgblobs'].append(img_blob)
# encode the top prediction my attempt at showing all candidates
img_blob['candidates'] = []
top_predictions = Ys[0] # take predictions for the first (and only) image we passed in
print 'Number of name candidates', top_predictions
# TODO TIME IT. SEEMS PRETTY FAST THOUGH
for i in range(0, len(top_predictions)):
top_prediction = top_predictions[i] # these are sorted with highest on top
candidate = ' '.join([ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that
print [ixtoword[ix] for ix in top_prediction[1] if ix > 0]
print 'PRED: (%f) %s' % (top_prediction[0], candidate)
img_blob['candidates'].append({'text': candidate, 'logprob': top_prediction[0]})
# VERY IMPORTANT LINE
blob['imgblobs'].append(img_blob)
# dump result struct to file
save_file = os.path.join(root_path, 'result_struct.json')
print 'writing predictions to %s...' % (save_file, )
json.dump(blob, open(save_file, 'w'))
# dump output html
html = ''
for img in blob['imgblobs']:
html += '<img src="%s" height="400"><br>' % (img['img_path'], )
#print 'Number of name candidates', len(img['candidates'])
for i in range(0, len(img['candidates'])):
html += '(%f) %s <br><br>' % (img['candidates'][i]['logprob'], img['candidates'][i]['text'])
#print html
html_file = os.path.join(root_path, 'result.html')
print 'writing html result file to %s...' % (html_file, )
open(html_file, 'w').write(html)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint_path', type=str, help='the input checkpoint')
parser.add_argument('-r', '--root_path', default='example_images', type=str, help='folder with the images, tasks.txt file, and corresponding vgg_feats.mat file')
parser.add_argument('-b', '--beam_size', type=int, default=1, help='beam size in inference. 1 indicates greedy per-word max procedure. Good value is approx 20 or so, and more = better.')
args = parser.parse_args()
params = vars(args) # convert to ordinary dict
print 'parsed parameters:'
print json.dumps(params, indent = 2)
main(params)