Skip to content

Commit 3800469

Browse files
committed
updated eval.py to have correct signature for test
1 parent 81abceb commit 3800469

File tree

4 files changed

+44
-25
lines changed

4 files changed

+44
-25
lines changed

eval.py

+38-22
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,50 @@
1212
FROM_GAMES, DATASETS
1313

1414

15-
def test_model(sess, model, dataset_iter, color_map, output_dir):
15+
def test_model(sess, model, dataset_iter, use_patches=False, color_map=None, output_dir=None):
1616
total_accuracy = 0
1717
class_correct_counts = np.zeros(model.num_classes)
1818
class_total_counts = np.zeros(model.num_classes)
1919
i = 0
20-
for image, labels in dataset_iter():
21-
i += 1
20+
for image, labels, img_id in dataset_iter():
2221
start_time = time.time()
2322
h, w, _ = image.shape
24-
25-
input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2)
26-
feed_dict = {model.inpt: [input_image], model.output: [labels]}
27-
logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict)
28-
predicted_labels = np.argmax(logits[0], axis=2)
29-
true_labels = labels[::4, ::4]
30-
31-
correct_labels = np.equal(predicted_labels, true_labels)
32-
accuracy = np.mean(correct_labels)
33-
total_accuracy += accuracy
34-
35-
for c in range(model.num_classes):
36-
current_class_labels = np.equal(true_labels, c)
37-
class_total_counts[c] += np.sum(current_class_labels)
38-
class_correct_counts[c] += np.sum(np.equal(true_labels, c) * correct_labels)
39-
40-
print "Error: %f Accuracy: %f (time: %.1fs)" % (error, accuracy, time.time() - start_time)
41-
42-
print "%d Images, Total Accuracy: %f" % (i, total_accuracy/i)
23+
if use_patches:
24+
patch_size = CNNModel.PATCH_SIZE
25+
for y in range(patch_size, h - patch_size):
26+
for x in range(patch_size, w - patch_size):
27+
patch = get_patch(image, center=(y, x), patch_size=patch_size)
28+
patch_labels = get_patch(labels, center=(y, x), patch_size=patch_size)
29+
input_patch = np.append(patch, np.zeros(shape=[patch_size, patch_size, model.num_classes],
30+
dtype=np.float32), axis=2)
31+
feed_dict = {model.inpt: [input_patch], model.output: [patch_labels]}
32+
logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict)
33+
predicted_label = np.argmax(logits[patch_size/2, patch_size/2, :])
34+
true_label = patch_labels[patch_size/2, patch_size/2]
35+
36+
class_total_counts[true_label] += 1
37+
if true_label == predicted_label:
38+
class_correct_counts[true_label] += 1
39+
else:
40+
i += 1
41+
input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2)
42+
feed_dict = {model.inpt: [input_image], model.output: [labels]}
43+
logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict)
44+
predicted_labels = np.argmax(logits[0], axis=2)
45+
true_labels = labels[::4, ::4]
46+
47+
correct_labels = np.equal(predicted_labels, true_labels)
48+
accuracy = np.mean(correct_labels)
49+
total_accuracy += accuracy
50+
51+
for c in range(model.num_classes):
52+
current_class_labels = np.equal(true_labels, c)
53+
class_total_counts[c] += np.sum(current_class_labels)
54+
class_correct_counts[c] += np.sum(np.equal(true_labels, c) * correct_labels)
55+
56+
print "Image: %s Error: %f Accuracy: %f (time: %.1fs)" % (img_id, error, accuracy, time.time() - start_time)
57+
58+
print "%d Images, Total Accuracy: %f" % (i, total_accuracy / i)
4359
print "Per Class accuracy:", class_correct_counts / class_total_counts
4460
print np.sum(class_correct_counts / class_total_counts)
4561

model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
class CNNModel:
9+
PATCH_SIZE = 67
10+
911
def __init__(self, hidden_size_1, hidden_size_2, batch_size, num_classes, learning_rate, num_layers):
1012
# TODO fix padding
1113
self.hidden_size_1 = hidden_size_1

preprocessing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,10 @@ def stanford_bgrounds_dataset(data_dir, train_fraction=None, num_train=None):
195195
for label_f, image_f in train_files:
196196
if os.path.basename(label_f).split('.')[0] != os.path.basename(image_f).split('.')[0]:
197197
print "UNEQUAL IMAGE NAMES!", label_f, image_f
198+
img_id = os.path.basename(label_f).split('.')[0]
198199
image = image_to_np_array(image_f)
199200
labels = text_labels_to_np_array(label_f)
200-
yield image, labels
201+
yield image, labels, img_id
201202

202203

203204
# list of datasets for which we have iterators

train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
def train(sess, model, dataset_iter, num_epochs, patch_size, patches_per_image=1000, save_path=None):
1414
for i in range(num_epochs):
1515
print 'Running epoch %d/%d...' % (i + 1, num_epochs)
16-
for image, labels in dataset_iter():
16+
for image, labels, img_id in dataset_iter():
1717
start_time = time.time()
1818
h, w, _ = image.shape
1919

2020
input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2)
2121
feed_dict = {model.inpt: [input_image], model.output: [labels]}
2222
loss, _ = sess.run([model.loss, model.train_step], feed_dict=feed_dict)
23-
print "Average error for this image: %f (time: %.1fs)" % (loss, time.time() - start_time)
23+
print "Average error for this image (%s): %f (time: %.1fs)" % (img_id, loss, time.time() - start_time)
2424

2525
if save_path is not None:
2626
print "Epoch %i finished, saving trained model to %s..." % (i + 1, save_path)

0 commit comments

Comments
 (0)