diff --git a/.gitignore b/.gitignore index 126d97f..d906a8b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ test_result_*/ logs_*/ dataset/* !dataset/README.org -checkpoints_*/ \ No newline at end of file +checkpoints_*/ +__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 4836227..d66702c 100644 --- a/README.md +++ b/README.md @@ -1 +1,151 @@ -# TSHRNet \ No newline at end of file +# TSHRNet + +**Towards High-Quality Specular Highlight Removal by Leveraging Large-scale Synthetic Data** + +Gang Fu, Qing Zhang, Lei Zhu, Chunxia Xiao, and Ping Li + +In ICCV's 23 + +In this paper, our goal is to remove specular highlight removal for object-level images. In this paper, we propose a three-stage network for specular highlight removal, consisting of (i) physics-based specular highlight removal, (ii) specular-free refinement, and (iii) tone correction. In addition, we present a large-scale synthetic dataset of object-level images, in which each input image has corresponding albedo, shading, specular residue, diffuse, and tone-corrected diffuse images. + +## Prerequisities of our implementation + +``` +conda create --yes --name TSHRNet python=3.9 +conda activate TSHRNet +conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia +conda install --yes tqdm matplotlib +``` + +Please see "dependencies_install.sh". + +## Datasets + +* Our SHHR dataset is available at [OneDrive](https://polyuit-my.sharepoint.com/:u:/g/personal/gangfu_polyu_edu_hk/ERVx4DV78jxGq-1HCPmRsssBOYHPvL_eYmKbGMrELxm8uw?e=tdDAeu) or [Google Drive](https://drive.google.com/file/d/1iBBYIvF5ujLuUe6l22eArFRxFPYAPLVR/view?usp=sharing) (~5G). +* The SHIQ dataset can be found in the project [SHIQ](https://github.com/fu123456/SHIQ). +* The PSD dataset can be found in the project [SpecularityNet-PSD](https://github.com/jianweiguo/SpecularityNet-PSD). + +## Training + +The bash shell script file of "train.sh" provides the command lines for traning on different datasets. + +### Training on SSHR + +``` +python train_4_networks.py \ + -trdd dataset \ + -trdlf dataset/SSHR/train_6_tuples.lst \ + -dn SSHR +``` + +### Training on SHIQ + +``` +python train_4_networks_mix.py \ + -trdd dataset \ + -trdlf dataset/SHIQ_data_10825/train.lst \ + -dn SHIQ +``` + +### Training on PSD + +``` +python train_4_networks_mix.py \ + -trdd dataset \ + -trdlf dataset/M-PSD_Dataset/train_validation.lst \ + -dn PSD_debug_1 +``` + +### Training on the mixed data + +``` +cat dataset/SSHR/train_4_tuples.lst dataset/SHIQ_data_10825/train.lst dataset/M-PSD_Dataset/train.lst >> dataset/train_mix.lst +python train_4_networks_mix.py \ + -trdd dataset \ + -trdlf dataset/train_mix.lst \ + -dn mix_SSHR_SHIQ_PSD +``` + +## Testing + +The bash shell script file of "test.sh" provides the command lines for testing on different datasets. + +### Testing on SSHR + +Note thatwe split "test.lst" into four parts for testin, due to out of memory. + +``` +num_checkpoint=60 # the indexes of the used checkpoints +model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR" +testing_data_name='SSHR' # testing dataset name +# processing testing images +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part1.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part2.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part3.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_6_tuples_part4.lst' +``` + +### Testing on SHIQ + +``` +num_checkpoint=60 +model_name='SHIQ' +testing_data_name='SHIQ' +python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst' +``` + +### Testing on PSD + +``` +num_checkpoint=60 +model_name='PSD' +testing_data_name='PSD' +python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/M-PSD_Dataset/test.lst' +``` + +## Index structure of image groups + +Please, put the SSHR, SHIQ, and PSD datasets into the directory of "dataset". + +For seven-tuples image groups (i.e. including additional albedo and shading), their index structure has the following forms: + +``` +SSHR/train/048721/0024_i.jpg SSHR/train/048721/0024_a.jpg SSHR/train/048721/0024_s.jpg SSHR/train/048721/0024_r.jpg SSHR/train/048721/0024_d.jpg SSHR/train/048721/0024_d_tc.jpg SSHR/train/048721/0024_m.jpg +SSHR/train/048721/0078_i.jpg SSHR/train/048721/0078_a.jpg SSHR/train/048721/0078_s.jpg SSHR/train/048721/0078_r.jpg SSHR/train/048721/0078_d.jpg SSHR/train/048721/0078_d_tc.jpg SSHR/train/048721/0024_m.jpg +... ... + +``` +From left to right, they are input, albedo, shading, specular residue, diffuse, tone-corrected diffuse, and object mask images, respectively. + +For four-tuples image groups, their index structure has the following forms (taking our SSHR as an example). + +``` +SSHR/train/048721/0044_i.jpg SSHR/train/048721/0044_r.jpg SSHR/train/048721/0044_d.jpg SSHR/train/048721/0044_d_tc.jpg +SSHR/train/048721/0023_i.jpg SSHR/train/048721/0023_r.jpg SSHR/train/048721/0023_d.jpg SSHR/train/048721/0023_d_tc.jpg +... ... +``` + +From left to right, they are input, specular residue, diffuse, and tone-corrected diffuse images, respectively. The main reason for is that it allows to be trained with four-tuples image grops of the SHIQ and PSD datasets. Please download our SSHR dataset and see it for more details. + + +For SHIQ, four-tuples image groups are like: + +``` +SHIQ_data_10825/train/00583_A.png SHIQ_data_10825/train/00583_S.png SHIQ_data_10825/train/00583_D.png SHIQ_data_10825/train/00583_D_tc.png +SHIQ_data_10825/train/08766_A.png SHIQ_data_10825/train/08766_S.png SHIQ_data_10825/train/08766_D.png SHIQ_data_10825/train/08766_D_tc.png +... ... +``` + +For PSD, their images can be constructed as the above form in a list file. + +## Citation + +``` +@inproceedings{zhang-2017-stack, + author = {Fu, Gang and Zhang, Qing and Zhu, Lei and Xiao, Chunxia and Li, Ping}, + title = {Towards high-quality specular highlight removal by leveraging large-scale synthetic data}, + booktitle = {Proceedings of the IEEE International Conference on Computer Vision}, + year = {2023}, + pages = {To appear}, +} +``` diff --git a/dataset/README.org b/dataset/README.org new file mode 100644 index 0000000..e7c8519 --- /dev/null +++ b/dataset/README.org @@ -0,0 +1 @@ +Put the SSHR, SHIQ, and PSD datasets into this directory. diff --git a/dependencies_install.sh b/dependencies_install.sh new file mode 100755 index 0000000..f36d27e --- /dev/null +++ b/dependencies_install.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +## run with "bash -i dependencies_install.sh" +## if it does not work with errors, please run line by line in shell + +conda create --yes --name TSHRNet python=3.9 +conda activate TSHRNet +conda install --yes pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.6 -c pytorch -c nvidia +conda install --yes tqdm matplotlib diff --git a/models/UNet.py b/models/UNet.py new file mode 100755 index 0000000..b6f64c9 --- /dev/null +++ b/models/UNet.py @@ -0,0 +1,154 @@ +import torch.nn.functional as F +import torch.nn as nn +import torch + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find( + 'Linear') == 0) and hasattr(m, 'weight'): + if init_type == 'gaussian': + nn.init.normal_(m.weight, 0.0, 0.02) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + + return init_fun + + +class Cvi(nn.Module): + def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, + padding=1, dilation=1, groups=1, bias=False): + super(Cvi, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) + self.conv.apply(weights_init('gaussian')) + + if after=='BN': + self.after = nn.BatchNorm2d(out_channels) + elif after=='Tanh': + self.after = torch.tanh + elif after=='sigmoid': + self.after = torch.sigmoid + + if before=='ReLU': + self.before = nn.ReLU(inplace=True) + elif before=='LReLU': + self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + + if hasattr(self, 'before'): + x = self.before(x) + + x = self.conv(x) + + if hasattr(self, 'after'): + x = self.after(x) + + return x + + +class CvTi(nn.Module): + def __init__(self, in_channels, out_channels, before=None, after=False, kernel_size=4, stride=2, + padding=1, dilation=1, groups=1, bias=False): + super(CvTi, self).__init__() + # with errors: TypeError: conv_transpose2d(): argument 'output_padding' (position 6) must be tuple of ints, not tuple + # self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias) + self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=True) + self.conv.apply(weights_init('gaussian')) + + if after=='BN': + self.after = nn.BatchNorm2d(out_channels) + elif after=='Tanh': + self.after = torch.tanh + elif after=='sigmoid': + self.after = torch.sigmoid + + if before=='ReLU': + self.before = nn.ReLU(inplace=True) + elif before=='LReLU': + self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + + if hasattr(self, 'before'): + x = self.before(x) + + x = self.conv(x) + + if hasattr(self, 'after'): + x = self.after(x) + + return x + +class UNet(nn.Module): + def __init__(self, input_channels=3, output_channels=1): + super(UNet, self).__init__() + + self.Cv0 = Cvi(input_channels, 64) + + self.Cv1 = Cvi(64, 128, before='LReLU', after='BN', dilation=1) + + self.Cv2 = Cvi(128, 256, before='LReLU', after='BN', dilation=1) + + self.Cv3 = Cvi(256, 512, before='LReLU', after='BN', dilation=1) + + self.Cv4 = Cvi(512, 512, before='LReLU', after='BN', dilation=1) + + self.Cv5 = Cvi(512, 512, before='LReLU', dilation=1) + + self.CvT6 = CvTi(512, 512, before='ReLU', after='BN', dilation=1) + + self.CvT7 = CvTi(1024, 512, before='ReLU', after='BN', dilation=1) + + self.CvT8 = CvTi(1024, 256, before='ReLU', after='BN', dilation=1) + + self.CvT9 = CvTi(512, 128, before='ReLU', after='BN', dilation=1) + + self.CvT10 = CvTi(256, 64, before='ReLU', after='BN', dilation=1) + + self.CvT11 = CvTi(128, output_channels, before='ReLU', after='Tanh', dilation=1) + + def forward(self, input): + # encoder + x0 = self.Cv0(input) + x1 = self.Cv1(x0) + x2 = self.Cv2(x1) + x3 = self.Cv3(x2) + x4_1 = self.Cv4(x3) + x4_2 = self.Cv4(x4_1) + x4_3 = self.Cv4(x4_2) + x5 = self.Cv5(x4_3) + + # decoder + x6 = self.CvT6(x5) + + cat1_1 = torch.cat([x6, x4_3], dim=1) + x7_1 = self.CvT7(cat1_1) + cat1_2 = torch.cat([x7_1, x4_2], dim=1) + x7_2 = self.CvT7(cat1_2) + cat1_3 = torch.cat([x7_2, x4_1], dim=1) + x7_3 = self.CvT7(cat1_3) + + cat2 = torch.cat([x7_3, x3], dim=1) + x8 = self.CvT8(cat2) + + cat3 = torch.cat([x8, x2], dim=1) + x9 = self.CvT9(cat3) + + cat4 = torch.cat([x9, x1], dim=1) + x10 = self.CvT10(cat4) + + cat5 = torch.cat([x10, x0], dim=1) + out = self.CvT11(cat5) + + return out diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..d6c8288 --- /dev/null +++ b/test.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +set -e + +# In default, we use the model trained on SSHR (or SHIQ or PSD) to process the testing images of SSHR (or SHIQ or PSD). +# The variable of "model_name" can be SSHR or SHIQ or PSD or mix_SSHR_SHIQ_PSD. + + +## >>> testing SSHR >>> +# due to out of memory, we split "test.lst" into four parts for testing +num_checkpoint=60 # the indexes of the used checkpoints +model_name='SSHR' # find the checkpoints in "checkpoints_${model_name}, like "checkpoints_SSHR" +testing_data_name='SSHR' # testing dataset name +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part1.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part2.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part3.lst' +python test_4_networks.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SSHR/test_7_tuples_part4.lst' +## <<< testing SSHR <<< + + +# ## >>> testing SHIQ >>> +# num_checkpoint=60 +# model_name='SHIQ' +# testing_data_name='SHIQ' +# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/SHIQ_data_10825/test.lst' +# ## <<< testing SHIQ <<< + + +# ## >>> testing PSD >>> +# num_checkpoint=60 +# model_name='PSD' +# testing_data_name='PSD' +# python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/test.lst' +# # python test_4_networks_mix.py -mn ${model_name} -l ${num_checkpoint} -tdn ${testing_data_name} -tedd 'dataset' -tedlf 'dataset/PSD/train.lst' +# ## <<< testing PSD <<< diff --git a/test_4_networks.py b/test_4_networks.py new file mode 100644 index 0000000..1d872f6 --- /dev/null +++ b/test_4_networks.py @@ -0,0 +1,199 @@ +from utils.data_loader_seven_tuples import ImageDataset, ImageTransform, generate_training_data_list, generate_testing_data_list + +from models.UNet import UNet +from torchvision.utils import make_grid +from torchvision.utils import save_image +from torchvision import models +from torchvision import transforms +from torch.autograd import Variable +from collections import OrderedDict +from PIL import Image +from tqdm import tqdm + +import matplotlib.pyplot as plt +import torch.optim as optim +import torch.nn as nn +import numpy as np +import argparse +import time +import torch +import os + +torch.manual_seed(1) + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-l', '--load', type=str, default=None, help='the number of checkpoints') + parser.add_argument('-s', '--image_size', type=int, default=256) + parser.add_argument('-tedd', '--testing_data_dir', type=str, default='dataset') + parser.add_argument('-tedlf', '--testing_data_list_file', type=str, default='dataset/SSHR/test_7_tuples.lst') + parser.add_argument('-mn', '--model_name', type=str, default='SSHR') + parser.add_argument('-tdn', '--testing_data_name', type=str, default='SSHR') + return parser + +def fix_model_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k + if name.startswith('module.'): + name = name[7:] + new_state_dict[name] = v + return new_state_dict + +def unnormalize(x): + x = x.transpose(1, 3) + x = x * torch.Tensor((0.5, )) + torch.Tensor((0.5, )) + x = x.transpose(1, 3) + return x + +def test(UNet1, UNet2, UNet3, UNet4, model_name, test_dataset): + device = "cuda" if torch.cuda.is_available() else "cpu" + + UNet1.to(device) + UNet2.to(device) + UNet3.to(device) + UNet4.to(device) + + if device == 'cuda': + UNet1 = torch.nn.DataParallel(UNet1) + UNet2 = torch.nn.DataParallel(UNet2) + UNet3 = torch.nn.DataParallel(UNet3) + UNet4 = torch.nn.DataParallel(UNet4) + print("parallel mode") + + print("device:{}".format(device)) + + UNet1.eval() + UNet2.eval() + UNet3.eval() + UNet4.eval() + + # dirs for saving results + dir_t1 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_albedo' + dir_t2 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_shading' + dir_t3 = './test_result_' + model_name + '/' + key_dir + '/' + 'grid' + dir_t4 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_diffuse' + dir_t5 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_diffuse_tc' + + for n, (img, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask) in enumerate([test_dataset[i] for i in range(test_dataset.__len__())]): + + img = torch.unsqueeze(img, dim=0) + gt_albedo = torch.unsqueeze(gt_albedo, dim=0) + gt_shading = torch.unsqueeze(gt_shading, dim=0) + gt_diffuse = torch.unsqueeze(gt_diffuse, dim=0) + gt_diffuse_tc = torch.unsqueeze(gt_diffuse_tc, dim=0) + object_mask = torch.unsqueeze(object_mask, dim=0) + + img = img.to(device) + object_mask = object_mask.to(device) + + with torch.no_grad(): + ## estimations in our three-stage network + # estimations in the first stage + estimated_albedo = UNet1(img) + estimated_shading = UNet2(img) + estimated_specular_residue = (img - estimated_albedo * estimated_shading) + + # estimation in the second stage + G3_input = torch.cat([estimated_albedo * estimated_shading, img], dim=1) + estimated_diffuse_refined = UNet3(G3_input) + + # estimation in the third stage + G4_input = torch.cat([estimated_diffuse_refined, estimated_specular_residue, img], dim=1) + estimated_diffuse_tc = UNet4(G4_input) + ## end + + # to cpu + estimated_albedo = estimated_albedo.to(torch.device('cpu')) + estimated_shading = estimated_shading.to(torch.device('cpu')) + estimated_diffuse_refined = estimated_diffuse_refined.to(torch.device('cpu')) + estimated_diffuse_tc = estimated_diffuse_tc.to(torch.device('cpu')) + img = img.to(torch.device('cpu')) + object_mask = object_mask.to(torch.device('cpu')) + + grid= make_grid(torch.cat((img,gt_diffuse,gt_diffuse_tc, estimated_diffuse_refined * object_mask, estimated_diffuse_tc * object_mask), dim=0)) + + temp = len(test_dataset.img_list['path_i'][n].split('/')) + r_subdir = test_dataset.img_list['path_i'][n].split('/')[temp-2] + basename = test_dataset.img_list['path_i'][n].split('/')[temp-1] + print(r_subdir) + print(basename) + subdir = os.path.join(dir_t3, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + grid_name = os.path.join(subdir, basename) + save_image(grid, grid_name) + + # # save albedo image + # estimated_albedo = transforms.ToPILImage(mode='RGB')((estimated_albedo)[0, :, :, :] *(object_mask)[0, :, :, :]) + # subdir = os.path.join(dir_t1, r_subdir) + # if not os.path.exists(subdir): + # os.makedirs(subdir) + # detected_shadow_name = os.path.join(subdir, basename) + # estimated_albedo.save(detected_shadow_name) + + # # save shading image + # estimated_shading = transforms.ToPILImage(mode='RGB')((estimated_shading)[0, :, :, :] * (object_mask)[0, :, :, :]) + # subdir = os.path.join(dir_t2, r_subdir) + # if not os.path.exists(subdir): + # os.makedirs(subdir) + # shadow_removal_name = os.path.join(subdir, basename) + # estimated_shading.save(shadow_removal_name) + + # save specular-refined image + estimated_diffuse_refined = transforms.ToPILImage(mode='RGB')((estimated_diffuse_refined)[0, :, :, :] * (object_mask)[0, :, :, :]) + subdir = os.path.join(dir_t4, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + estimated_diffuse_name = os.path.join(subdir, basename) + estimated_diffuse_refined.save(estimated_diffuse_name) + + # save tone-corrected diffuse image + estimated_diffuse_tc = transforms.ToPILImage(mode='RGB')((estimated_diffuse_tc)[0, :, :, :] * (object_mask)[0, :, :, :]) + subdir = os.path.join(dir_t5, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + estimated_diffuse_tc_name = os.path.join(subdir, basename) + estimated_diffuse_tc.save(estimated_diffuse_tc_name) + + +def main(parser): + UNet1 = UNet(input_channels=3, output_channels=3) + UNet2 = UNet(input_channels=3, output_channels=3) + UNet3 = UNet(input_channels=6, output_channels=3) + UNet4 = UNet(input_channels=9, output_channels=3) + + if parser.load is not None: + print('load checkpoint ' + parser.load) + # load UNet weights + UNet1_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet1_'+parser.load+'.pth') + UNet1.load_state_dict(fix_model_state_dict(UNet1_weights)) + UNet2_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet2_'+parser.load+'.pth') + UNet2.load_state_dict(fix_model_state_dict(UNet2_weights)) + UNet3_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet3_'+parser.load+'.pth') + UNet3.load_state_dict(fix_model_state_dict(UNet3_weights)) + UNet4_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet4_'+parser.load+'.pth') + UNet4.load_state_dict(fix_model_state_dict(UNet4_weights)) + + mean = (0.5,) + std = (0.5,) + + size = parser.image_size + testing_data_dir = parser.testing_data_dir + testing_data_list_file = parser.testing_data_list_file + model_name = parser.model_name + + test_img_list = generate_testing_data_list(testing_data_dir, testing_data_list_file) + test_dataset = ImageDataset(img_list=test_img_list, img_transform=ImageTransform(size=size, mean=mean, std=std), phase='test') + test(UNet1, UNet2, UNet3, UNet4, model_name, test_dataset) + + +if __name__ == "__main__": + parser = get_parser().parse_args() + if parser.load is not None: + load_num = str(parser.load) + model_name = parser.model_name + testing_data_name = parser.testing_data_name + key_dir = testing_data_name + '_' + model_name + '_' + load_num # like SSHR_SSHR_60 + + main(parser) diff --git a/test_4_networks_mix.py b/test_4_networks_mix.py new file mode 100644 index 0000000..495eab6 --- /dev/null +++ b/test_4_networks_mix.py @@ -0,0 +1,176 @@ +from utils.data_loader_four_tuples_mix import ImageDataset, ImageTransform, generate_training_data_list, generate_testing_data_list + +from models.UNet import UNet +from torchvision.utils import make_grid +from torchvision.utils import save_image +from torchvision import models +from torchvision import transforms +from torch.autograd import Variable +from collections import OrderedDict +from PIL import Image +from tqdm import tqdm + +import matplotlib.pyplot as plt +import torch.optim as optim +import torch.nn as nn +import numpy as np +import argparse +import time +import torch +import os + +torch.manual_seed(1) + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-l', '--load', type=str, default=None, help='the number of checkpoints') + parser.add_argument('-s', '--image_size', type=int, default=256) + parser.add_argument('-tedd', '--testing_data_dir', type=str, default='dataset/shapenet_specular_1500/testing_data') + parser.add_argument('-tedlf', '--testing_data_list_file', type=str, default='dataset/shapenet_specular_1500/test.lst') + parser.add_argument('-mn', '--model_name', type=str, default='SSHR') + parser.add_argument('-tdn', '--testing_data_name', type=str, default='SSHR') + return parser + + +def fix_model_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k + if name.startswith('module.'): + name = name[7:] + new_state_dict[name] = v + return new_state_dict + + +def unnormalize(x): + x = x.transpose(1, 3) + x = x * torch.Tensor((0.5, )) + torch.Tensor((0.5, )) + x = x.transpose(1, 3) + return x + + +def test(UNet1, UNet2, UNet3, UNet4, model_name, test_dataset): + device = "cuda" if torch.cuda.is_available() else "cpu" + + UNet1.to(device) + UNet2.to(device) + UNet3.to(device) + UNet4.to(device) + + if device == 'cuda': + UNet1 = torch.nn.DataParallel(UNet1) + UNet2 = torch.nn.DataParallel(UNet2) + UNet3 = torch.nn.DataParallel(UNet3) + UNet4 = torch.nn.DataParallel(UNet4) + print("parallel mode") + + print("device:{}".format(device)) + + UNet1.eval() + UNet2.eval() + UNet3.eval() + UNet4.eval() + + # dirs for saving results + dir_t1 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_albedo' + dir_t2 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_shading' + dir_t3 = './test_result_' + model_name + '/' + key_dir + '/' + 'grid' + dir_t4 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_diffuse' + dir_t5 = './test_result_' + model_name + '/' + key_dir + '/' + 'estimated_diffuse_tc' + + for n, (input_img, gt_specular_residue, gt_diffuse, gt_diffuse_tc) in enumerate([test_dataset[i] for i in range(test_dataset.__len__())]): + + # print(test_dataset.img_list['path_i'][n].split('/')[4][:-4]) + + input_img = torch.unsqueeze(input_img, dim=0) + gt_diffuse = torch.unsqueeze(gt_diffuse, dim=0) + gt_diffuse_tc = torch.unsqueeze(gt_diffuse_tc, dim=0) + + with torch.no_grad(): + estimated_diffuse = UNet1(input_img.to(device)) + estimated_specular_residue = UNet2(input_img.to(device)) + # estimat diffuse + G3_data = torch.cat([estimated_diffuse, input_img.to(device)], dim=1) + estimated_diffuse_refined = UNet3(G3_data.to(device)) + + # the third stage (tone correction) + input_img = input_img.to(device) + G4_input = torch.cat([estimated_diffuse_refined, estimated_specular_residue, input_img], dim=1) + estimated_diffuse_tc = UNet4(G4_input.to(device)) + + # to cpu + estimated_diffuse = estimated_diffuse.to(torch.device('cpu')) + estimated_diffuse_refined = estimated_diffuse_refined.to(torch.device('cpu')) + estimated_diffuse_tc = estimated_diffuse_tc.to(torch.device('cpu')) + input_img = input_img.to(torch.device('cpu')) + + + grid= make_grid(torch.cat((unnormalize(input_img), unnormalize(gt_diffuse), unnormalize(estimated_diffuse_refined), unnormalize(estimated_diffuse_tc)), dim=0)) + + temp = len(test_dataset.img_list['path_i'][n].split('/')) + r_subdir = test_dataset.img_list['path_i'][n].split('/')[temp-2] + basename = test_dataset.img_list['path_i'][n].split('/')[temp-1] + print(r_subdir) + print(basename) + subdir = os.path.join(dir_t3, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + grid_name = os.path.join(subdir, basename) + save_image(grid, grid_name) + + # save diffuse images + estimated_diffuse_refined = transforms.ToPILImage(mode='RGB')(unnormalize(estimated_diffuse_refined)[0, :, :, :]) + subdir = os.path.join(dir_t4, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + estimated_diffuse_name = os.path.join(subdir, basename) + estimated_diffuse_refined.save(estimated_diffuse_name) + + # save tone-corrected diffuse images + estimated_diffuse_tc = transforms.ToPILImage(mode='RGB')(unnormalize(estimated_diffuse_tc)[0, :, :, :]) + subdir = os.path.join(dir_t5, r_subdir) + if not os.path.exists(subdir): + os.makedirs(subdir) + estimated_diffuse_tc_name = os.path.join(subdir, basename) + estimated_diffuse_tc.save(estimated_diffuse_tc_name) + + +def main(parser): + UNet1 = UNet(input_channels=3, output_channels=3) + UNet2 = UNet(input_channels=3, output_channels=3) + UNet3 = UNet(input_channels=6, output_channels=3) + UNet4 = UNet(input_channels=9, output_channels=3) + + if parser.load is not None: + print('load checkpoint ' + parser.load) + # load UNet weights + UNet1_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet1_'+parser.load+'.pth') + UNet1.load_state_dict(fix_model_state_dict(UNet1_weights)) + UNet2_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet2_'+parser.load+'.pth') + UNet2.load_state_dict(fix_model_state_dict(UNet2_weights)) + UNet3_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet3_'+parser.load+'.pth') + UNet3.load_state_dict(fix_model_state_dict(UNet3_weights)) + UNet4_weights = torch.load('./checkpoints_'+parser.model_name+'/UNet4_'+parser.load+'.pth') + UNet4.load_state_dict(fix_model_state_dict(UNet4_weights)) + + mean = (0.5,) + std = (0.5,) + + size = parser.image_size + testing_data_dir = parser.testing_data_dir + testing_data_list_file = parser.testing_data_list_file + model_name = parser.model_name + + test_img_list = generate_testing_data_list(testing_data_dir, testing_data_list_file) + test_dataset = ImageDataset(img_list=test_img_list, img_transform=ImageTransform(size=size, mean=mean, std=std), phase='test') + test(UNet1, UNet2, UNet3, UNet4, model_name, test_dataset) + +if __name__ == "__main__": + parser = get_parser().parse_args() + if parser.load is not None: + load_num = str(parser.load) + model_name = parser.model_name + testing_data_name = parser.testing_data_name + key_dir = testing_data_name + '_' + model_name + '_' + load_num # like SSHR_SSHR_60 + + main(parser) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..c13e9a3 --- /dev/null +++ b/train.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e + +# to train on our SSHR dataset +python train_4_networks.py \ + -trdd dataset \ + -trdlf dataset/SSHR/train_7_tuples.lst \ + -dn SSHR + +# # to train on SHIQ dataset +# python train_4_networks_mix.py \ + # -trdd dataset \ + # -trdlf dataset/SHIQ_data_10825/train.lst \ + # -dn SHIQ + +# # to train on PSD dataset +# python train_4_networks_mix.py \ + # -trdd dataset\ + # -trdlf dataset/PSD/train.lst \ + # -dn PSD_debug_3_train + + +# # to train on the mix data of SSHR, SHIQ, and PSD, which could produce better results for real images +# # generate list file +# cat dataset/SSHR/train_4_tuples.lst dataset/SHIQ_data_10825/train.lst dataset/PSD/train.lst >> dataset/train_mix.lst +# shuf train_mix.lst -o train_mix.lst +# cat dataset/SSHR/test_4_tuples.lst dataset/SHIQ_data_10825/test.lst dataset/PSD/test.lst >> dataset/test_mix.lst +# python train_4_networks_mix.py \ + # -trdd dataset \ + # -trdlf dataset/train_mix.lst \ + # -dn mix_SSHR_SHIQ_PSD diff --git a/train_4_networks.py b/train_4_networks.py new file mode 100644 index 0000000..0ed4251 --- /dev/null +++ b/train_4_networks.py @@ -0,0 +1,213 @@ +from utils.data_loader_seven_tuples import ImageDataset, ImageTransform, generate_training_data_list, generate_testing_data_list +from utils.fg_tools import fix_model_state_dict, plot_log, check_dir, unnormalize + +from models.UNet import UNet +from tqdm import tqdm +from torchvision import transforms +import torch.optim as optim +import torch.nn as nn +import numpy as np +import argparse +import time +import torch +import os + +torch.manual_seed(1) + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-ne', '--num_epoch', type=int, default=100, help='Number of epochs') + parser.add_argument('-bs', '--batch_size', type=int, default=16, help='Batch size') + parser.add_argument('-lne', '--load_num_epoch', type=str, default=None, help='the number of checkpoints') + parser.add_argument('-s', '--image_size', type=int, default=256) + parser.add_argument('-cs', '--crop_size', type=int, default=256) + parser.add_argument('-lr', '--lr', type=float, default=1e-4) + parser.add_argument('-pdn', '--pretrained_dataset_name', type=str, default=None, help='pretrained model name') + # settings for training and testing data, which can be from different datasets + # Here, using testing data to generate some temp results for observing variations in the results + # for "dataset_name" (e.g. SSHR_SSHR), the first and second "SSHR"s refer to training and testing dataset name, respectively + # this "dataset_name" can be used for mkdir specific dirs for saving results + parser.add_argument('-dn', '--dataset_name', type=str, default='SSHR') + parser.add_argument('-trdd', '--train_data_dir', type=str, default='dataset/shapenet_specular_1500/training_data') + parser.add_argument('-trdlf', '--train_data_list_file', type=str, default='dataset/shapenet_specular_1500/train.lst') + return parser + +def train_model(UNet1, UNet2, UNet3, UNet4, dataloader, load_num_epoch, num_epoch, lr, dataset_name, parser, save_model_name='model'): + + # ensure dirs for saving results + check_dir(dataset_name) + + device = "cuda" if torch.cuda.is_available() else "cpu" + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor + + UNet1.to(device) + UNet2.to(device) + UNet3.to(device) + UNet4.to(device) + + # GPU in parallel + if device == 'cuda': + UNet1 = torch.nn.DataParallel(UNet1) + UNet2 = torch.nn.DataParallel(UNet2) + UNet3 = torch.nn.DataParallel(UNet3) + UNet4 = torch.nn.DataParallel(UNet4) + + print("device:{}".format(device)) + + beta1, beta2 = 0.5, 0.999 + + optimizerUNet = torch.optim.Adam([{'params': UNet1.parameters()}, + {'params': UNet2.parameters()}, + {'params': UNet3.parameters()}, + {'params': UNet4.parameters()}], + lr=lr, betas=(beta1, beta2)) + + loss_criterion = nn.MSELoss().to(device) + + torch.backends.cudnn.benchmark = True + + mini_batch_size = parser.batch_size + num_train_imgs = len(dataloader.dataset) + batch_size = dataloader.batch_size + + UNets_loss = [] + + if load_num_epoch is not None: + epoch_old = int(load_num_epoch) + else: + epoch_old = 0 + + for epoch in range(1, num_epoch+1): + UNet1.train() + UNet2.train() + UNet3.train() + UNet4.train() + + epoch = epoch + epoch_old + epoch_L_total = 0.0 + + print('Epoch {}/{}'.format(epoch, num_epoch+epoch_old)) + print('(train)') + + for img, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask in tqdm(dataloader): + img = img.to(device) + gt_shading = gt_shading.to(device) + gt_albedo = gt_albedo.to(device) + gt_diffuse = gt_diffuse.to(device) + gt_specular_residue = gt_specular_residue.to(device) + gt_diffuse_tc = gt_diffuse_tc.to(device) + + ## estimations in our three-stage network + # estimations in the first stage + estimated_albedo = UNet1(img) + estimated_shading = UNet2(img) + estimated_specular_residue = (img - estimated_albedo * estimated_shading) + + # estimation in the second stage + G3_input = torch.cat([estimated_albedo * estimated_shading, img], dim=1) + estimated_diffuse_refined = UNet3(G3_input) + + # estimation in the third stage + G4_input = torch.cat([estimated_diffuse_refined, estimated_specular_residue, img], dim=1) + estimated_diffuse_tc = UNet4(G4_input) + ## end + + # Train networks + optimizerUNet.zero_grad() + + ## loss for our three-stage network + + # loss for the first stage (physics-based specular removal) + L_albedo = loss_criterion(estimated_albedo, gt_albedo) + L_shading = loss_criterion(estimated_shading, gt_shading) + L_specular_residue = loss_criterion(estimated_specular_residue, gt_specular_residue) + + # loss for the second stage (specular-free refinement) + L_diffuse_refined = loss_criterion(estimated_diffuse_refined, gt_diffuse) + + # loss for the thrid stage (tone correction) + L_diffuse_tc = loss_criterion(estimated_diffuse_tc, gt_diffuse_tc) + ## end + + # total loss + # If you want to obtain better albedo and shading images, + # it is better to use a lower weighting parameter for the loss of specular residue learning. + # This is mainly attributed to the data transfer from high + # dynamic range to low dynamic range (I is not equal to # A*S+R). + # Please, uncomment the following first line and + # comment the following second line + # L_total = L_albedo + L_shading + 0.01 * L_specular_residue + L_diffuse_refined + L_diffuse_tc + L_total = L_albedo + L_shading + L_specular_residue + L_diffuse_refined + L_diffuse_tc + + L_total.backward() + optimizerUNet.step() + + epoch_L_total += L_total.item() + + print('epoch {} || Epoch_Net_Loss:{:.4f}'.format(epoch, epoch_L_total/batch_size)) + + UNets_loss += [epoch_L_total/batch_size] + t_epoch_start = time.time() + plot_log({'UNets':UNets_loss}, dataset_name, save_model_name) + + if(epoch%10 == 0): + torch.save(UNet1.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'1_'+str(epoch)+'.pth') + torch.save(UNet2.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'2_'+str(epoch)+'.pth') + torch.save(UNet3.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'3_'+str(epoch)+'.pth') + torch.save(UNet4.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'4_'+str(epoch)+'.pth') + + # update learning rate + lr /= 10 + + return UNet1, UNet2, UNet3, UNet4 + +def main(parser): + UNet1 = UNet(input_channels=3, output_channels=3) + UNet2 = UNet(input_channels=3, output_channels=3) + UNet3 = UNet(input_channels=6, output_channels=3) + UNet4 = UNet(input_channels=9, output_channels=3) + + mean = (0.5,) + std = (0.5,) + + size = parser.image_size + crop_size = parser.crop_size + batch_size = parser.batch_size + num_epoch = parser.num_epoch + train_data_dir = parser.train_data_dir + train_data_list_file = parser.train_data_list_file + lr = parser.lr + dataset_name = parser.dataset_name + pretrained_dataset_name = parser.pretrained_dataset_name + load_num_epoch =parser.load_num_epoch + + if parser.load_num_epoch is not None: + print('load checkpoint ' + parser.load_num_epoch) + # load UNet weights + UNet1_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet1_'+parser.load_num_epoch+'.pth') + UNet1.load_state_dict(fix_model_state_dict(UNet1_weights)) + UNet2_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet2_'+parser.load_num_epoch+'.pth') + UNet2.load_state_dict(fix_model_state_dict(UNet2_weights)) + UNet3_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet3_'+parser.load_num_epoch+'.pth') + UNet3.load_state_dict(fix_model_state_dict(UNet3_weights)) + UNet4_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet4_'+parser.load_num_epoch+'.pth') + UNet4.load_state_dict(fix_model_state_dict(UNet4_weights)) + + train_img_list = generate_training_data_list(train_data_dir, train_data_list_file) + train_dataset = ImageDataset(img_list=train_img_list, + img_transform=ImageTransform(size=size, mean=mean, std=std), + phase='train') + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + UNet1_ud, UNet2_ud, UNet3_ud, UNet4_ud = train_model(UNet1, UNet2, UNet3, UNet4, + dataloader=train_dataloader, + load_num_epoch=load_num_epoch, + num_epoch=num_epoch, + lr=lr, dataset_name=dataset_name, + parser=parser, save_model_name='UNet') + + +if __name__ == "__main__": + parser = get_parser().parse_args() + main(parser) diff --git a/train_4_networks_mix.py b/train_4_networks_mix.py new file mode 100644 index 0000000..9f07ce8 --- /dev/null +++ b/train_4_networks_mix.py @@ -0,0 +1,201 @@ +from utils.data_loader_four_tuples_mix import ImageDataset, ImageTransform, generate_training_data_list, generate_testing_data_list +from utils.fg_tools import fix_model_state_dict, plot_log, check_dir, unnormalize + +from models.UNet import UNet +from tqdm import tqdm +from torchvision import transforms +import torch.optim as optim +import torch.nn as nn +import numpy as np +import argparse +import time +import torch +import os + +torch.manual_seed(1) + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument('-ne', '--num_epoch', type=int, default=100, help='Number of epochs') + parser.add_argument('-bs', '--batch_size', type=int, default=16, help='Batch size') + parser.add_argument('-lne', '--load_num_epoch', type=str, default=None, help='the number of checkpoints') + parser.add_argument('-s', '--image_size', type=int, default=256) + parser.add_argument('-cs', '--crop_size', type=int, default=256) + parser.add_argument('-lr', '--lr', type=float, default=1e-4) + parser.add_argument('-pdn', '--pretrained_dataset_name', type=str, default=None, help='pretrained model name') + # settings for training and testing data, which can be from different datasets + # Here, using testing data to generate some temp results for observing variations in the results + # for "dataset_name" (e.g. SSHR_SSHR), the first and second "SSHR"s refer to training and testing dataset name, respectively + # this "dataset_name" can be used for mkdir specific dirs for saving results + parser.add_argument('-dn', '--dataset_name', type=str, default='SSHR') + parser.add_argument('-trdd', '--train_data_dir', type=str, default='dataset/shapenet_specular_1500/training_data') + parser.add_argument('-trdlf', '--train_data_list_file', type=str, default='dataset/shapenet_specular_1500/train.lst') + return parser + +def train_model(UNet1, UNet2, UNet3, UNet4, dataloader, load_num_epoch, num_epoch, lr, dataset_name, parser, save_model_name='model'): + + # ensure dirs for saving results + check_dir(dataset_name) + + device = "cuda" if torch.cuda.is_available() else "cpu" + Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor + + UNet1.to(device) + UNet2.to(device) + UNet3.to(device) + UNet4.to(device) + + # GPU in parallel + if device == 'cuda': + UNet1 = torch.nn.DataParallel(UNet1) + UNet2 = torch.nn.DataParallel(UNet2) + UNet3 = torch.nn.DataParallel(UNet3) + UNet4 = torch.nn.DataParallel(UNet4) + + print("device:{}".format(device)) + + beta1, beta2 = 0.5, 0.999 + + optimizerUNet = torch.optim.Adam([{'params': UNet1.parameters()}, + {'params': UNet2.parameters()}, + {'params': UNet3.parameters()}, + {'params': UNet4.parameters()}], + lr=lr, betas=(beta1, beta2)) + + loss_criterion = nn.MSELoss().to(device) + + torch.backends.cudnn.benchmark = True + + mini_batch_size = parser.batch_size + num_train_imgs = len(dataloader.dataset) + batch_size = dataloader.batch_size + + UNets_loss = [] + + if load_num_epoch is not None: + epoch_old = int(load_num_epoch) + else: + epoch_old = 0 + + for epoch in range(1, num_epoch+1): + UNet1.train() + UNet2.train() + UNet3.train() + UNet4.train() + + epoch = epoch + epoch_old + epoch_L_total = 0.0 + + print('Epoch {}/{}'.format(epoch, num_epoch+epoch_old)) + print('(train)') + + for input_img, gt_specular_residue, gt_diffuse, gt_diffuse_tc in tqdm(dataloader): + input_img = input_img.to(device) + gt_diffuse = gt_diffuse.to(device) + gt_specular_residue = gt_specular_residue.to(device) + gt_diffuse_tc = gt_diffuse_tc.to(device) + + ## estimations in our three-stage network + # estimations in the first stage + estimated_diffuse = UNet1(input_img) + estimated_specular_residue = UNet2(input_img) + + # estimation in the second stage + G3_input = torch.cat([estimated_diffuse, input_img], dim=1) + estimated_diffuse_refined = UNet3(G3_input) + + # estimation in the third stage + G4_input = torch.cat([estimated_diffuse_refined, estimated_specular_residue, input_img], dim=1) + estimated_diffuse_tc = UNet4(G4_input) + ## end + + # Train networks + optimizerUNet.zero_grad() + + ## loss for our three-stage network + # loss for the first stage (physics-based specular removal) + L_diffuse = loss_criterion(estimated_diffuse, gt_diffuse) + L_specular_residue = loss_criterion(estimated_specular_residue, gt_specular_residue) + + # loss for the second stage (specular-free refinement) + L_diffuse_refined = loss_criterion(estimated_diffuse_refined, gt_diffuse) + + # loss for the thrid stage (tone correction) + L_diffuse_tc = loss_criterion(estimated_diffuse_tc, gt_diffuse_tc) + ## end + + # total loss + L_total = L_diffuse + L_specular_residue + L_diffuse_refined + L_diffuse_tc + + L_total.backward() + optimizerUNet.step() + + epoch_L_total += L_total.item() + + print('epoch {} || Epoch_Net_Loss:{:.4f}'.format(epoch, epoch_L_total/batch_size)) + + UNets_loss += [epoch_L_total/batch_size] + t_epoch_start = time.time() + plot_log({'UNets':UNets_loss}, dataset_name, save_model_name) + + if(epoch%10 == 0): + torch.save(UNet1.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'1_'+str(epoch)+'.pth') + torch.save(UNet2.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'2_'+str(epoch)+'.pth') + torch.save(UNet3.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'3_'+str(epoch)+'.pth') + torch.save(UNet4.state_dict(), 'checkpoints_'+dataset_name+'/'+save_model_name+'4_'+str(epoch)+'.pth') + + # update learning rate + lr /= 10 + + return UNet1, UNet2, UNet3, UNet4 + +def main(parser): + UNet1 = UNet(input_channels=3, output_channels=3) + UNet2 = UNet(input_channels=3, output_channels=3) + UNet3 = UNet(input_channels=6, output_channels=3) + UNet4 = UNet(input_channels=9, output_channels=3) + + mean = (0.5,) + std = (0.5,) + + size = parser.image_size + crop_size = parser.crop_size + batch_size = parser.batch_size + num_epoch = parser.num_epoch + train_data_dir = parser.train_data_dir + train_data_list_file = parser.train_data_list_file + lr = parser.lr + dataset_name = parser.dataset_name + pretrained_dataset_name = parser.pretrained_dataset_name + load_num_epoch =parser.load_num_epoch + + if parser.load_num_epoch is not None: + print('load checkpoint ' + parser.load_num_epoch) + # load UNet weights + UNet1_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet1_'+parser.load_num_epoch+'.pth') + UNet1.load_state_dict(fix_model_state_dict(UNet1_weights)) + UNet2_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet2_'+parser.load_num_epoch+'.pth') + UNet2.load_state_dict(fix_model_state_dict(UNet2_weights)) + UNet3_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet3_'+parser.load_num_epoch+'.pth') + UNet3.load_state_dict(fix_model_state_dict(UNet3_weights)) + UNet4_weights = torch.load('./checkpoints_'+pretrained_dataset_name+'/UNet4_'+parser.load_num_epoch+'.pth') + UNet4.load_state_dict(fix_model_state_dict(UNet4_weights)) + + train_img_list = generate_training_data_list(train_data_dir, train_data_list_file) + train_dataset = ImageDataset(img_list=train_img_list, + img_transform=ImageTransform(size=size, mean=mean, std=std), + phase='train') + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + UNet1_ud, UNet2_ud, UNet3_ud, UNet4_ud = train_model(UNet1, UNet2, UNet3, UNet4, + dataloader=train_dataloader, + load_num_epoch=load_num_epoch, + num_epoch=num_epoch, + lr=lr, dataset_name=dataset_name, + parser=parser, save_model_name='UNet') + + +if __name__ == "__main__": + parser = get_parser().parse_args() + main(parser) diff --git a/utils/data_loader_four_tuples_mix.py b/utils/data_loader_four_tuples_mix.py new file mode 100644 index 0000000..d2c8b39 --- /dev/null +++ b/utils/data_loader_four_tuples_mix.py @@ -0,0 +1,113 @@ +import os +import torch.utils.data as data +from . import four_tuple_data_processing +from PIL import Image +import random +from torchvision import transforms + + +def generate_training_data_list(training_data_dir, training_data_list_file): + # shapenet_specular training dataset + # training_data_dir = 'dataset/shapenet_specular_1500/training_data' + # training_data_list_file = 'dataset/shapenet_specular_1500/train_tc.lst' + + random.seed(1) + + path_i = [] # input + path_r = [] # specular residue + path_d = [] # diffuse + path_d_tc = [] # gamma correction version of diffuse + with open(training_data_list_file, 'r') as f: + image_list = [x.strip() for x in f.readlines()] + random.shuffle(image_list) + for name in image_list: + path_i.append(os.path.join(training_data_dir, name.split()[0])) # input + path_r.append(os.path.join(training_data_dir, name.split()[1])) # specular residue + path_d.append(os.path.join(training_data_dir, name.split()[2])) # diffuse + path_d_tc.append(os.path.join(training_data_dir, name.split()[3])) # gamma correction version of diffuse + + num = len(image_list) + path_i = path_i[:int(num)] + path_r = path_r[:int(num)] + path_d = path_d[:int(num)] + path_d_tc = path_d_tc[:int(num)] + + path_list = {'path_i': path_i, 'path_r': path_r, 'path_d': path_d, 'path_d_tc': path_d_tc} + return path_list + + +def generate_testing_data_list(data_dir, data_list_file): + # shapenet_specular testing data + # data_dir = 'dataset/shapenet_specular_1500/testing_data' + # data_list_file = 'dataset/shapenet_specular_1500/test_tc.lst' + + path_i = [] # input + path_r = [] # specular residue + path_d = [] # diffuse + path_d_tc = [] # gamma correction version of diffuse + with open(data_list_file, 'r') as f: + image_list = [x.strip() for x in f.readlines()] + image_list.sort() + for name in image_list: + path_i.append(os.path.join(data_dir, name.split()[0])) # input + path_r.append(os.path.join(data_dir, name.split()[1])) # specular residue + path_d.append(os.path.join(data_dir, name.split()[2])) # diffuse + path_d_tc.append(os.path.join(data_dir, name.split()[3])) # gamma correction version of diffuse + + num = len(image_list) + path_i = path_i[:int(num)] + path_r = path_r[:int(num)] + path_d = path_d[:int(num)] + path_d_tc = path_d_tc[:int(num)] + + path_list = {'path_i': path_i,'path_r': path_r, 'path_d': path_d, 'path_d_tc': path_d_tc} + + return path_list + + +class ImageTransformSingle(): + def __init__(self, size=256, mean=(0.5, ), std=(0.5, )): + self.data_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean, std)]) + + def __call__(self, img): + return self.data_transform(img) + + +class ImageTransform(): + def __init__(self, size=256, crop_size=256, mean=(0.5, ), std=(0.5, )): + self.data_transform = {'train': four_tuple_data_processing.Compose([four_tuple_data_processing.Scale(size=size), + four_tuple_data_processing.RandomCrop(size=crop_size), + four_tuple_data_processing.RandomHorizontalFlip(p=0.5), + four_tuple_data_processing.ToTensor(), + four_tuple_data_processing.Normalize(mean, std)]), + + 'test': four_tuple_data_processing.Compose([four_tuple_data_processing.Scale(size=size), + four_tuple_data_processing.RandomCrop(size=crop_size), + four_tuple_data_processing.RandomHorizontalFlip(p=0.5), + four_tuple_data_processing.ToTensor(), + four_tuple_data_processing.Normalize(mean, std)])} + + def __call__(self, phase, img): + return self.data_transform[phase](img) + + +class ImageDataset(data.Dataset): + def __init__(self, img_list, img_transform, phase): + self.img_list = img_list + self.img_transform = img_transform + self.phase = phase + + def __len__(self): + return len(self.img_list['path_i']) + + def __getitem__(self, index): + input = Image.open(self.img_list['path_i'][index]).convert('RGB') + gt_specular_residue = Image.open(self.img_list['path_r'][index]).convert('RGB') + gt_diffuse = Image.open(self.img_list['path_d'][index]).convert('RGB') + gt_diffuse_tc= Image.open(self.img_list['path_d_tc'][index]).convert('RGB') + + # data pre-processing + input, gt_specular_residue, gt_diffuse, gt_diffuse_tc = self.img_transform(self.phase, [input, gt_specular_residue, gt_diffuse, gt_diffuse_tc]) + + return input, gt_specular_residue, gt_diffuse, gt_diffuse_tc diff --git a/utils/data_loader_seven_tuples.py b/utils/data_loader_seven_tuples.py new file mode 100644 index 0000000..3756a2b --- /dev/null +++ b/utils/data_loader_seven_tuples.py @@ -0,0 +1,134 @@ +import os +import torch.utils.data as data +from . import seven_tuple_data_processing +from PIL import Image +import random +from torchvision import transforms + + +def generate_training_data_list(training_data_dir, training_data_list_file): + # shapenet_specular training dataset + # training_data_dir = 'dataset/shapenet_specular_1500/training_data' + # training_data_list_file = 'dataset/shapenet_specular_1500/train_tc.lst' + + random.seed(1) + + path_i = [] # input + path_a = [] # albedo + path_s = [] # shading + path_r = [] # specular residue + path_d = [] # diffuse + path_d_tc = [] # gamma correction version of diffuse + path_m = [] # mask + with open(training_data_list_file, 'r') as f: + image_list = [x.strip() for x in f.readlines()] + random.shuffle(image_list) + for name in image_list: + path_i.append(os.path.join(training_data_dir, name.split()[0])) # input + path_a.append(os.path.join(training_data_dir, name.split()[1])) # albedo + path_s.append(os.path.join(training_data_dir, name.split()[2])) # shading + path_r.append(os.path.join(training_data_dir, name.split()[3])) # specular residue + path_d.append(os.path.join(training_data_dir, name.split()[4])) # diffuse + path_d_tc.append(os.path.join(training_data_dir, name.split()[5])) # gamma correction version of diffuse + path_m.append(os.path.join(training_data_dir, name.split()[6])) # mask + + num = len(image_list) + path_i = path_i[:int(num)] + path_a = path_a[:int(num)] + path_s = path_s[:int(num)] + path_r = path_r[:int(num)] + path_d = path_d[:int(num)] + path_d_tc = path_d_tc[:int(num)] + path_m = path_m[:int(num)] + + path_list = {'path_i': path_i, 'path_a': path_a, 'path_s': path_s, 'path_r': path_r, 'path_d': path_d, 'path_d_tc': path_d_tc, 'path_m': path_m} + return path_list + + +def generate_testing_data_list(data_dir, data_list_file): + # shapenet_specular testing data + # data_dir = 'dataset/shapenet_specular_1500/testing_data' + # data_list_file = 'dataset/shapenet_specular_1500/test_tc.lst' + + path_i = [] # input + path_a = [] # albedo + path_s = [] # shading + path_r = [] # specular residue + path_d = [] # diffuse + path_d_tc = [] # gamma correction version of diffuse + path_m = [] # mask + with open(data_list_file, 'r') as f: + image_list = [x.strip() for x in f.readlines()] + image_list.sort() + for name in image_list: + path_i.append(os.path.join(data_dir, name.split()[0])) # input + path_a.append(os.path.join(data_dir, name.split()[1])) # albedo + path_s.append(os.path.join(data_dir, name.split()[2])) # shading + path_r.append(os.path.join(data_dir, name.split()[3])) # specular residue + path_d.append(os.path.join(data_dir, name.split()[4])) # diffuse + path_d_tc.append(os.path.join(data_dir, name.split()[5])) # gamma correction version of diffuse + path_m.append(os.path.join(data_dir, name.split()[6])) # mask + + num = len(image_list) + path_i = path_i[:int(num)] + path_a = path_a[:int(num)] + path_s = path_s[:int(num)] + path_r = path_r[:int(num)] + path_d = path_d[:int(num)] + path_d_tc = path_d_tc[:int(num)] + path_m = path_m[:int(num)] + + path_list = {'path_i': path_i, 'path_a': path_a, 'path_s': path_s, 'path_r': path_r, 'path_d': path_d, 'path_d_tc': path_d_tc, 'path_m': path_m} + + return path_list + + +class ImageTransformSingle(): + def __init__(self, size=256, mean=(0.5, ), std=(0.5, )): + self.data_transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean, std)]) + + def __call__(self, img): + return self.data_transform(img) + + +class ImageTransform(): + def __init__(self, size=256, crop_size=256, mean=(0.5, ), std=(0.5, )): + self.data_transform = {'train': seven_tuple_data_processing.Compose([seven_tuple_data_processing.Scale(size=size), + seven_tuple_data_processing.RandomCrop(size=crop_size), + seven_tuple_data_processing.RandomHorizontalFlip(p=0.5), + seven_tuple_data_processing.ToTensor(), + seven_tuple_data_processing.Normalize(mean, std)]), + + 'test': seven_tuple_data_processing.Compose([seven_tuple_data_processing.Scale(size=size), + seven_tuple_data_processing.RandomCrop(size=crop_size), + seven_tuple_data_processing.RandomHorizontalFlip(p=0.5), + seven_tuple_data_processing.ToTensor(), + seven_tuple_data_processing.Normalize(mean, std)])} + + def __call__(self, phase, img): + return self.data_transform[phase](img) + + +class ImageDataset(data.Dataset): + def __init__(self, img_list, img_transform, phase): + self.img_list = img_list + self.img_transform = img_transform + self.phase = phase + + def __len__(self): + return len(self.img_list['path_i']) + + def __getitem__(self, index): + input = Image.open(self.img_list['path_i'][index]).convert('RGB') + gt_albedo = Image.open(self.img_list['path_a'][index]).convert('RGB') + gt_shading = Image.open(self.img_list['path_s'][index]).convert('RGB') + gt_specular_residue = Image.open(self.img_list['path_r'][index]).convert('RGB') + gt_diffuse = Image.open(self.img_list['path_d'][index]).convert('RGB') + gt_diffuse_tc= Image.open(self.img_list['path_d_tc'][index]).convert('RGB') + object_mask= Image.open(self.img_list['path_m'][index]).convert('RGB') + + # data pre-processing + input, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask = self.img_transform(self.phase, [input, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask]) + + return input, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask diff --git a/utils/fg_tools.py b/utils/fg_tools.py new file mode 100644 index 0000000..dc7bfb3 --- /dev/null +++ b/utils/fg_tools.py @@ -0,0 +1,109 @@ +import os +import torch +from torchvision.utils import make_grid +from torchvision.utils import save_image +import matplotlib.pyplot as plt + + +def fix_model_state_dict(state_dict): + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k + if name.startswith('module.'): + name = name[7:] + new_state_dict[name] = v + return new_state_dict + + +def plot_log(data, dataset_name, save_model_name='model'): + plt.cla() + plt.plot(data['UNets'], label='L_total ') + plt.legend() + plt.xlabel('epoch') + plt.ylabel('loss') + plt.title('Loss') + plt.savefig('./logs_'+dataset_name+'/'+save_model_name+'.png') + + +def check_dir(dataset_name): + if not os.path.exists('./logs_' + dataset_name): + os.mkdir('./logs_' + dataset_name) + if not os.path.exists('./checkpoints_' + dataset_name): + os.mkdir('./checkpoints_' + dataset_name) + + +def unnormalize(x): + x = x.transpose(1, 3) + x = x * torch.Tensor((0.5, )) + torch.Tensor((0.5, )) + x = x.transpose(1, 3) + return x + + +def evaluate(UNet1, UNet2, UNet3, UNet4, dataset, device, filename): + img, gt_albedo, gt_shading, gt_specular_residue, gt_diffuse, gt_diffuse_tc, object_mask = zip(*[dataset[i] for i in range(8)]) + img = torch.stack(img) + gt_diffuse = torch.stack(gt_diffuse) + gt_diffuse_tc = torch.stack(gt_diffuse_tc) + object_mask = torch.stack(object_mask) + + + img = img.to(device) + object_mask = object_mask.to(device) + + with torch.no_grad(): + ## estimations in our three-stage network + # estimations in the first stage + estimated_albedo = UNet1(img) + estimated_shading = UNet2(img) + estimated_specular_residue = (img - estimated_albedo * estimated_shading) + + # estimation in the second stage + G3_input = torch.cat([estimated_albedo * estimated_shading * object_mask, img], dim=1) + estimated_diffuse_refined = UNet3(G3_input) + + # estimation in the third stage + G4_input = torch.cat([estimated_diffuse_refined * object_mask, estimated_specular_residue * object_mask, img], dim=1) + estimated_diffuse_tc = UNet4(G4_input) + ## end + + # to cpu + estimated_albedo = estimated_albedo.to(torch.device('cpu')) + estimated_shading = estimated_shading.to(torch.device('cpu')) + estimated_diffuse_refined = estimated_diffuse_refined.to(torch.device('cpu')) + estimated_diffuse_tc = estimated_diffuse_tc.to(torch.device('cpu')) + img = img.to(torch.device('cpu')) + object_mask = object_mask.to(torch.device('cpu')) + + grid_removal = make_grid(torch.cat((img,gt_diffuse,estimated_diffuse_refined * object_mask, estimated_diffuse_tc * object_mask), dim=0)) + save_image(grid_removal, filename+'_overview.jpg') + + +def evaluate_mix(UNet1, UNet2, UNet3, UNet4, dataset, device, filename): + input_img, gt_specular_residue, gt_diffuse, gt_diffuse_tc = zip(*[dataset[i] for i in range(16)]) # 8 in default + input_img = torch.stack(input_img) + gt_diffuse = torch.stack(gt_diffuse) + gt_diffuse_tc = torch.stack(gt_diffuse_tc) + + with torch.no_grad(): + # first stage (physics-based specular highlight removal) + estimated_diffuse = UNet1(input_img.to(device)) + estimated_specular_residue = UNet2(input_img.to(device)) + + # second stage (specular-free refinement) + G3_data = torch.cat([estimated_diffuse, input_img.to(device)], dim=1) + estimated_diffuse_refined = UNet3(G3_data.to(device)) + + # third stage (tone correction) + input_img = input_img.to(device) + G4_input = torch.cat([estimated_diffuse_refined, estimated_specular_residue, input_img], dim=1) + estimated_diffuse_tc = UNet4(G4_input.to(device)) + + # to cpu + estimated_diffuse = estimated_diffuse.to(torch.device('cpu')) + estimated_specular_residue = estimated_specular_residue.to(torch.device('cpu')) + estimated_diffuse_refined = estimated_diffuse_refined.to(torch.device('cpu')) + estimated_diffuse_tc = estimated_diffuse_tc.to(torch.device('cpu')) + input_img = input_img.to(torch.device('cpu')) + + grid_removal = make_grid(torch.cat((unnormalize(input_img), unnormalize(gt_diffuse), unnormalize(estimated_diffuse_refined), unnormalize(estimated_diffuse_tc)), dim=0)) + save_image(grid_removal, filename+'_overview.jpg') diff --git a/utils/four_tuple_data_processing.py b/utils/four_tuple_data_processing.py new file mode 100644 index 0000000..0574813 --- /dev/null +++ b/utils/four_tuple_data_processing.py @@ -0,0 +1,139 @@ +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import Tuple, List, Optional + +import torch +from PIL import Image +from torch import Tensor +import torchvision.transforms.functional as F + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ToTensor(object): + def __call__(self, pic): + return F.to_tensor(pic[0]), F.to_tensor(pic[1]), F.to_tensor(pic[2]), F.to_tensor(pic[3]) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class Scale(object): + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgs): + output = [] + for img in imgs: + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + output.append(img) + continue + if w < h: + ow = self.size + oh = int(self.size * h / w) + output.append(img.resize((ow, oh), self.interpolation)) + continue + else: + oh = self.size + ow = int(self.size * w / h) + output.append(img.resize((ow, oh), self.interpolation)) + return output[0], output[1], output[2], output[3] + + +class Normalize(object): + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, tensor): + return F.normalize(tensor[0], self.mean, self.std, self.inplace), F.normalize(tensor[1], self.mean, self.std, self.inplace), F.normalize(tensor[2], self.mean, self.std, self.inplace), F.normalize(tensor[3], self.mean, self.std, self.inplace) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +class RandomCrop(torch.nn.Module): + @staticmethod + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + return i, j, th, tw + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) + else: + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + # cast to tuple for torchscript + self.size = tuple(size) + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + if self.padding is not None: + img[0] = F.pad(img[0], self.padding, self.fill, self.padding_mode) + + width, height = img[0].size + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img[0], self.size) + + return F.crop(img[0], i, j, h, w), F.crop(img[1], i, j, h, w), F.crop(img[2], i, j, h, w), F.crop(img[3], i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) + + +class RandomHorizontalFlip(torch.nn.Module): + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + if torch.rand(1) < self.p: + return F.hflip(img[0]), F.hflip(img[1]), F.hflip(img[2]), F.hflip(img[3]) + return img[0], img[1], img[2], img[3] + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) diff --git a/utils/seven_tuple_data_processing.py b/utils/seven_tuple_data_processing.py new file mode 100644 index 0000000..40b950b --- /dev/null +++ b/utils/seven_tuple_data_processing.py @@ -0,0 +1,139 @@ +import math +import numbers +import random +import warnings +from collections.abc import Sequence +from typing import Tuple, List, Optional + +import torch +from PIL import Image +from torch import Tensor +import torchvision.transforms.functional as F + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img): + for t in self.transforms: + img = t(img) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ToTensor(object): + def __call__(self, pic): + return F.to_tensor(pic[0]), F.to_tensor(pic[1]), F.to_tensor(pic[2]), F.to_tensor(pic[3]), F.to_tensor(pic[4]), F.to_tensor(pic[5]), F.to_tensor(pic[6]) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class Scale(object): + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgs): + output = [] + for img in imgs: + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + output.append(img) + continue + if w < h: + ow = self.size + oh = int(self.size * h / w) + output.append(img.resize((ow, oh), self.interpolation)) + continue + else: + oh = self.size + ow = int(self.size * w / h) + output.append(img.resize((ow, oh), self.interpolation)) + return output[0], output[1], output[2], output[3], output[4], output[5], output[6] + + +class Normalize(object): + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, tensor): + return F.normalize(tensor[0], self.mean, self.std, self.inplace), F.normalize(tensor[1], self.mean, self.std, self.inplace), F.normalize(tensor[2], self.mean, self.std, self.inplace), F.normalize(tensor[3], self.mean, self.std, self.inplace), F.normalize(tensor[4], self.mean, self.std, self.inplace), F.normalize(tensor[5], self.mean, self.std, self.inplace), F.normalize(tensor[6], self.mean, self.std, self.inplace) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +class RandomCrop(torch.nn.Module): + @staticmethod + def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int, int]: + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = torch.randint(0, h - th + 1, size=(1, )).item() + j = torch.randint(0, w - tw + 1, size=(1, )).item() + return i, j, th, tw + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): + super().__init__() + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + elif isinstance(size, Sequence) and len(size) == 1: + self.size = (size[0], size[0]) + else: + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + # cast to tuple for torchscript + self.size = tuple(size) + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def forward(self, img): + if self.padding is not None: + img[0] = F.pad(img[0], self.padding, self.fill, self.padding_mode) + + width, height = img[0].size + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img[0] = F.pad(img[0], padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img[0], self.size) + + return F.crop(img[0], i, j, h, w), F.crop(img[1], i, j, h, w), F.crop(img[2], i, j, h, w), F.crop(img[3], i, j, h, w), F.crop(img[4], i, j, h, w), F.crop(img[5], i, j, h, w), F.crop(img[6], i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + "(size={0}, padding={1})".format(self.size, self.padding) + + +class RandomHorizontalFlip(torch.nn.Module): + def __init__(self, p=0.5): + super().__init__() + self.p = p + + def forward(self, img): + if torch.rand(1) < self.p: + return F.hflip(img[0]), F.hflip(img[1]), F.hflip(img[2]), F.hflip(img[3]), F.hflip(img[4]), F.hflip(img[5]), F.hflip(img[6]) + return img[0], img[1], img[2], img[3], img[4], img[5], img[6] + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p)