Skip to content

Commit 13821fb

Browse files
committed
added robust support
1 parent d2f8457 commit 13821fb

File tree

3 files changed

+83
-8
lines changed

3 files changed

+83
-8
lines changed

train.py

+22
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def getOpt():
7272
parser.add_argument("--model", type=str, default="nyu_modelA", required = True, help="name of the model (nyu_modelA | nyu_modelB)")
7373
parser.add_argument("--dataset_path", type=str, default="/home/mdl/mzk591/dataset/data.nyuv2/disk3/", help="path to the dataset")
7474
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
75+
parser.add_argument('--robust', '-r', action='store_true', help="flag to enable robust training")
7576
parser.add_argument("--save_size", type=int, default=8, help="batch size for saved outputs")
7677
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
7778
parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
@@ -217,6 +218,14 @@ def main():
217218
# snapshot_interval = round(total_train_batches/2)
218219
snapshot_interval = 30
219220

221+
222+
if opt.robust:
223+
# Finding noisy batches
224+
train_rgb_noise, train_sparse_noise = send_noisy_batches(total_train_batches, train_flag=True)
225+
226+
logger.info("RGB noisy batches for training are: {}".format(train_rgb_noise))
227+
logger.info("Sparse noisy batches for training are: {}".format(train_sparse_noise))
228+
220229
# ----------
221230
# Training
222231
# ----------
@@ -238,6 +247,19 @@ def main():
238247
gt_temp = torch.unsqueeze(imgs["gt"], 1)
239248
rgb_temp = imgs["rgb"]
240249

250+
if opt.robust:
251+
# do not want to start the training during warm up
252+
rstart = False
253+
if batches_done >= opt.warmup_batches:
254+
rstart = True
255+
256+
if (i in train_rgb_noise) and rstart:
257+
rgb_temp = torch.zeros(rgb_temp.size()) # it can be any other noise
258+
logger.info("Current batch {} is a noisy RGB sample!".format(batches_done))
259+
elif (i in train_sparse_noise) and rstart:
260+
sparse_temp = torch.zeros(sparse_temp.size()) # it can be any other form of noise
261+
logger.info("Current batch {} is a noisy sparse sample!".format(batches_done))
262+
241263
# Configure model input
242264
sparse_depth = Variable(sparse_temp.type(Tensor))
243265
gt_depth = Variable(gt_temp.type(Tensor))

utils.py

+37-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
from torchvision.utils import make_grid
88
from torchvision import transforms
9+
from numpy.random import default_rng
910

1011
import logging
1112
import cv2
@@ -36,7 +37,12 @@ def generate_depth_cmap(in_tensor):
3637
for img in range(depth_tensor.shape[0]):
3738
min_val = np.amin(depth_tensor[img])
3839
max_val = np.amax(depth_tensor[img])
39-
gray = (depth_tensor[img]-min_val)/(max_val-min_val)
40+
41+
if (max_val - min_val) < 1e-4: # when they are pretty close, no normalization
42+
gray = depth_tensor[img]
43+
else:
44+
gray = (depth_tensor[img]-min_val)/(max_val-min_val)
45+
4046
# gray = depth_tensor[img]/255.0
4147
gray = np.clip(gray,0,1)
4248
heatmap = np.round(colormap(gray) * 255).astype(np.uint8)[:,:,:3]
@@ -121,18 +127,44 @@ def save_my_image(image_array, fp) -> None:
121127
# print('image_array',image_array.shape)
122128
cv2.imwrite(fp, image_array)
123129

124-
def save_sample_images(gt_depth, imgs_rgb, sparse_depth, gen_depth, image_save_path, image_id) -> None:
130+
def save_sample_images(gt_depth, imgs_rgb, sparse_depth, gen_depth, image_save_path, image_id, sn_flag=False, rn_flag=False) -> None:
125131

126132
denorm_gt = denormalize_dense(gt_depth)
127-
denorm_sparse = denormalize_sparse(sparse_depth)
133+
134+
if sn_flag:
135+
denorm_sparse = sparse_depth
136+
else:
137+
denorm_sparse = denormalize_sparse(sparse_depth)
138+
128139
denorm_pred = denormalize_dense(gen_depth)
129140

130141
gt_depth = generate_depth_cmap(denorm_gt)
131142
sparse_depth = generate_depth_cmap(denorm_sparse)
132143
gen_depth = generate_depth_cmap(denorm_pred)
133144

