4
4
5
5
import numpy as np
6
6
import tensorflow as tf
7
+ import time
7
8
8
9
from model import CNNModel
9
10
from model import restore_model
10
- from preprocessing import read_object_classes , image_to_np_array , labels_to_np_array , get_patch , save_labels_array
11
-
12
-
13
- def test_model (sess , model , images , labels , patch_size , output_dir = None , category_colors = None ):
14
- """
15
- Tests the model on the given images and labels
16
- :param patch_size:
17
- :param sess: The tensorflow session within which to run the model
18
- :param model: An rCNN model
19
- :param images: A series of image filenames
20
- :param labels: A series of label filenames (corresponding images/labels should be in the same order)
21
- :param output_dir: An (optional) directory in which to store predicted labels as images.
22
- :param category_colors: A mapping of category index to color, to create images
23
- """
24
- for image_f , label_f in zip (images , labels ):
25
- print "Testing on image %s..." % image_f
26
- image = image_to_np_array (image_f , float_cols = True )
27
- labels = labels_to_np_array (label_f )
11
+ from preprocessing import read_object_classes , image_to_np_array , labels_to_np_array , get_patch , save_labels_array , \
12
+ FROM_GAMES , DATASETS
13
+
14
+
15
+ def test_model (sess , model , dataset_iter , color_map , output_dir ):
16
+ total_accuracy = 0
17
+ class_correct_counts = np .zeros (model .num_classes )
18
+ class_total_counts = np .zeros (model .num_classes )
19
+ i = 0
20
+ for image , labels in dataset_iter ():
21
+ i += 1
22
+ start_time = time .time ()
28
23
h , w , _ = image .shape
29
- image = image [:h // 2 , :w // 2 , :]
30
- h , w , _ = image .shape
31
- labels = labels [:h , :w ]
32
- predicted_labels = np .zeros ([h , w ], dtype = np .uint8 )
33
- pixels_correct = 0
34
- error_for_image = 0
35
- i = 0
36
-
37
- for y in range (patch_size , h - patch_size ):
38
- # # for debug, only do first 10K
39
- # if i > 1e4:
40
- # break
41
-
42
- for x in range (patch_size , w - patch_size ):
43
- i += 1
44
- input_image = get_patch (image , (y , x ), patch_size )
45
- input_image = np .append (input_image ,
46
- np .zeros (shape = [patch_size , patch_size , model .num_classes ], dtype = np .float32 ),
47
- axis = 2 )
48
- input_label = labels [y , x ]
49
- feed_dict = {model .inpt : [input_image ], model .output : input_label }
50
-
51
- error , logits = sess .run ([model .errors [1 ], model .logits [1 ]], feed_dict = feed_dict )
52
- error_for_image += error
53
- output_label = np .argmax (logits )
54
- if output_label == input_label :
55
- pixels_correct += 1
56
- predicted_labels [y , x ] = output_label
57
-
58
- if i % 1000 == 0 :
59
- print "%d/%d pixels done..." % (i , (h - 2 * patch_size ) * (w - 2 * patch_size ))
60
-
61
- # print "Tested on image %s: Accuracy is %.2f%%, error per pixel is %f." % (
62
- # image_f, (100.0 * pixels_correct) / i, error_for_image / i)
63
- if output_dir is not None :
64
- if category_colors is None :
65
- raise ValueError ("Color index not provided, can't output images." )
66
- output_filename = os .path .join (output_dir , os .path .splitext (os .path .basename (label_f ))[0 ] + '_test.png' )
67
- print "output: " , output_filename
68
- save_labels_array (predicted_labels , output_filename = output_filename , colors = category_colors )
24
+
25
+ input_image = np .append (image , np .zeros (shape = [h , w , model .num_classes ], dtype = np .float32 ), axis = 2 )
26
+ feed_dict = {model .inpt : [input_image ], model .output : [labels ]}
27
+ logits , error = sess .run ([model .logits [1 ], model .loss ], feed_dict = feed_dict )
28
+ predicted_labels = np .argmax (logits [0 ], axis = 2 )
29
+ true_labels = labels [::4 , ::4 ]
30
+
31
+ correct_labels = np .equal (predicted_labels , true_labels )
32
+ accuracy = np .mean (correct_labels )
33
+ total_accuracy += accuracy
34
+
35
+ for c in range (model .num_classes ):
36
+ current_class_labels = np .equal (true_labels , c )
37
+ class_total_counts [c ] += np .sum (current_class_labels )
38
+ class_correct_counts [c ] += np .sum (np .equal (true_labels , c ) * correct_labels )
39
+
40
+ print "Error: %f Accuracy: %f (time: %.1fs)" % (error , accuracy , time .time () - start_time )
41
+
42
+ print "%d Images, Total Accuracy: %f" % (i , total_accuracy / i )
43
+ print "Per Class accuracy:" , class_correct_counts / class_total_counts
44
+ print np .sum (class_correct_counts / class_total_counts )
69
45
70
46
71
47
def main ():
@@ -74,8 +50,9 @@ def main():
74
50
formatter_class = argparse .ArgumentDefaultsHelpFormatter )
75
51
parser .add_argument ('--model' , type = str , help = 'Filename of saved model' )
76
52
parser .add_argument ('--category_map' , type = str , help = 'File that maps colors ' )
77
- parser .add_argument ('--images' , type = str , nargs = '+' , help = 'Filename of test image' )
78
- parser .add_argument ('--labels' , type = str , nargs = '+' , help = 'Filename of test labels' )
53
+ parser .add_argument ('--dataset' , type = str , default = FROM_GAMES , choices = DATASETS .keys (),
54
+ help = 'Type of dataset to use. This determines the expected format of the data directory' )
55
+ parser .add_argument ('--data_dir' , type = str , help = 'Directory for image and label data' )
79
56
parser .add_argument ('--output_dir' , type = str , default = None ,
80
57
help = 'Directory to store model output. By default no output is generated.' )
81
58
parser .add_argument ('--patch_size' , type = int , default = 67 , help = 'Size of input patches' )
@@ -84,14 +61,17 @@ def main():
84
61
# load class labels
85
62
category_colors , category_names , names_to_ids = read_object_classes (args .category_map )
86
63
num_classes = len (category_names )
64
+
65
+ # load dataset
66
+ def dataset_func (): return DATASETS [args .dataset ](args .data_dir )
67
+
87
68
# TODO don't hardcode these (maybe store them in config file?)
88
69
model = CNNModel (25 , 50 , 1 , num_classes , 1e-4 , num_layers = 2 )
89
70
90
71
sess = tf .Session ()
91
72
restore_model (sess , args .model )
92
73
93
- test_model (sess , model , args .images , args .labels , patch_size = args .patch_size , output_dir = args .output_dir ,
94
- category_colors = category_colors )
74
+ test_model (sess , model , dataset_func )
95
75
96
76
97
77
if __name__ == '__main__' :
0 commit comments