|
12 | 12 | FROM_GAMES, DATASETS
|
13 | 13 |
|
14 | 14 |
|
15 |
| -def test_model(sess, model, dataset_iter, color_map, output_dir): |
| 15 | +def test_model(sess, model, dataset_iter, use_patches=False, color_map=None, output_dir=None): |
16 | 16 | total_accuracy = 0
|
17 | 17 | class_correct_counts = np.zeros(model.num_classes)
|
18 | 18 | class_total_counts = np.zeros(model.num_classes)
|
19 | 19 | i = 0
|
20 |
| - for image, labels in dataset_iter(): |
21 |
| - i += 1 |
| 20 | + for image, labels, img_id in dataset_iter(): |
22 | 21 | start_time = time.time()
|
23 | 22 | h, w, _ = image.shape
|
24 |
| - |
25 |
| - input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2) |
26 |
| - feed_dict = {model.inpt: [input_image], model.output: [labels]} |
27 |
| - logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict) |
28 |
| - predicted_labels = np.argmax(logits[0], axis=2) |
29 |
| - true_labels = labels[::4, ::4] |
30 |
| - |
31 |
| - correct_labels = np.equal(predicted_labels, true_labels) |
32 |
| - accuracy = np.mean(correct_labels) |
33 |
| - total_accuracy += accuracy |
34 |
| - |
35 |
| - for c in range(model.num_classes): |
36 |
| - current_class_labels = np.equal(true_labels, c) |
37 |
| - class_total_counts[c] += np.sum(current_class_labels) |
38 |
| - class_correct_counts[c] += np.sum(np.equal(true_labels, c) * correct_labels) |
39 |
| - |
40 |
| - print "Error: %f Accuracy: %f (time: %.1fs)" % (error, accuracy, time.time() - start_time) |
41 |
| - |
42 |
| - print "%d Images, Total Accuracy: %f" % (i, total_accuracy/i) |
| 23 | + if use_patches: |
| 24 | + patch_size = CNNModel.PATCH_SIZE |
| 25 | + for y in range(patch_size, h - patch_size): |
| 26 | + for x in range(patch_size, w - patch_size): |
| 27 | + patch = get_patch(image, center=(y, x), patch_size=patch_size) |
| 28 | + patch_labels = get_patch(labels, center=(y, x), patch_size=patch_size) |
| 29 | + input_patch = np.append(patch, np.zeros(shape=[patch_size, patch_size, model.num_classes], |
| 30 | + dtype=np.float32), axis=2) |
| 31 | + feed_dict = {model.inpt: [input_patch], model.output: [patch_labels]} |
| 32 | + logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict) |
| 33 | + predicted_label = np.argmax(logits[patch_size/2, patch_size/2, :]) |
| 34 | + true_label = patch_labels[patch_size/2, patch_size/2] |
| 35 | + |
| 36 | + class_total_counts[true_label] += 1 |
| 37 | + if true_label == predicted_label: |
| 38 | + class_correct_counts[true_label] += 1 |
| 39 | + else: |
| 40 | + i += 1 |
| 41 | + input_image = np.append(image, np.zeros(shape=[h, w, model.num_classes], dtype=np.float32), axis=2) |
| 42 | + feed_dict = {model.inpt: [input_image], model.output: [labels]} |
| 43 | + logits, error = sess.run([model.logits[1], model.loss], feed_dict=feed_dict) |
| 44 | + predicted_labels = np.argmax(logits[0], axis=2) |
| 45 | + true_labels = labels[::4, ::4] |
| 46 | + |
| 47 | + correct_labels = np.equal(predicted_labels, true_labels) |
| 48 | + accuracy = np.mean(correct_labels) |
| 49 | + total_accuracy += accuracy |
| 50 | + |
| 51 | + for c in range(model.num_classes): |
| 52 | + current_class_labels = np.equal(true_labels, c) |
| 53 | + class_total_counts[c] += np.sum(current_class_labels) |
| 54 | + class_correct_counts[c] += np.sum(np.equal(true_labels, c) * correct_labels) |
| 55 | + |
| 56 | + print "Image: %s Error: %f Accuracy: %f (time: %.1fs)" % (img_id, error, accuracy, time.time() - start_time) |
| 57 | + |
| 58 | + print "%d Images, Total Accuracy: %f" % (i, total_accuracy / i) |
43 | 59 | print "Per Class accuracy:", class_correct_counts / class_total_counts
|
44 | 60 | print np.sum(class_correct_counts / class_total_counts)
|
45 | 61 |
|
|
0 commit comments