Skip to content

Commit bed0249

Browse files
committed
added readMe
1 parent 13821fb commit bed0249

File tree

5 files changed

+92
-7
lines changed

5 files changed

+92
-7
lines changed

README.md

+85-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,86 @@
11
# Robust Multimodal Fusion GAN
2-
Codebase for ACM MM'22 paper titled "Robust Multimodal Depth Estimation using Transformer based Generative Adversarial Networks"
2+
3+
This repo is the PyTorch implementation of our ACM Multimedia'22 paper on [Robust Multimodal Depth Estimation using Transformer based Generative Adversarial Networks](https://dl.acm.org/doi/abs/10.1145/3503161.3548418)
4+
5+
<p align="center">
6+
<img src="misc/teapot.png" alt="photo not available">
7+
</p>
8+
9+
## Requirements
10+
The base environment(python 3.6) consists of:
11+
```
12+
pytorch == 1.10.2
13+
torchvision == 0.11.3
14+
tensorboard == 1.15
15+
py-opencv == 4.5.5
16+
pillow == 8.4.0
17+
numpy == 1.17.4
18+
typing == 3.6.4
19+
```
20+
21+
## Dataset
22+
Primarily two datasets were used [ShapeNet](https://shapenet.org/) and [NYU_v2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)
23+
24+
## Training
25+
26+
python train.py --model nyu_modelA --gpus=0,1 --batch_size=40 --n_epochs=27 --decay_epoch=15 --lr_gap=3 -p chkpts/nyu_modelA.pth -n nyu_modelA_train
27+
28+
1. -n --> give a name to the run
29+
2. Modify the val dataloader path with appropriate data directory
30+
3. Typically the directory has the following structure
31+
----|->data.nyu_v2|
32+
|->train|
33+
|->sparse_depth
34+
|->depth_gt
35+
|->image_rgb
36+
|->meta_info.txt
37+
|->val|
38+
|->sparse_depth
39+
|->depth_gt
40+
|->image_rgb
41+
|->meta_info.txt
42+
|->sample|
43+
|->sparse_depth
44+
|->depth_gt
45+
|->image_rgb
46+
|->meta_info.txt
47+
48+
4. The "depth_gt" and "sparse_depth" are the folders containing dense and sparse depth respectively
49+
5. The meta_info.txt contains the file names of these folders. Refer to misc/ folder for sample meta_info file
50+
6. The folder "sample" contains a few sparse samples. This is to track the model learning visually. This is optional.
51+
52+
53+
## Validation
54+
You can run standalone validation if you have a trained model. For that the checkpoint model path has to have 2 files named generator_best.pth and discriminator_best.pth. You can invoke the validation script by:
55+
```bash
56+
python validate.py --model nyu_modelA --gpus=0 --batch_size=16 --checkpoint_model=./logdir/nyu_train/saved_models/ -n nyu_test
57+
```
58+
## Misc
59+
For convenience, some helping scripts have been provided in the misc\ folder
60+
```
61+
├── meta_info.txt #example meta_info file
62+
```
63+
64+
## Citation
65+
If you found the repository helpful, please cite using the following:
66+
```
67+
@inproceedings{10.1145/3503161.3548418,
68+
author = {Khan, Md Fahim Faysal and Devulapally, Anusha and Advani, Siddharth and Narayanan, Vijaykrishnan},
69+
title = {Robust Multimodal Depth Estimation Using Transformer Based Generative Adversarial Networks},
70+
year = {2022},
71+
isbn = {9781450392037},
72+
publisher = {Association for Computing Machinery},
73+
address = {New York, NY, USA},
74+
url = {https://doi.org/10.1145/3503161.3548418},
75+
doi = {10.1145/3503161.3548418},
76+
booktitle = {Proceedings of the 30th ACM International Conference on Multimedia},
77+
pages = {3559–3568},
78+
numpages = {10},
79+
keywords = {sensor fusion, depth completion, generative adversarial nertworks (gan), multimodal sensing, robustness, sensor failure},
80+
location = {Lisboa, Portugal},
81+
series = {MM '22}
82+
}
83+
```
84+
85+
## Acknowledgement
86+
This work was supported in part by National Science Foundation (NSF) SOPHIA (CCF-1822923) and Center for Brain-inspired Computing (C-BRIC) & Center for Research in Intelligent Storage and Processing in Memory (CRISP), two of the six centers in JUMP, a Semiconductor Research Corporation (SRC) program sponsored by DARPA.

datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def __init__(self, root, opt, hr_shape):
190190
# assumption is that the sparse depth is in "lidar" folder
191191
# ground truth depth is in "depth_gt" folder
192192
# and rgb image is in "image_rgb" folder
193-
self.gt_folder, self.lq_folder, self.rgb_folder = os.path.join(root,'depth_gt'), os.path.join(root,'lidar_5p'), os.path.join(root,'image_rgb')
193+
self.gt_folder, self.lq_folder, self.rgb_folder = os.path.join(root,'depth_gt'), os.path.join(root,'sparse_depth'), os.path.join(root,'image_rgb')
194194

195195
self.filename_tmpl = '{}'
196196

misc/meta_info.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
047550 (228, 304, 1)
2+
047551 (228, 304, 1)
3+
047552 (228, 304, 1)

misc/teapot.png

304 KB
Loading

train.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def getOpt():
7070
parser.add_argument("--n_epochs", type=int, default=10, help="number of epochs of training")
7171
parser.add_argument("--dataset", type=str, default="nyu_v2", help="name of the dataset (shapeNet or nyu_v2)")
7272
parser.add_argument("--model", type=str, default="nyu_modelA", required = True, help="name of the model (nyu_modelA | nyu_modelB)")
73-
parser.add_argument("--dataset_path", type=str, default="/home/mdl/mzk591/dataset/data.nyuv2/disk3/", help="path to the dataset")
73+
parser.add_argument("--dataset_path", type=str, default="/home/dataset/nyu_v2/", help="path to the dataset")
7474
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
7575
parser.add_argument('--robust', '-r', action='store_true', help="flag to enable robust training")
7676
parser.add_argument("--save_size", type=int, default=8, help="batch size for saved outputs")
@@ -83,7 +83,7 @@ def getOpt():
8383
parser.add_argument("--hr_width", type=int, default=304, help="dense depth width")
8484
parser.add_argument("--channels", type=int, default=1, help="depth image has only 1 channel")
8585
parser.add_argument("--sample_interval", type=int, default=20, help="interval between saving image samples")
86-
parser.add_argument("--warmup_batches", type=int, default=15, help="number of batches with pixel-wise loss only")
86+
parser.add_argument("--warmup_batches", type=int, default=250, help="number of batches with pixel-wise loss only")
8787
parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight")
8888
parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight")
8989
parser.add_argument("--gpus", metavar='DEV_ID', default=None,
@@ -215,9 +215,7 @@ def main():
215215
milestones = [opt.decay_epoch, opt.decay_epoch + opt.lr_gap, opt.decay_epoch + opt.lr_gap*2, opt.decay_epoch + opt.lr_gap*3]
216216

217217
total_train_batches = len(train_dataloader)
218-
# snapshot_interval = round(total_train_batches/2)
219-
snapshot_interval = 30
220-
218+
snapshot_interval = round(total_train_batches/2)
221219

222220
if opt.robust:
223221
# Finding noisy batches

0 commit comments

Comments
 (0)