From 724c624896938cfb6cff2c7248771584650dbc4c Mon Sep 17 00:00:00 2001 From: Mark Zakharov Date: Thu, 1 Aug 2019 15:13:20 +0300 Subject: [PATCH] Initial commit --- .gitignore | 4 ++ pixelcnn/__init__.py | 3 + pixelcnn/conv_layers.py | 64 +++++++++++++++++++ pixelcnn/model.py | 134 ++++++++++++++++++++++++++++++++++++++++ sample.py | 58 +++++++++++++++++ train.py | 106 +++++++++++++++++++++++++++++++ utils.py | 31 ++++++++++ 7 files changed, 400 insertions(+) create mode 100644 .gitignore create mode 100644 pixelcnn/__init__.py create mode 100644 pixelcnn/conv_layers.py create mode 100644 pixelcnn/model.py create mode 100644 sample.py create mode 100644 train.py create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eaf0c76 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.ipynb +.ipynb_checkpoints +.idea +__pycache__/ \ No newline at end of file diff --git a/pixelcnn/__init__.py b/pixelcnn/__init__.py new file mode 100644 index 0000000..240ecec --- /dev/null +++ b/pixelcnn/__init__.py @@ -0,0 +1,3 @@ +from pixelcnn.model import PixelCNN + +__all__ = ['PixelCNN'] \ No newline at end of file diff --git a/pixelcnn/conv_layers.py b/pixelcnn/conv_layers.py new file mode 100644 index 0000000..1111139 --- /dev/null +++ b/pixelcnn/conv_layers.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn + +import numpy as np + + +class CroppedConv2d(nn.Conv2d): + def __init__(self, *args, **kwargs): + super(CroppedConv2d, self).__init__(*args, **kwargs) + + def forward(self, x): + x = super(CroppedConv2d, self).forward(x) + + kernel_height, _ = self.kernel_size + res = x[:, :, 1:-kernel_height, :] + shifted_up_res = x[:, :, :-kernel_height - 1, :] + + return res, shifted_up_res + + +class MaskedConv2d(nn.Conv2d): + def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=True, **kwargs): + super(MaskedConv2d, self).__init__(*args, **kwargs) + + assert mask_type in ['A', 'B'], 'Invalid mask type.' + + out_channels, in_channels, height, width = self.weight.size() + yc, xc = height // 2, width // 2 + + mask = np.zeros(self.weight.size(), dtype=np.float32) + mask[:, :, :yc, :] = 1 + mask[:, :, yc, :xc + 1] = 1 + + def cmask(out_c, in_c): + if out_spread: + a = (np.arange(out_channels) % data_channels == out_c)[:, None] + else: + split = np.ceil(out_channels / 3) + lbound = out_c * split + ubound = (out_c + 1) * split + a = ((lbound <= np.arange(out_channels)) * (np.arange(out_channels) < ubound))[:, None] + if in_spread: + b = (np.arange(in_channels) % data_channels == in_c)[None, :] + else: + split = np.ceil(in_channels / 3) + lbound = in_c * split + ubound = (in_c + 1) * split + b = ((lbound <= np.arange(in_channels)) * (np.arange(in_channels) < ubound))[None, :] + return a * b + + for o in range(data_channels): + for i in range(o + 1, data_channels): + mask[cmask(o, i), yc, xc] = 0 + + if mask_type == 'A': + for c in range(data_channels): + mask[cmask(c, c), yc, xc] = 0 + + self.mask = torch.from_numpy(mask) + + def forward(self, x): + self.weight.data *= self.mask + x = super(MaskedConv2d, self).forward(x) + return x diff --git a/pixelcnn/model.py b/pixelcnn/model.py new file mode 100644 index 0000000..8aff080 --- /dev/null +++ b/pixelcnn/model.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .conv_layers import MaskedConv2d, CroppedConv2d + + +class GatedBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, mask_type, data_channels): + super(GatedBlock, self).__init__() + self.split_index = out_channels + + self.v_conv = CroppedConv2d(in_channels, + 2 * out_channels, + (kernel_size // 2 + 1, kernel_size), + padding=(kernel_size // 2 + 1, kernel_size // 2)) + self.v_fc = nn.Conv2d(in_channels, + 2 * out_channels, + (1, 1)) + self.v_to_h = MaskedConv2d(2 * out_channels, + 2 * out_channels, + (1, 1), + mask_type=mask_type, + data_channels=data_channels, ) + self.h_conv = MaskedConv2d(in_channels, + 2 * out_channels, + (1, kernel_size), + mask_type=mask_type, + data_channels=data_channels, + padding=(0, kernel_size // 2)) + self.h_fc = MaskedConv2d(out_channels, + out_channels, + (1, 1), + mask_type=mask_type, + data_channels=data_channels) + self.h_skip = MaskedConv2d(out_channels, + out_channels, + (1, 1), + mask_type=mask_type, + data_channels=data_channels) + + def forward(self, x): + v_in, h_in, skip = x[0], x[1], x[2] + + v_out, v_shifted = self.v_conv(v_in) + v_out += self.v_fc(v_in) + v_out = torch.tanh(v_out[:, :self.split_index]) * torch.sigmoid(v_out[:, self.split_index:]) + + h_out = self.h_conv(h_in) + v_shifted = self.v_to_h(v_shifted) + h_out += v_shifted + h_out = torch.tanh(h_out[:, :self.split_index]) * torch.sigmoid(h_out[:, self.split_index:]) + + skip = skip + self.h_skip(h_out) + + h_out = self.h_fc(h_out) + h_in + + return {0: v_out, 1: h_out, 2: skip} + + +class PixelCNN(nn.Module): + def __init__(self, cfg): + super(PixelCNN, self).__init__() + self.causal_ksize = cfg.causal_ksize + self.hidden_ksize = cfg.hidden_ksize + + self.data_channels = cfg.data_channels + self.hidden_fmaps = cfg.hidden_fmaps + self.out_hidden_fmaps = cfg.out_hidden_fmaps + + self.hidden_layers = cfg.hidden_layers + + self.color_levels = cfg.color_levels + + self.causal_conv = GatedBlock(self.data_channels, + self.hidden_fmaps, + self.causal_ksize, + mask_type='A', + data_channels=self.data_channels) + + self.hidden_conv = nn.Sequential( + *[GatedBlock(self.hidden_fmaps, + self.hidden_fmaps, + self.hidden_ksize, + mask_type='B', + data_channels=self.data_channels) for _ in range(self.hidden_layers)] + ) + + self.out_hidden_conv = MaskedConv2d(self.hidden_fmaps, + self.out_hidden_fmaps, + (1, 1), + mask_type='B', + data_channels=self.data_channels) + + self.out_conv = MaskedConv2d(self.out_hidden_fmaps, + self.data_channels * self.color_levels, + (1, 1), + mask_type='B', + data_channels=self.data_channels, + out_spread=False) + + def forward(self, x): + count, _, height, width = x.size() + out_shape = (count, self.hidden_fmaps, height, width) + + v, h, _ = self.causal_conv({0: x, 1: x, 2: torch.zeros(out_shape, requires_grad=True)}).values() + + _, _, out = self.hidden_conv({0: v, 1: h, 2: torch.zeros(out_shape, requires_grad=True)}).values() + + assert out.requires_grad + + out = F.relu(out) + out = F.relu(self.out_hidden_conv(out)) + out = self.out_conv(out) + + batch_size, _, height, width = out.size() + out = out.view(batch_size, self.color_levels, self.data_channels, height, width) + + return out + + def sample(self, shape, count): + channels, height, width = shape + + samples = torch.zeros(count, *shape) + + with torch.no_grad(): + for i in range(height): + for j in range(width): + for c in range(channels): + unnormalized_probs = self.forward(samples) + pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1) + sampled_levels = torch.multinomial(pixel_probs, 1).squeeze() + samples[:, c, i, j] = sampled_levels + + return samples diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..4bad944 --- /dev/null +++ b/sample.py @@ -0,0 +1,58 @@ +import torch + +from pixelcnn import PixelCNN + +import argparse +from utils import str2bool, save_samples + + +def main(): + parser = argparse.ArgumentParser(description='PixelCNN') + + parser.add_argument('--causal-ksize', type=int, default=7, + help='Kernel size of causal convolution') + parser.add_argument('--hidden-ksize', type=int, default=3, + help='Kernel size of hidden layers convolutions') + + parser.add_argument('--data-channels', type=int, default=3, + help='Number of data channels') + parser.add_argument('--color-levels', type=int, default=7, + help='Number of levels to quantisize value of each channel of each pixel into') + + parser.add_argument('--hidden-fmaps', type=int, default=128, + help='Number of feature maps in hidden layer') + parser.add_argument('--out-hidden-fmaps', type=int, default=32, + help='Number of feature maps in outer hidden layer') + parser.add_argument('--hidden-layers', type=int, default=10, + help='Number of layers of gated convolutions with mask of type "B"') + + parser.add_argument('--cuda', type=str2bool, default=True, + help='Flag indicating whether CUDA should be used') + parser.add_argument('--model-path', '-m', default='', + help="Path to model's saved parameters") + parser.add_argument('--output-fname', '-o', type=str, default='samples/samples.jpg', + help='Output filename') + + parser.add_argument('--count', '-c', type=int, default=10, + help='Number of images to generate \ + (is rounded to the nearest integer square)') + parser.add_argument('--height', type=int, default=28, help='Output image height') + parser.add_argument('--width', type=int, default=28, help='Output image width') + + cfg = parser.parse_args() + + model = PixelCNN(cfg=cfg) + model.eval() + + device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu") + model.to(device) + + model.load_state_dict(torch.load(cfg.model_path)) + + samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count) + save_samples(samples, cfg.output_fname) + + +if __name__ == '__main__': + main() + diff --git a/train.py b/train.py new file mode 100644 index 0000000..0e715eb --- /dev/null +++ b/train.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +from torchvision import datasets, transforms +from torch.utils.data import DataLoader + +import argparse +import os +from utils import str2bool, quantisize, save_samples +from tqdm import tqdm + +from pixelcnn import PixelCNN + +DATASET_ROOT = "data/" +TRAIN_SAMPLES_PATH = "train_samples" +TRAIN_SAMPLES_COUNT = 16 #must be square + + +def main(): + parser = argparse.ArgumentParser(description='PixelCNN') + + parser.add_argument('--epochs', type=int, default=30, + help='Number of epochs to train model for') + parser.add_argument('--batch-size', type=int, default=32, + help='Number of images per mini-batch') + parser.add_argument('--dataset', type=str, default='mnist', + help='Dataset to train model on. Either mnist, fashionmnist or cifar.') + + parser.add_argument('--causal-ksize', type=int, default=7, + help='Kernel size of causal convolution') + parser.add_argument('--hidden-ksize', type=int, default=3, + help='Kernel size of hidden layers convolutions') + + parser.add_argument('--data-channels', type=int, default=1, + help='Number of data channels') + parser.add_argument('--color-levels', type=int, default=7, + help='Number of levels to quantisize value of each channel of each pixel into') + + parser.add_argument('--hidden-fmaps', type=int, default=128, + help='Number of feature maps in hidden layer') + parser.add_argument('--out-hidden-fmaps', type=int, default=32, + help='Number of feature maps in outer hidden layer') + parser.add_argument('--hidden-layers', type=int, default=10, + help='Number of layers of gated convolutions with mask of type "B"') + + parser.add_argument('--cuda', type=str2bool, default=True, + help='Flag indicating whether CUDA should be used') + parser.add_argument('--model-output-path', '-m', default='', + help="Output path for model's parameters") + parser.add_argument('--samples-folder', '-o', type=str, default='train-samples/', + help='Path where sampled images will be saved') + + cfg = parser.parse_args() + LEVELS = cfg.color_levels + MODEL_PATH = cfg.model_output_path + + model = PixelCNN(cfg=cfg) + + device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu") + model.to(device) + + transform = transforms.Compose([ + transforms.Lambda(lambda image: quantisize(image, LEVELS)), + transforms.ToTensor() + ]) + if cfg.dataset == "mnist": + dataset = datasets.MNIST(root=DATASET_ROOT, train=True, download=True, transform=transform) + HEIGHT, WIDTH = 28, 28 + elif cfg.dataset == "fashionmnist": + dataset = datasets.FashionMNIST(root=DATASET_ROOT, train=True, download=True, transform=transform) + HEIGHT, WIDTH = 28, 28 + elif cfg.dataset == "cifar": + dataset = datasets.CIFAR10(root=DATASET_ROOT, train=True, download=True, transform=transform) + HEIGHT, WIDTH = 28, 28 + + data_loader = DataLoader(dataset, batch_size=cfg.batch_size) + + loss_fn = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters()) + + for epoch in tqdm(range(cfg.epochs)): + for i, images in enumerate(data_loader): + optimizer.zero_grad() + + if cfg.dataset in ['mnist', 'fashionmnist', 'cifar']: + # remove labels + images = images[0] + + normalized_images = images.float() / (LEVELS - 1) + + outputs = model(normalized_images) + loss = loss_fn(outputs, images) + loss.backward() + optimizer.step() + + model.eval() + samples = model.sample((cfg.data_channels, HEIGHT, WIDTH), TRAIN_SAMPLES_COUNT) + save_samples(samples, os.path.join(TRAIN_SAMPLES_PATH, 'epoch{}_samples.jpg'.format(epoch))) + model.train() + + torch.save(model.state_dict(), MODEL_PATH) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..6a7a2a0 --- /dev/null +++ b/utils.py @@ -0,0 +1,31 @@ +import numpy as np +import argparse + +from PIL import Image + + +def quantisize(image, levels): + return np.digitize(image, np.arange(levels) / levels) - 1 + + +def str2bool(s): + if isinstance(s, bool): + return s + if s.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif s.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected') + + +def nearest_square(num): + return round(np.sqrt(num))**2 + + +def save_samples(samples, filename): + count, channels, height, width = samples.size() + samples = samples.view(count**0.5, count**0.5, channels, height, width) + samples = samples.permute(1, 3, 0, 4, 2) + samples = samples.view(height * count, width * count, channels) * 255 + Image.fromarray(samples).save(filename)