Skip to content

Commit

Permalink
our network
Browse files Browse the repository at this point in the history
  • Loading branch information
fu123456 committed Sep 12, 2023
1 parent 0d573ca commit cb1efaf
Show file tree
Hide file tree
Showing 16 changed files with 1,807 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ test_result_*/
logs_*/
dataset/*
!dataset/README.org
checkpoints_*/
checkpoints_*/
__pycache__/
152 changes: 151 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,151 @@
# TSHRNet
# TSHRNet

**Towards High-Quality Specular Highlight Removal by Leveraging Large-scale Synthetic Data**

Gang Fu, Qing Zhang, Lei Zhu, Chunxia Xiao, and Ping Li

In ICCV's 23

In this paper, our goal is to remove specular highlight removal for object-level images. In this paper, we propose a three-stage network for specular highlight removal, consisting of (i) physics-based specular highlight removal, (ii) specular-free refinement, and (iii) tone correction. In addition, we present a large-scale synthetic dataset of object-level images, in which each input image has corresponding albedo, shading, specular residue, diffuse, and tone-corrected diffuse images.

## Prerequisities of our implementation

```
conda create --yes --name TSHRNet python=3.9
conda activate TSHRNet
conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia
conda install --yes tqdm matplotlib
```

Please see "dependencies_install.sh".

## Datasets

