Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ $ python bts_test.py arguments_test_nyu.txt
```
This will save results to ./result_bts_nyu_v2_pytorch_densenet161. With a single RTX 2080 Ti it takes about 41 seconds for processing 654 testing images.

**Test multiple images in a folder**

```bash
$ python bts_sequence.py --image_path <image_folder_path>
```

It'll save the results to `./result_sequence_bts_nyu_v2_pytorch_densenet161`.

## Evaluation
Following command will evaluate the prediction results for NYU Depvh V2.
```
Expand Down Expand Up @@ -139,4 +147,4 @@ Note: Modify arguments '--encoder', '--model_name', '--checkpoint_path' and '--p

## License
Copyright (C) 2019 Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh \
This Software is licensed under GPL-3.0-or-later.
This Software is licensed under GPL-3.0-or-later.
216 changes: 216 additions & 0 deletions pytorch/bts_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (C) 2019 Jin Han Lee
#
# author: @kHarshit (https://github.com/kHarshit)
#
# This file is a part of BTS.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>

from __future__ import absolute_import, division, print_function

import os
import argparse
import time
import numpy as np
import cv2
import sys
import timeit
import glob
import PIL.Image

import torch
import torch.nn as nn
from bts_dataloader import *

import errno
import matplotlib.pyplot as plt
from tqdm import tqdm

from bts_dataloader import *


def convert_arg_line_to_args(arg_line):
for arg in arg_line.split():
if not arg.strip():
continue
yield arg


parser = argparse.ArgumentParser(description='BTS PyTorch implementation.', fromfile_prefix_chars='@')
parser.convert_arg_line_to_args = convert_arg_line_to_args

parser.add_argument('--model_name', type=str, help='model name', default='bts_nyu_v2_pytorch_densenet161')
parser.add_argument('--encoder', type=str, help='type of encoder, vgg or desenet121_bts or densenet161_bts',
default='densenet161_bts')
parser.add_argument('--image_path', type=str, help='path to the data', required=True)
parser.add_argument('--input_height', type=int, help='input height', default=480)
parser.add_argument('--input_width', type=int, help='input width', default=640)
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation (10 for NYUV2, 80 for KITTI)', default=10)
parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='../models/bts_nyu_v2_pytorch_densenet161/model')
parser.add_argument('--dataset', type=str, help='dataset to train on, make3d or nyudepthv2', default='nyu')
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
parser.add_argument('--save_lpg', help='if set, save outputs from lpg layers', action='store_true')
parser.add_argument('--bts_size', type=int, help='initial num_filters in bts', default=512)

if sys.argv.__len__() == 2:
arg_filename_with_prefix = '@' + sys.argv[1]
args = parser.parse_args([arg_filename_with_prefix])
else:
args = parser.parse_args()

model_dir = os.path.dirname(args.checkpoint_path)
sys.path.append(model_dir)

for key, val in vars(__import__(args.model_name)).items():
if key.startswith('__') and key.endswith('__'):
continue
vars()[key] = val


def test(params):
"""Test function."""

model = BtsModel(params=args)
model = torch.nn.DataParallel(model)

checkpoint = torch.load(args.checkpoint_path)
model.load_state_dict(checkpoint['model'])
model.eval()
model.cuda()

images = glob.glob(args.image_path + '/*')
num_test_samples = len(images)

print('Testing {} files with {}'.format(num_test_samples, args.checkpoint_path))

if args.dataset == 'nyu':
focal = 518.8579
elif args.dataset == 'kitti':
focal = 718.856
elif args.dataset == '' and args.focal == -1:
print('Custom dataset needs to specify focal length with --focal')
return

# apply transformations
loader_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

pred_depths = []
pred_8x8s = []
pred_4x4s = []
pred_2x2s = []
pred_1x1s = []

start_time = time.time()
with torch.no_grad():
for _, image in enumerate(tqdm(images)):

image = PIL.Image.open(image)
image = loader_transforms(image).float().cuda()
# add fake batch dimension i.e. [3, 480, 640] -> [1, 3, 480, 640]
image = image.unsqueeze(0)

# Predict
lpg8x8, lpg4x4, lpg2x2, reduc1x1, depth_est = model(image, focal)
pred_depths.append(depth_est.cpu().numpy().squeeze())
pred_8x8s.append(lpg8x8[0].cpu().numpy().squeeze())
pred_4x4s.append(lpg4x4[0].cpu().numpy().squeeze())
pred_2x2s.append(lpg2x2[0].cpu().numpy().squeeze())
pred_1x1s.append(reduc1x1[0].cpu().numpy().squeeze())

elapsed_time = time.time() - start_time
print('Elapesed time: %s' % str(elapsed_time))
print('Done.')

save_name = 'result_sequence_' + args.model_name

print('Saving result pngs..')
if not os.path.exists(os.path.dirname(save_name)):
try:
os.mkdir(save_name)
os.mkdir(save_name + '/raw')
os.mkdir(save_name + '/cmap')
os.mkdir(save_name + '/rgb')
os.mkdir(save_name + '/gt')
except OSError as e:
if e.errno != errno.EEXIST:
raise

for s in tqdm(range(num_test_samples)):
if args.dataset == 'kitti':
filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace(
'.jpg', '.png')
filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace('.jpg', '.png')
filename_image_png = save_name + '/rgb/' + images[s]
elif args.dataset == 'kitti_benchmark':
filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace('.jpg', '.png')
filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace('.jpg', '.png')
filename_image_png = save_name + '/rgb/' + images[s].split('/')[-1]
else:
filename_pred_png = save_name + '/raw/' + images[s].split('/')[-1].replace(
'.jpg', '.png')
filename_cmap_png = save_name + '/cmap/' + images[s].split('/')[-1].replace(
'.jpg', '.png')
filename_image_png = save_name + '/rgb/' + images[s].split('/')[-1]

image = cv2.imread(images[s])

pred_depth = pred_depths[s]
pred_8x8 = pred_8x8s[s]
pred_4x4 = pred_4x4s[s]
pred_2x2 = pred_2x2s[s]
pred_1x1 = pred_1x1s[s]

if args.dataset == 'kitti' or args.dataset == 'kitti_benchmark':
pred_depth_scaled = pred_depth * 256.0
else:
pred_depth_scaled = pred_depth * 1000.0

pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])

if args.save_lpg:
cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
if args.dataset == 'nyu':
pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
plt.imsave(filename_cmap_png, np.log10(pred_depth_cropped))
pred_8x8_cropped = pred_8x8[10:-1 - 9, 10:-1 - 9]
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8_cropped), cmap='Greys')
pred_4x4_cropped = pred_4x4[10:-1 - 9, 10:-1 - 9]
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4_cropped), cmap='Greys')
pred_2x2_cropped = pred_2x2[10:-1 - 9, 10:-1 - 9]
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2_cropped), cmap='Greys')
pred_1x1_cropped = pred_1x1[10:-1 - 9, 10:-1 - 9]
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1_cropped), cmap='Greys')
else:
plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='Greys')
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_8x8.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_8x8), cmap='Greys')
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_4x4.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_4x4), cmap='Greys')
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_2x2.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_2x2), cmap='Greys')
filename_lpg_cmap_png = filename_cmap_png.replace('.png', '_1x1.png')
plt.imsave(filename_lpg_cmap_png, np.log10(pred_1x1), cmap='Greys')

return


if __name__ == '__main__':
test(args)