Skip to content

Commit 81abceb

Browse files
committed
updated eval.py
1 parent 2e5a141 commit 81abceb

File tree

3 files changed

+44
-69
lines changed

3 files changed

+44
-69
lines changed

eval.py

+42-62
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,44 @@
44

55
import numpy as np
66
import tensorflow as tf
7+
import time
78

89
from model import CNNModel
910
from model import restore_model
10-
from preprocessing import read_object_classes, image_to_np_array, labels_to_np_array, get_patch, save_labels_array
11-
12-
13-
def test_model(sess, model, images, labels, patch_size, output_dir=None, category_colors=None):
14-
"""
15-
Tests the model on the given images and labels
16-
:param patch_size:
17-
:param sess: The tensorflow session within which to run the model
18-
:param model: An rCNN model
19-
:param images: A series of image filenames
20-
:param labels: A series of label filenames (corresponding images/labels should be in the same order)
21-
:param output_dir: An (optional) directory in which to store predicted labels as images.
22-
:param category_colors: A mapping of category index to color, to create images
23-
"""
24-
for image_f, label_f in zip(images, labels):
25-
print "Testing on image %s..." % image_f
26-
image = image_to_np_array(image_f, float_cols=True)
27-
labels = labels_to_np_array(label_f)
11+
from preprocessing import read_object_classes, image_to_np_array, labels_to_np_array, get_patch, save_labels_array, \
12+
FROM_GAMES, DATASETS
13+
14+
15+
def test_model(sess, model, dataset_iter, color_map, output_dir):
16+
total_accuracy = 0
17+
class_correct_counts = np.zeros(model.num_classes)
18+
class_total_counts = np.zeros(model.num_classes)
19+
i = 0
20+
for image, labels in dataset_iter():
21+
i += 1
22+
start_time = time.time()
2823
h, w, _ = image.shape
29-
image = image[:h//2, :w//2, :]
30-
h, w, _ = image.shape
31-
labels = labels[:h, :w]
32-
predicted_labels = np.zeros([h, w], dtype=np.uint8)
33-
pixels_correct = 0
34-
error_for_image = 0
35-
i = 0
36-
37-
for y in range(patch_size, h - patch_size):
38-
# # for debug, only do first 10K
39-
# if i > 1e4:
40-
# break
41-
42-
for x in range(patch_size, w - patch_size):
43-
i += 1
44-
input_image = get_patch(image, (y, x), patch_size)
45-
input_image = np.append(input_image,
46-
np.zeros(shape=[patch_size, patch_size, model.num_classes], dtype=np.float32),
47-
axis=2)
48-
input_label = labels[y, x]
49-
feed_dict = {model.inpt: [input_image], model.output: input_label}
50-
51-
error, logits = sess.run([model.errors[1], model.logits[1]], feed_dict=feed_dict)
52-
error_for_image += error
53-
output_label = np.argmax(logits)
54-
if output_label == input_label:
55-
pixels_correct += 1
56-
predicted_labels[y, x] = output_label
57-
58-
if i % 1000 == 0:
59-
print "%d/%d pixels done..." % (i, (h - 2 * patch_size) * (w - 2 * patch_size))
60-
61-
# print "Tested on image %s: Accuracy is %.2f%%, error per pixel is %f." % (
62-
# image_f, (100.0 * pixels_correct) / i, error_for_image / i)
63-
if output_dir is not None:
64-
if category_colors is None:
65-
raise ValueError("Color index not provided, can't output images.")
66-
output_filename = os.path.join(output_dir, os.path.splitext(os.path.basename(label_f))[0] + '_test.png')
67-
print "output: ", output_filename
68-
save_labels_array(predicted_labels, output_filename=output_filename, colors=category_colors)
24+
25+
input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2)
26+
feed_dict = {model.inpt: [input_image], model.output: [labels]}
27+
logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict)
28+
predicted_labels = np.argmax(logits[0], axis=2)
29+
true_labels = labels[::4, ::4]
30+
31+
correct_labels = np.equal(predicted_labels, true_labels)
32+
accuracy = np.mean(correct_labels)
33+
total_accuracy += accuracy
34+
35+
for c in range(model.num_classes):
36+
current_class_labels = np.equal(true_labels, c)
37+
class_total_counts[c] += np.sum(current_class_labels)
38+
class_correct_counts[c] += np.sum(np.equal(true_labels, c) * correct_labels)
39+
40+
print "Error: %f Accuracy: %f (time: %.1fs)" % (error, accuracy, time.time() - start_time)
41+
42+
print "%d Images, Total Accuracy: %f" % (i, total_accuracy/i)
43+
print "Per Class accuracy:", class_correct_counts / class_total_counts
44+
print np.sum(class_correct_counts / class_total_counts)
6945

7046

7147
def main():
@@ -74,8 +50,9 @@ def main():
7450
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7551
parser.add_argument('--model', type=str, help='Filename of saved model')
7652
parser.add_argument('--category_map', type=str, help='File that maps colors ')
77-
parser.add_argument('--images', type=str, nargs='+', help='Filename of test image')
78-
parser.add_argument('--labels', type=str, nargs='+', help='Filename of test labels')
53+
parser.add_argument('--dataset', type=str, default=FROM_GAMES, choices=DATASETS.keys(),
54+
help='Type of dataset to use. This determines the expected format of the data directory')
55+
parser.add_argument('--data_dir', type=str, help='Directory for image and label data')
7956
parser.add_argument('--output_dir', type=str, default=None,
8057
help='Directory to store model output. By default no output is generated.')
8158
parser.add_argument('--patch_size', type=int, default=67, help='Size of input patches')
@@ -84,14 +61,17 @@ def main():
8461
# load class labels
8562
category_colors, category_names, names_to_ids = read_object_classes(args.category_map)
8663
num_classes = len(category_names)
64+
65+
# load dataset
66+
def dataset_func(): return DATASETS[args.dataset](args.data_dir)
67+
8768
# TODO don't hardcode these (maybe store them in config file?)
8869
model = CNNModel(25, 50, 1, num_classes, 1e-4, num_layers=2)
8970

9071
sess = tf.Session()
9172
restore_model(sess, args.model)
9273

93-
test_model(sess, model, args.images, args.labels, patch_size=args.patch_size, output_dir=args.output_dir,
94-
category_colors=category_colors)
74+
test_model(sess, model, dataset_func)
9575

9676

9777
if __name__ == '__main__':

models/checkpoint

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
model_checkpoint_path: "saved_model.ckpt"
2-
all_model_checkpoint_paths: "saved_model.ckpt"
1+
model_checkpoint_path: "stanford_model_10.ckpt"
2+
all_model_checkpoint_paths: "stanford_model_10.ckpt"

train.py

-5
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ def train(sess, model, dataset_iter, num_epochs, patch_size, patches_per_image=1
1414
for i in range(num_epochs):
1515
print 'Running epoch %d/%d...' % (i + 1, num_epochs)
1616
for image, labels in dataset_iter():
17-
for row in labels:
18-
for l in row:
19-
if l < 0 or l >= model.num_classes:
20-
print "INVALID label:", l
21-
2217
start_time = time.time()
2318
h, w, _ = image.shape
2419

0 commit comments

Comments
 (0)