* Our SHHR dataset is available at [OneDrive](https://polyuit-my.sharepoint.com/:u:/g/personal/gangfu_polyu_edu_hk/ERVx4DV78jxGq-1HCPmRsssBOYHPvL_eYmKbGMrELxm8uw?e=tdDAeu) or [Google Drive](https://drive.google.com/file/d/1iBBYIvF5ujLuUe6l22eArFRxFPYAPLVR/view?usp=sharing) (~5G).
* The SHIQ dataset can be found in the project [SHIQ](https://github.com/fu123456/SHIQ).
* The PSD dataset can be found in the project [SpecularityNet-PSD](https://github.com/jianweiguo/SpecularityNet-PSD).

## Training

The bash shell script file of "train.sh" provides the command lines for traning on different datasets.

### Training on SSHR

```
python train_4_networks.py \
-trdd dataset \
-trdlf dataset/SSHR/train_6_tuples.lst \
-dn SSHR
```

### Training on SHIQ

```
python train_4_networks_mix.py \
-trdd dataset \
-trdlf dataset/SHIQ_data_10825/train.lst \
-dn SHIQ
```

### Training on PSD

```
python train_4_networks_mix.py \
-trdd dataset \
-trdlf dataset/M-PSD_Dataset/train_validation.lst \
-dn PSD_debug_1
```

### Training on the mixed data

```
cat dataset/SSHR/train_4_tuples.lst dataset/SHIQ_data_10825/train.lst dataset/M-PSD_Dataset/train.lst >> dataset/train_mix.lst
python train_4_networks_mix.py \
-trdd dataset \
-trdlf dataset/train_mix.lst \
-dn mix_SSHR_SHIQ_PSD
```

## Testing

The bash shell script file of "test.sh" provides the command lines for testing on different datasets.

### Testing on SSHR

Note thatwe split "test.lst" into four parts for testin, due to out of memory.

```
num_checkpoint=60 # the indexes of the used checkpoints
model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR"
testing_data_name='SSHR' # testing dataset name
# processing testing images
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part1.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part2.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part3.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part4.lst'
```

### Testing on SHIQ

```
num_checkpoint=60
model_name='SHIQ'
testing_data_name='SHIQ'
python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst'
```

### Testing on PSD

```
num_checkpoint=60
model_name='PSD'
testing_data_name='PSD'
python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/M-PSD_Dataset/test.lst'
```

## Index structure of image groups

Please, put the SSHR, SHIQ, and PSD datasets into the directory of "dataset".

For seven-tuples image groups (i.e. including additional albedo and shading), their index structure has the following forms:

```
SSHR/train/048721/0024_i.jpg SSHR/train/048721/0024_a.jpg SSHR/train/048721/0024_s.jpg SSHR/train/048721/0024_r.jpg SSHR/train/048721/0024_d.jpg SSHR/train/048721/0024_d_tc.jpg SSHR/train/048721/0024_m.jpg
SSHR/train/048721/0078_i.jpg SSHR/train/048721/0078_a.jpg SSHR/train/048721/0078_s.jpg SSHR/train/048721/0078_r.jpg SSHR/train/048721/0078_d.jpg SSHR/train/048721/0078_d_tc.jpg SSHR/train/048721/0024_m.jpg
... ...
```
From left to right, they are input, albedo, shading, specular residue, diffuse, tone-corrected diffuse, and object mask images, respectively.

For four-tuples image groups, their index structure has the following forms (taking our SSHR as an example).

```
SSHR/train/048721/0044_i.jpg SSHR/train/048721/0044_r.jpg SSHR/train/048721/0044_d.jpg SSHR/train/048721/0044_d_tc.jpg
SSHR/train/048721/0023_i.jpg SSHR/train/048721/0023_r.jpg SSHR/train/048721/0023_d.jpg SSHR/train/048721/0023_d_tc.jpg
... ...
```

From left to right, they are input, specular residue, diffuse, and tone-corrected diffuse images, respectively. The main reason for is that it allows to be trained with four-tuples image grops of the SHIQ and PSD datasets. Please download our SSHR dataset and see it for more details.


For SHIQ, four-tuples image groups are like:

```
SHIQ_data_10825/train/00583_A.png SHIQ_data_10825/train/00583_S.png SHIQ_data_10825/train/00583_D.png SHIQ_data_10825/train/00583_D_tc.png
SHIQ_data_10825/train/08766_A.png SHIQ_data_10825/train/08766_S.png SHIQ_data_10825/train/08766_D.png SHIQ_data_10825/train/08766_D_tc.png
... ...
```

For PSD, their images can be constructed as the above form in a list file.

## Citation

```
@inproceedings{zhang-2017-stack,
author = {Fu, Gang and Zhang, Qing and Zhu, Lei and Xiao, Chunxia and Li, Ping},
title = {Towards high-quality specular highlight removal by leveraging large-scale synthetic data},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision},
year = {2023},
pages = {To appear},
}
```
1 change: 1 addition & 0 deletions dataset/README.org
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Put the SSHR, SHIQ, and PSD datasets into this directory.
9 changes: 9 additions & 0 deletions dependencies_install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

## run with "bash -i dependencies_install.sh"
## if it does not work with errors, please run line by line in shell

conda create --yes --name TSHRNet python=3.9
conda activate TSHRNet
conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia
conda install --yes tqdm matplotlib
154 changes: 154 additions & 0 deletions models/UNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch.nn.functional as F
import torch.nn as nn
import torch

def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)

return init_fun


class Cvi(nn.Module):
def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2,
padding=1, dilation=1, groups=1, bias=False):
super(Cvi, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.conv.apply(weights_init('gaussian'))

if after=='BN':
self.after = nn.BatchNorm2d(out_channels)
elif after=='Tanh':
self.after = torch.tanh
elif after=='sigmoid':
self.after = torch.sigmoid

if before=='ReLU':
self.before = nn.ReLU(inplace=True)
elif before=='LReLU':
self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):

if hasattr(self, 'before'):
x = self.before(x)

x = self.conv(x)

if hasattr(self, 'after'):
x = self.after(x)

return x


class CvTi(nn.Module):
def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2,
padding=1, dilation=1, groups=1, bias=False):
super(CvTi, self).__init__()
# with errors: TypeError: conv_transpose2d(): argument 'output_padding' (position 6) must be tuple of ints, not tuple
# self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias)
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=True)
self.conv.apply(weights_init('gaussian'))

