diff --git a/pytorch/README.md b/pytorch/README.md index 7b08cdc..236b4c4 100644 --- a/pytorch/README.md +++ b/pytorch/README.md @@ -24,6 +24,14 @@ $ python bts_test.py arguments_test_nyu.txt ``` This will save results to ./result_bts_nyu_v2_pytorch_densenet161. With a single RTX 2080 Ti it takes about 41 seconds for processing 654 testing images. +**Test multiple images in a folder** + +```bash +$ python bts_sequence.py --image_path +``` + +It'll save the results to `./result_sequence_bts_nyu_v2_pytorch_densenet161`. + ## Evaluation Following command will evaluate the prediction results for NYU Depvh V2. ``` @@ -139,4 +147,4 @@ Note: Modify arguments '--encoder', '--model_name', '--checkpoint_path' and '--p ## License Copyright (C) 2019 Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh \ -This Software is licensed under GPL-3.0-or-later. \ No newline at end of file +This Software is licensed under GPL-3.0-or-later. diff --git a/pytorch/bts_sequence.py b/pytorch/bts_sequence.py new file mode 100644 index 0000000..1d5e4b2 --- /dev/null +++ b/pytorch/bts_sequence.py @@ -0,0 +1,216 @@ +# Copyright (C) 2019 Jin Han Lee +# +# author: @kHarshit (https://github.com/kHarshit) +# +# This file is a part of BTS. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see + +from __future__ import absolute_import, division, print_function + +import os +import argparse +import time +import numpy as np +import cv2 +import sys +import timeit +import glob +import PIL.Image + +import torch +import torch.nn as nn +from bts_dataloader import * + +import errno +import matplotlib.pyplot as plt +from tqdm import tqdm + +from bts_dataloader import * + + +def convert_arg_line_to_args(arg_line): + for arg in arg_line.split(): + if not arg.strip(): + continue + yield arg + + +parser = argparse.ArgumentParser(description='BTS PyTorch implementation.', fromfile_prefix_chars='@') +parser.convert_arg_line_to_args = convert_arg_line_to_args + +parser.add_argument('--model_name', type=str, help='model name', default='bts_nyu_v2_pytorch_densenet161') +parser.add_argument('--encoder', type=str, help='type of encoder, vgg or desenet121_bts or densenet161_bts', + default='densenet161_bts') +parser.add_argument('--image_path', type=str, help='path to the data', required=True) +parser.add_argument('--input_height', type=int, help='input height', default=480) +parser.add_argument('--input_width', type=int, help='input width', default=640) +parser.add_argument('--max_depth', type=float, help='maximum depth in estimation (10 for NYUV2, 80 for KITTI)', default=10) +parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='../models/bts_nyu_v2_pytorch_densenet161/model') +parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu') +parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true') +parser.add_argument('--save_lpg', help='if set, save outputs from lpg layers', action='store_true') +parser.add_argument('--bts_size', type=int, help='initial num_filters in bts', default=512) + +if sys.argv.__len__() == 2: + arg_filename_with_prefix = '@' + sys.argv[1] + args = parser.parse_args([arg_filename_with_prefix]) +else: + args = parser.parse_args() + +model_dir = os.path.dirname(args.checkpoint_path) +sys.path.append(model_dir) + +for key, val in vars(__import__(args.model_name)).items(): + if key.startswith('__') and key.endswith('__'): + continue + vars()[key] = val + + +def test(params): + """Test function.""" + + model = BtsModel(params=args) + model = torch.nn.DataParallel(model) + + checkpoint = torch.load(args.checkpoint_path) + model.load_state_dict(checkpoint['model']) + model.eval() + model.cuda() + + images = glob.glob(args.image_path + '/*') + num_test_samples = len(images) + + print('Testing {} files with {}'.format(num_test_samples, args.checkpoint_path)) + + if args.dataset == 'nyu': + focal = 518.8579 + elif args.dataset == 'kitti': + focal = 718.856 + elif args.dataset == '' and args.focal == -1: + print('Custom dataset needs to specify focal length with --focal') + return + + # apply transformations + loader_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + pred_depths = [] + pred_8x8s = [] + pred_4x4s = [] + pred_2x2s = [] + pred_1x1s = [] + + start_time = time.time() + with torch.no_grad(): + for _, image in enumerate(tqdm(images)): + + image = PIL.Image.open(image) + image = loader_transforms(image).float().cuda() + # add fake batch dimension i.e. [3, 480, 640] -> [1, 3, 480, 640] + image = image.unsqueeze(0) + + # Predict + lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal) + pred_depths.append(depth_est.cpu().numpy().squeeze()) + pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze()) + pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze()) + pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze()) + pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze()) + + elapsed_time = time.time() - start_time + print('Elapesed time: %s' % str(elapsed_time)) + print('Done.') + + save_name = 'result_sequence_' + args.model_name + + print('Saving result pngs..') + if not os.path.exists(os.path.dirname(save_name)): + try: + os.mkdir(save_name) + os.mkdir(save_name + '/raw') + os.mkdir(save_name + '/cmap') + os.mkdir(save_name + '/rgb') + os.mkdir(save_name + '/gt') + except OSError as e: + if e.errno != errno.EEXIST: + raise + + for s in tqdm(range(num_test_samples)): + if args.dataset == 'kitti': + filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace( + '.jpg', '.png') + filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace('.jpg', '.png') + filename_image_png = save_name + '/rgb/' + images[s] + elif args.dataset == 'kitti_benchmark': + filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace('.jpg', '.png') + filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace('.jpg', '.png') + filename_image_png = save_name + '/rgb/' + images[s].split('/')[-1] + else: + filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace( + '.jpg', '.png') + filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace( + '.jpg', '.png') + filename_image_png = save_name + '/rgb/' + images[s].split('/')[-1] + + image = cv2.imread(images[s]) + + pred_depth = pred_depths[s] + pred_8x8 = pred_8x8s[s] + pred_4x4 = pred_4x4s[s] + pred_2x2 = pred_2x2s[s] + pred_1x1 = pred_1x1s[s] + + if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark': + pred_depth_scaled = pred_depth * 256.0 + else: + pred_depth_scaled = pred_depth * 1000.0 + + pred_depth_scaled = pred_depth_scaled.astype(np.uint16) + cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0]) + + if args.save_lpg: + cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :]) + if args.dataset == 'nyu': + pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9] + plt.imsave(filename_cmap_png, np.log10(pred_depth_cropped)) + pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9] + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8_cropped), cmap='Greys') + pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9] + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys') + pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9] + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2_cropped), cmap='Greys') + pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9] + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1_cropped), cmap='Greys') + else: + plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys') + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8), cmap='Greys') + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4), cmap='Greys') + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2), cmap='Greys') + filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png') + plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1), cmap='Greys') + + return + + +if __name__ == '__main__': + test(args)