-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,807 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,5 @@ test_result_*/ | |
logs_*/ | ||
dataset/* | ||
!dataset/README.org | ||
checkpoints_*/ | ||
checkpoints_*/ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,151 @@ | ||
# TSHRNet | ||
# TSHRNet | ||
|
||
**Towards High-Quality Specular Highlight Removal by Leveraging Large-scale Synthetic Data** | ||
|
||
Gang Fu, Qing Zhang, Lei Zhu, Chunxia Xiao, and Ping Li | ||
|
||
In ICCV's 23 | ||
|
||
In this paper, our goal is to remove specular highlight removal for object-level images. In this paper, we propose a three-stage network for specular highlight removal, consisting of (i) physics-based specular highlight removal, (ii) specular-free refinement, and (iii) tone correction. In addition, we present a large-scale synthetic dataset of object-level images, in which each input image has corresponding albedo, shading, specular residue, diffuse, and tone-corrected diffuse images. | ||
|
||
## Prerequisities of our implementation | ||
|
||
``` | ||
conda create --yes --name TSHRNet python=3.9 | ||
conda activate TSHRNet | ||
conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia | ||
conda install --yes tqdm matplotlib | ||
``` | ||
|
||
Please see "dependencies_install.sh". | ||
|
||
## Datasets | ||
|
||
* Our SHHR dataset is available at [OneDrive](https://polyuit-my.sharepoint.com/:u:/g/personal/gangfu_polyu_edu_hk/ERVx4DV78jxGq-1HCPmRsssBOYHPvL_eYmKbGMrELxm8uw?e=tdDAeu) or [Google Drive](https://drive.google.com/file/d/1iBBYIvF5ujLuUe6l22eArFRxFPYAPLVR/view?usp=sharing) (~5G). | ||
* The SHIQ dataset can be found in the project [SHIQ](https://github.com/fu123456/SHIQ). | ||
* The PSD dataset can be found in the project [SpecularityNet-PSD](https://github.com/jianweiguo/SpecularityNet-PSD). | ||
|
||
## Training | ||
|
||
The bash shell script file of "train.sh" provides the command lines for traning on different datasets. | ||
|
||
### Training on SSHR | ||
|
||
``` | ||
python train_4_networks.py \ | ||
-trdd dataset \ | ||
-trdlf dataset/SSHR/train_6_tuples.lst \ | ||
-dn SSHR | ||
``` | ||
|
||
### Training on SHIQ | ||
|
||
``` | ||
python train_4_networks_mix.py \ | ||
-trdd dataset \ | ||
-trdlf dataset/SHIQ_data_10825/train.lst \ | ||
-dn SHIQ | ||
``` | ||
|
||
### Training on PSD | ||
|
||
``` | ||
python train_4_networks_mix.py \ | ||
-trdd dataset \ | ||
-trdlf dataset/M-PSD_Dataset/train_validation.lst \ | ||
-dn PSD_debug_1 | ||
``` | ||
|
||
### Training on the mixed data | ||
|
||
``` | ||
cat dataset/SSHR/train_4_tuples.lst dataset/SHIQ_data_10825/train.lst dataset/M-PSD_Dataset/train.lst >> dataset/train_mix.lst | ||
python train_4_networks_mix.py \ | ||
-trdd dataset \ | ||
-trdlf dataset/train_mix.lst \ | ||
-dn mix_SSHR_SHIQ_PSD | ||
``` | ||
|
||
## Testing | ||
|
||
The bash shell script file of "test.sh" provides the command lines for testing on different datasets. | ||
|
||
### Testing on SSHR | ||
|
||
Note thatwe split "test.lst" into four parts for testin, due to out of memory. | ||
|
||
``` | ||
num_checkpoint=60 # the indexes of the used checkpoints | ||
model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR" | ||
testing_data_name='SSHR' # testing dataset name | ||
# processing testing images | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part1.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part2.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part3.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part4.lst' | ||
``` | ||
|
||
### Testing on SHIQ | ||
|
||
``` | ||
num_checkpoint=60 | ||
model_name='SHIQ' | ||
testing_data_name='SHIQ' | ||
python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst' | ||
``` | ||
|
||
### Testing on PSD | ||
|
||
``` | ||
num_checkpoint=60 | ||
model_name='PSD' | ||
testing_data_name='PSD' | ||
python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/M-PSD_Dataset/test.lst' | ||
``` | ||
|
||
## Index structure of image groups | ||
|
||
Please, put the SSHR, SHIQ, and PSD datasets into the directory of "dataset". | ||
|
||
For seven-tuples image groups (i.e. including additional albedo and shading), their index structure has the following forms: | ||
|
||
``` | ||
SSHR/train/048721/0024_i.jpg SSHR/train/048721/0024_a.jpg SSHR/train/048721/0024_s.jpg SSHR/train/048721/0024_r.jpg SSHR/train/048721/0024_d.jpg SSHR/train/048721/0024_d_tc.jpg SSHR/train/048721/0024_m.jpg | ||
SSHR/train/048721/0078_i.jpg SSHR/train/048721/0078_a.jpg SSHR/train/048721/0078_s.jpg SSHR/train/048721/0078_r.jpg SSHR/train/048721/0078_d.jpg SSHR/train/048721/0078_d_tc.jpg SSHR/train/048721/0024_m.jpg | ||
... ... | ||
``` | ||
From left to right, they are input, albedo, shading, specular residue, diffuse, tone-corrected diffuse, and object mask images, respectively. | ||
|
||
For four-tuples image groups, their index structure has the following forms (taking our SSHR as an example). | ||
|
||
``` | ||
SSHR/train/048721/0044_i.jpg SSHR/train/048721/0044_r.jpg SSHR/train/048721/0044_d.jpg SSHR/train/048721/0044_d_tc.jpg | ||
SSHR/train/048721/0023_i.jpg SSHR/train/048721/0023_r.jpg SSHR/train/048721/0023_d.jpg SSHR/train/048721/0023_d_tc.jpg | ||
... ... | ||
``` | ||
|
||
From left to right, they are input, specular residue, diffuse, and tone-corrected diffuse images, respectively. The main reason for is that it allows to be trained with four-tuples image grops of the SHIQ and PSD datasets. Please download our SSHR dataset and see it for more details. | ||
|
||
|
||
For SHIQ, four-tuples image groups are like: | ||
|
||
``` | ||
SHIQ_data_10825/train/00583_A.png SHIQ_data_10825/train/00583_S.png SHIQ_data_10825/train/00583_D.png SHIQ_data_10825/train/00583_D_tc.png | ||
SHIQ_data_10825/train/08766_A.png SHIQ_data_10825/train/08766_S.png SHIQ_data_10825/train/08766_D.png SHIQ_data_10825/train/08766_D_tc.png | ||
... ... | ||
``` | ||
|
||
For PSD, their images can be constructed as the above form in a list file. | ||
|
||
## Citation | ||
|
||
``` | ||
@inproceedings{zhang-2017-stack, | ||
author = {Fu, Gang and Zhang, Qing and Zhu, Lei and Xiao, Chunxia and Li, Ping}, | ||
title = {Towards high-quality specular highlight removal by leveraging large-scale synthetic data}, | ||
booktitle = {Proceedings of the IEEE International Conference on Computer Vision}, | ||
year = {2023}, | ||
pages = {To appear}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Put the SSHR, SHIQ, and PSD datasets into this directory. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/bin/bash | ||
|
||
## run with "bash -i dependencies_install.sh" | ||
## if it does not work with errors, please run line by line in shell | ||
|
||
conda create --yes --name TSHRNet python=3.9 | ||
conda activate TSHRNet | ||
conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia | ||
conda install --yes tqdm matplotlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
import torch | ||
|
||
def weights_init(init_type='gaussian'): | ||
def init_fun(m): | ||
classname = m.__class__.__name__ | ||
if (classname.find('Conv') == 0 or classname.find( | ||
'Linear') == 0) and hasattr(m, 'weight'): | ||
if init_type == 'gaussian': | ||
nn.init.normal_(m.weight, 0.0, 0.02) | ||
elif init_type == 'xavier': | ||
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) | ||
elif init_type == 'kaiming': | ||
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | ||
elif init_type == 'orthogonal': | ||
nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) | ||
elif init_type == 'default': | ||
pass | ||
else: | ||
assert 0, "Unsupported initialization: {}".format(init_type) | ||
if hasattr(m, 'bias') and m.bias is not None: | ||
nn.init.constant_(m.bias, 0.0) | ||
|
||
return init_fun | ||
|
||
|
||
class Cvi(nn.Module): | ||
def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, | ||
padding=1, dilation=1, groups=1, bias=False): | ||
super(Cvi, self).__init__() | ||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) | ||
self.conv.apply(weights_init('gaussian')) | ||
|
||
if after=='BN': | ||
self.after = nn.BatchNorm2d(out_channels) | ||
elif after=='Tanh': | ||
self.after = torch.tanh | ||
elif after=='sigmoid': | ||
self.after = torch.sigmoid | ||
|
||
if before=='ReLU': | ||
self.before = nn.ReLU(inplace=True) | ||
elif before=='LReLU': | ||
self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
|
||
def forward(self, x): | ||
|
||
if hasattr(self, 'before'): | ||
x = self.before(x) | ||
|
||
x = self.conv(x) | ||
|
||
if hasattr(self, 'after'): | ||
x = self.after(x) | ||
|
||
return x | ||
|
||
|
||
class CvTi(nn.Module): | ||
def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, | ||
padding=1, dilation=1, groups=1, bias=False): | ||
super(CvTi, self).__init__() | ||
# with errors: TypeError: conv_transpose2d(): argument 'output_padding' (position 6) must be tuple of ints, not tuple | ||
# self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias) | ||
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=True) | ||
self.conv.apply(weights_init('gaussian')) | ||
|
||
if after=='BN': | ||
self.after = nn.BatchNorm2d(out_channels) | ||
elif after=='Tanh': | ||
self.after = torch.tanh | ||
elif after=='sigmoid': | ||
self.after = torch.sigmoid | ||
|
||
if before=='ReLU': | ||
self.before = nn.ReLU(inplace=True) | ||
elif before=='LReLU': | ||
self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) | ||
|
||
def forward(self, x): | ||
|
||
if hasattr(self, 'before'): | ||
x = self.before(x) | ||
|
||
x = self.conv(x) | ||
|
||
if hasattr(self, 'after'): | ||
x = self.after(x) | ||
|
||
return x | ||
|
||
class UNet(nn.Module): | ||
def __init__(self, input_channels=3, output_channels=1): | ||
super(UNet, self).__init__() | ||
|
||
self.Cv0 = Cvi(input_channels, 64) | ||
|
||
self.Cv1 = Cvi(64, 128, before='LReLU', after='BN', dilation=1) | ||
|
||
self.Cv2 = Cvi(128, 256, before='LReLU', after='BN', dilation=1) | ||
|
||
self.Cv3 = Cvi(256, 512, before='LReLU', after='BN', dilation=1) | ||
|
||
self.Cv4 = Cvi(512, 512, before='LReLU', after='BN', dilation=1) | ||
|
||
self.Cv5 = Cvi(512, 512, before='LReLU', dilation=1) | ||
|
||
self.CvT6 = CvTi(512, 512, before='ReLU', after='BN', dilation=1) | ||
|
||
self.CvT7 = CvTi(1024, 512, before='ReLU', after='BN', dilation=1) | ||
|
||
self.CvT8 = CvTi(1024, 256, before='ReLU', after='BN', dilation=1) | ||
|
||
self.CvT9 = CvTi(512, 128, before='ReLU', after='BN', dilation=1) | ||
|
||
self.CvT10 = CvTi(256, 64, before='ReLU', after='BN', dilation=1) | ||
|
||
self.CvT11 = CvTi(128, output_channels, before='ReLU', after='Tanh', dilation=1) | ||
|
||
def forward(self, input): | ||
# encoder | ||
x0 = self.Cv0(input) | ||
x1 = self.Cv1(x0) | ||
x2 = self.Cv2(x1) | ||
x3 = self.Cv3(x2) | ||
x4_1 = self.Cv4(x3) | ||
x4_2 = self.Cv4(x4_1) | ||
x4_3 = self.Cv4(x4_2) | ||
x5 = self.Cv5(x4_3) | ||
|
||
# decoder | ||
x6 = self.CvT6(x5) | ||
|
||
cat1_1 = torch.cat([x6, x4_3], dim=1) | ||
x7_1 = self.CvT7(cat1_1) | ||
cat1_2 = torch.cat([x7_1, x4_2], dim=1) | ||
x7_2 = self.CvT7(cat1_2) | ||
cat1_3 = torch.cat([x7_2, x4_1], dim=1) | ||
x7_3 = self.CvT7(cat1_3) | ||
|
||
cat2 = torch.cat([x7_3, x3], dim=1) | ||
x8 = self.CvT8(cat2) | ||
|
||
cat3 = torch.cat([x8, x2], dim=1) | ||
x9 = self.CvT9(cat3) | ||
|
||
cat4 = torch.cat([x9, x1], dim=1) | ||
x10 = self.CvT10(cat4) | ||
|
||
cat5 = torch.cat([x10, x0], dim=1) | ||
out = self.CvT11(cat5) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/bin/bash | ||
|
||
set -e | ||
|
||
# In default, we use the model trained on SSHR (or SHIQ or PSD) to process the testing images of SSHR (or SHIQ or PSD). | ||
# The variable of "model_name" can be SSHR or SHIQ or PSD or mix_SSHR_SHIQ_PSD. | ||
|
||
|
||
## >>> testing SSHR >>> | ||
# due to out of memory, we split "test.lst" into four parts for testing | ||
num_checkpoint=60 # the indexes of the used checkpoints | ||
model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR" | ||
testing_data_name='SSHR' # testing dataset name | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part1.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part2.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part3.lst' | ||
python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part4.lst' | ||
## <<< testing SSHR <<< | ||
|
||
|
||
# ## >>> testing SHIQ >>> | ||
# num_checkpoint=60 | ||
# model_name='SHIQ' | ||
# testing_data_name='SHIQ' | ||
# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst' | ||
# ## <<< testing SHIQ <<< | ||
|
||
|
||
# ## >>> testing PSD >>> | ||
# num_checkpoint=60 | ||
# model_name='PSD' | ||
# testing_data_name='PSD' | ||
# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/test.lst' | ||
# # python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/train.lst' | ||
# ## <<< testing PSD <<< |
Oops, something went wrong.