diff --git a/.gitignore b/.gitignore index 0bee8b0d..5f4613a6 100755 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ results/ *.zip *.pkl *.pyc +.ipynb_checkpoints/ diff --git a/data/pix2pix_dataset.py b/data/pix2pix_dataset.py index d32d4cfa..b366be42 100644 --- a/data/pix2pix_dataset.py +++ b/data/pix2pix_dataset.py @@ -4,6 +4,7 @@ """ from data.base_dataset import BaseDataset, get_params, get_transform +from data.image_folder import make_dataset from PIL import Image import util.util as util import os @@ -43,10 +44,22 @@ def initialize(self, opt): self.dataset_size = size def get_paths(self, opt): - label_paths = [] - image_paths = [] - instance_paths = [] - assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" + phase = 'train' if opt.isTrain else 'test' + label_dir = os.path.join(opt.dataroot, f'{phase}_A') + label_paths = make_dataset(label_dir, recursive=False, read_cache=True) + + image_dir = os.path.join(opt.dataroot, f'{phase}_B') + image_paths = make_dataset(image_dir, recursive=False, read_cache=True) + + # label_paths = image_paths = list(set(label_paths) & set(image_paths)) + + if opt.label_nc > 0: + instance_dir = os.path.join(opt.dataroot, f'{phase}_inst') + instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) + else: + instance_paths = [] + + assert len(label_paths) == len(image_paths), "The #images in %s and %s do not match. Is there something wrong?" return label_paths, image_paths, instance_paths def paths_match(self, path1, path2): @@ -58,10 +71,13 @@ def __getitem__(self, index): # Label Image label_path = self.label_paths[index] label = Image.open(label_path) + label = label.convert('RGB') params = get_params(self.opt, label.size) - transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) - label_tensor = transform_label(label) * 255.0 - label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc + transform_label = get_transform(self.opt, params) + # transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + # label_tensor = transform_label(label) * 255.0 + # label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc + label_tensor = transform_label(label) # input image (real images) image_path = self.image_paths[index] @@ -102,3 +118,4 @@ def postprocess(self, input_dict): def __len__(self): return self.dataset_size + diff --git a/models/networks/discriminator.py b/models/networks/discriminator.py index 5207bbcc..c02f8fb7 100644 --- a/models/networks/discriminator.py +++ b/models/networks/discriminator.py @@ -100,7 +100,10 @@ def __init__(self, opt): self.add_module('model' + str(n), nn.Sequential(*sequence[n])) def compute_D_input_nc(self, opt): - input_nc = opt.label_nc + opt.output_nc + if opt.label_nc == 0: + input_nc = opt.input_nc * 2 + else: + input_nc = opt.label_nc + opt.output_nc if opt.contain_dontcare_label: input_nc += 1 if not opt.no_instance: diff --git a/models/networks/loss.py b/models/networks/loss.py index b2485d77..d9a3916a 100644 --- a/models/networks/loss.py +++ b/models/networks/loss.py @@ -102,7 +102,10 @@ def __call__(self, input, target_is_real, for_discriminator=True): class VGGLoss(nn.Module): def __init__(self, gpu_ids): super(VGGLoss, self).__init__() - self.vgg = VGG19().cuda() + if len(gpu_ids): + self.vgg = VGG19().cuda() + else: + self.vgg = VGG19() self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 784cb15a..508ae2af 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -106,13 +106,18 @@ def initialize_networks(self, opt): # |data|: dictionary of the input data def preprocess_input(self, data): + if self.opt.label_nc == 0: + if self.use_gpu(): + return data['label'].cuda(), data['image'].cuda() + else: + return data['label'], data['image'] # move to GPU and change data types data['label'] = data['label'].long() if self.use_gpu(): data['label'] = data['label'].cuda() data['instance'] = data['instance'].cuda() data['image'] = data['image'].cuda() - + # create one-hot label map label_map = data['label'] bs, _, h, w = label_map.size() diff --git a/options/base_options.py b/options/base_options.py index 939a26dc..ab60b234 100755 --- a/options/base_options.py +++ b/options/base_options.py @@ -31,12 +31,15 @@ def initialize(self, parser): # input/output sizes parser.add_argument('--batchSize', type=int, default=1, help='input batch size') - parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) - parser.add_argument('--load_size', type=int, default=1024, help='Scale images to this size. The final image will be cropped to --crop_size.') - parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') + parser.add_argument('--preprocess_mode', type=str, default='resize_and_crop', help='scaling and cropping of images at load time.', choices=("resize", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) + parser.add_argument('--load_size', type=int, default=320, help='Scale images to this size. The final image will be cropped to --crop_size.') + parser.add_argument('--crop_size', type=int, default=320, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.') parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)') + parser.add_argument( + "--input_nc", type=int, default=3, help="# of input image channels" + ) parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') # for setting inputs @@ -156,9 +159,12 @@ def parse(self, save=False): # Set semantic_nc based on the option. # This will be convenient in many places - opt.semantic_nc = opt.label_nc + \ - (1 if opt.contain_dontcare_label else 0) + \ - (0 if opt.no_instance else 1) + if opt.label_nc == 0: + opt.semantic_nc = opt.input_nc + else: + opt.semantic_nc = opt.label_nc + \ + (1 if opt.contain_dontcare_label else 0) + \ + (0 if opt.no_instance else 1) # set gpu ids str_ids = opt.gpu_ids.split(',') diff --git a/test.py b/test.py index cb8e7ff1..f78c8cb5 100755 --- a/test.py +++ b/test.py @@ -11,6 +11,8 @@ from models.pix2pix_model import Pix2PixModel from util.visualizer import Visualizer from util import html +from util.util import tensor2im +from PIL import Image opt = TestOptions().parse() @@ -34,12 +36,16 @@ break generated = model(data_i, mode='inference') - + synthesized_image = tensor2im(generated) img_path = data_i['path'] - for b in range(generated.shape[0]): + for b in range(synthesized_image.shape[0]): print('process image... %s' % img_path[b]) visuals = OrderedDict([('input_label', data_i['label'][b]), ('synthesized_image', generated[b])]) visualizer.save_images(webpage, visuals, img_path[b:b + 1]) + save_image_path = os.path.join( + opt.results_dir, os.path.basename(img_path[b])) + Image.fromarray(synthesized_image[b]).save(save_image_path) webpage.save() + diff --git a/train.py b/train.py index 09e18124..7f0f1b2a 100755 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ """ import sys +import numpy as np from collections import OrderedDict from options.train_options import TrainOptions import data @@ -24,14 +25,23 @@ trainer = Pix2PixTrainer(opt) # create tool for counting iterations -iter_counter = IterationCounter(opt, len(dataloader)) +iter_counter = IterationCounter(opt, len(dataloader) * opt.batchSize) # create tool for visualization visualizer = Visualizer(opt) + +clear_iter = False for epoch in iter_counter.training_epochs(): - iter_counter.record_epoch_start(epoch) - for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): + iter_counter.record_epoch_start(epoch, clear_iter) + clear_iter = True + + start_batch_idx = iter_counter.epoch_iter // opt.batchSize + + for i, data_i in enumerate(dataloader): + if i < start_batch_idx: + continue + iter_counter.record_one_iteration() # Training @@ -60,7 +70,7 @@ (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') iter_counter.record_current_iter() - + trainer.update_learning_rate(epoch) iter_counter.record_epoch_end() @@ -70,5 +80,6 @@ (epoch, iter_counter.total_steps_so_far)) trainer.save('latest') trainer.save(epoch) - + print('Training was successfully finished.') + diff --git a/util/iter_counter.py b/util/iter_counter.py index 1a0182fa..19138419 100644 --- a/util/iter_counter.py +++ b/util/iter_counter.py @@ -33,9 +33,10 @@ def __init__(self, opt, dataset_size): def training_epochs(self): return range(self.first_epoch, self.total_epochs + 1) - def record_epoch_start(self, epoch): + def record_epoch_start(self, epoch, clear_iter=True): self.epoch_start_time = time.time() - self.epoch_iter = 0 + if clear_iter: + self.epoch_iter = 0 self.last_iter_time = time.time() self.current_epoch = epoch diff --git a/util/visualizer.py b/util/visualizer.py index 463c7bd4..ef1dce56 100755 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -129,7 +129,10 @@ def convert_visuals_to_numpy(self, visuals): for key, t in visuals.items(): tile = self.opt.batchSize > 8 if 'input_label' == key: - t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) + if self.opt.label_nc == 0: + t = util.tensor2im(t, tile=tile) + else: + t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) else: t = util.tensor2im(t, tile=tile) visuals[key] = t