134-
imgs_rgb = denormalize_rgb(imgs_rgb).permute(0,2,3,1).to('cpu').detach().numpy()
145+
if rn_flag:
146+
imgs_rgb = imgs_rgb.permute(0,2,3,1).to('cpu').detach().numpy()
147+
else:
148+
imgs_rgb = denormalize_rgb(imgs_rgb).permute(0,2,3,1).to('cpu').detach().numpy()
135149

136150
img_grid = np.concatenate((gt_depth, imgs_rgb, sparse_depth, gen_depth), axis=2)
137151
saved_image_file = os.path.join(image_save_path,"%04d.png"%image_id)
138-
save_my_image(img_grid, saved_image_file)
152+
save_my_image(img_grid, saved_image_file)
153+
154+
def send_noisy_batches(batches, train_flag=False, ratio=0.2):
155+
156+
n_noisy_batches = round(batches * ratio)
157+
n_rgb = round(n_noisy_batches * 0.5)
158+
159+
rng = default_rng()
160+
161+
if train_flag:
162+
allowed = np.arange(batches + 1)
163+
selection = rng.choice(allowed, size=n_noisy_batches, replace=False)
164+
else:
165+
selection = rng.choice(batches+1, size=n_noisy_batches, replace=False)
166+
167+
rgb_batches = selection[:n_rgb]
168+
sparse_batches = selection[n_rgb:]
169+
170+
return sorted(rgb_batches), sorted(sparse_batches)

validate.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def getOpt():
6363
parser.add_argument("--model", type=str, default="nyu_modelA", required = True, help="name of the model (nyu_modelA | nyu_modelB)")
6464
parser.add_argument("--dataset_path", type=str, default="/home/mdl/mzk591/dataset/data.nyuv2/disk3/", help="path to the dataset")
6565
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
66+
parser.add_argument('--robust', '-r', action='store_true', help="flag to enable robust training")
6667
parser.add_argument("--save_size", type=int, default=8, help="batch size for saved outputs")
6768
parser.add_argument("--n_cpu", type=int, default=16, help="number of cpu threads to use during batch generation")
6869
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
@@ -79,9 +80,21 @@ def getOpt():
7980
def validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GAN, criterion_content, criterion_pixel, logger, val_image_save_path, writer, batches_done=0):
8081

8182
total_val_batches = len(val_dataloader)
82-
83-
batch_to_be_saved = np.random.randint(total_val_batches, size=5)
84-
# batch_to_be_saved = [1, 2, 3, 4] #it can be any numbers
83+
84+
if opt.robust:
85+
# Finding noisy batches
86+
val_rgb_noise, val_sparse_noise = send_noisy_batches(total_val_batches, train_flag=False)
87+
88+
logger.info("RGB noisy batches for validation are: {}".format(val_rgb_noise))
89+
logger.info("Sparse noisy batches for validation are: {}".format(val_sparse_noise))
90+
91+
batch_to_be_saved = np.random.randint(total_val_batches, size=3)
92+
batch_to_be_saved = set(batch_to_be_saved)
93+
batch_to_be_saved.add(val_rgb_noise[0])
94+
batch_to_be_saved.add(val_sparse_noise[0])
95+
else:
96+
batch_to_be_saved = np.random.randint(total_val_batches, size=5)
97+
# batch_to_be_saved = [1, 2, 3, 4] #it can be any numbers
8598

8699
val_sample_path = os.path.join(val_image_save_path,"%06d"%batches_done)
87100
os.makedirs(val_sample_path, exist_ok=True)
@@ -94,6 +107,14 @@ def validate(generator, discriminator, opt, Tensor, val_dataloader, criterion_GA
94107
sparse_temp = torch.unsqueeze(imgs["sparse"], 1)
95108
gt_temp = torch.unsqueeze(imgs["gt"], 1)
96109
rgb_temp = imgs["rgb"]
110+
111+
if opt.robust:
112+
if (i in val_rgb_noise):
113+
rgb_temp = torch.zeros(rgb_temp.size()) # it can be any other noise
114+
logger.info("Current batch {} is a noisy RGB sample!".format(i+1))
115+
elif (i in val_sparse_noise):
116+
sparse_temp = torch.zeros(sparse_temp.size()) # it can be any other form of noise
117+
logger.info("Current batch {} is a noisy sparse sample!".format(i+1))
97118

98119
# Configure model input
99120
sparse_depth = Variable(sparse_temp.type(Tensor))

0 commit comments

Comments
 (0)