if after=='BN':
self.after = nn.BatchNorm2d(out_channels)
elif after=='Tanh':
self.after = torch.tanh
elif after=='sigmoid':
self.after = torch.sigmoid

if before=='ReLU':
self.before = nn.ReLU(inplace=True)
elif before=='LReLU':
self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):

if hasattr(self, 'before'):
x = self.before(x)

x = self.conv(x)

if hasattr(self, 'after'):
x = self.after(x)

return x

class UNet(nn.Module):
def __init__(self, input_channels=3, output_channels=1):
super(UNet, self).__init__()

self.Cv0 = Cvi(input_channels, 64)

self.Cv1 = Cvi(64, 128, before='LReLU', after='BN', dilation=1)

self.Cv2 = Cvi(128, 256, before='LReLU', after='BN', dilation=1)

self.Cv3 = Cvi(256, 512, before='LReLU', after='BN', dilation=1)

self.Cv4 = Cvi(512, 512, before='LReLU', after='BN', dilation=1)

self.Cv5 = Cvi(512, 512, before='LReLU', dilation=1)

self.CvT6 = CvTi(512, 512, before='ReLU', after='BN', dilation=1)

self.CvT7 = CvTi(1024, 512, before='ReLU', after='BN', dilation=1)

self.CvT8 = CvTi(1024, 256, before='ReLU', after='BN', dilation=1)

self.CvT9 = CvTi(512, 128, before='ReLU', after='BN', dilation=1)

self.CvT10 = CvTi(256, 64, before='ReLU', after='BN', dilation=1)

self.CvT11 = CvTi(128, output_channels, before='ReLU', after='Tanh', dilation=1)

def forward(self, input):
# encoder
x0 = self.Cv0(input)
x1 = self.Cv1(x0)
x2 = self.Cv2(x1)
x3 = self.Cv3(x2)
x4_1 = self.Cv4(x3)
x4_2 = self.Cv4(x4_1)
x4_3 = self.Cv4(x4_2)
x5 = self.Cv5(x4_3)

# decoder
x6 = self.CvT6(x5)

cat1_1 = torch.cat([x6, x4_3], dim=1)
x7_1 = self.CvT7(cat1_1)
cat1_2 = torch.cat([x7_1, x4_2], dim=1)
x7_2 = self.CvT7(cat1_2)
cat1_3 = torch.cat([x7_2, x4_1], dim=1)
x7_3 = self.CvT7(cat1_3)

cat2 = torch.cat([x7_3, x3], dim=1)
x8 = self.CvT8(cat2)

cat3 = torch.cat([x8, x2], dim=1)
x9 = self.CvT9(cat3)

cat4 = torch.cat([x9, x1], dim=1)
x10 = self.CvT10(cat4)

cat5 = torch.cat([x10, x0], dim=1)
out = self.CvT11(cat5)

return out
35 changes: 35 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

set -e

# In default, we use the model trained on SSHR (or SHIQ or PSD) to process the testing images of SSHR (or SHIQ or PSD).
# The variable of "model_name" can be SSHR or SHIQ or PSD or mix_SSHR_SHIQ_PSD.


## >>> testing SSHR >>>
# due to out of memory, we split "test.lst" into four parts for testing
num_checkpoint=60 # the indexes of the used checkpoints
model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR"
testing_data_name='SSHR' # testing dataset name
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part1.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part2.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part3.lst'
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part4.lst'
## <<< testing SSHR <<<


# ## >>> testing SHIQ >>>
# num_checkpoint=60
# model_name='SHIQ'
# testing_data_name='SHIQ'
# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst'
# ## <<< testing SHIQ <<<


# ## >>> testing PSD >>>
# num_checkpoint=60
# model_name='PSD'
# testing_data_name='PSD'
# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/test.lst'
# # python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/train.lst'
# ## <<< testing PSD <<<
Loading

0 comments on commit cb1efaf

Please sign in to comment.