@@ -63,6 +63,7 @@ def getOpt():
63
63
parser .add_argument ("--model" , type = str , default = "nyu_modelA" , required = True , help = "name of the model (nyu_modelA | nyu_modelB)" )
64
64
parser .add_argument ("--dataset_path" , type = str , default = "/home/mdl/mzk591/dataset/data.nyuv2/disk3/" , help = "path to the dataset" )
65
65
parser .add_argument ("--batch_size" , type = int , default = 4 , help = "size of the batches" )
66
+ parser .add_argument ('--robust' , '-r' , action = 'store_true' , help = "flag to enable robust training" )
66
67
parser .add_argument ("--save_size" , type = int , default = 8 , help = "batch size for saved outputs" )
67
68
parser .add_argument ("--n_cpu" , type = int , default = 16 , help = "number of cpu threads to use during batch generation" )
68
69
parser .add_argument ("--channels" , type = int , default = 1 , help = "number of image channels" )
@@ -79,9 +80,21 @@ def getOpt():
79
80
def validate (generator , discriminator , opt , Tensor , val_dataloader , criterion_GAN , criterion_content , criterion_pixel , logger , val_image_save_path , writer , batches_done = 0 ):
80
81
81
82
total_val_batches = len (val_dataloader )
82
-
83
- batch_to_be_saved = np .random .randint (total_val_batches , size = 5 )
84
- # batch_to_be_saved = [1, 2, 3, 4] #it can be any numbers
83
+
84
+ if opt .robust :
85
+ # Finding noisy batches
86
+ val_rgb_noise , val_sparse_noise = send_noisy_batches (total_val_batches , train_flag = False )
87
+
88
+ logger .info ("RGB noisy batches for validation are: {}" .format (val_rgb_noise ))
89
+ logger .info ("Sparse noisy batches for validation are: {}" .format (val_sparse_noise ))
90
+
91
+ batch_to_be_saved = np .random .randint (total_val_batches , size = 3 )
92
+ batch_to_be_saved = set (batch_to_be_saved )
93
+ batch_to_be_saved .add (val_rgb_noise [0 ])
94
+ batch_to_be_saved .add (val_sparse_noise [0 ])
95
+ else :
96
+ batch_to_be_saved = np .random .randint (total_val_batches , size = 5 )
97
+ # batch_to_be_saved = [1, 2, 3, 4] #it can be any numbers
85
98
86
99
val_sample_path = os .path .join (val_image_save_path ,"%06d" % batches_done )
87
100
os .makedirs (val_sample_path , exist_ok = True )
@@ -94,6 +107,14 @@ def validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GA
94
107
sparse_temp = torch .unsqueeze (imgs ["sparse" ], 1 )
95
108
gt_temp = torch .unsqueeze (imgs ["gt" ], 1 )
96
109
rgb_temp = imgs ["rgb" ]
110
+
111
+ if opt .robust :
112
+ if (i in val_rgb_noise ):
113
+ rgb_temp = torch .zeros (rgb_temp .size ()) # it can be any other noise
114
+ logger .info ("Current batch {} is a noisy RGB sample!" .format (i + 1 ))
115
+ elif (i in val_sparse_noise ):
116
+ sparse_temp = torch .zeros (sparse_temp .size ()) # it can be any other form of noise
117
+ logger .info ("Current batch {} is a noisy sparse sample!" .format (i + 1 ))
97
118
98
119
# Configure model input
99
120
sparse_depth = Variable (sparse_temp .type (Tensor ))
0 commit comments