Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
anordertoreclaim committed Aug 1, 2019
0 parents commit 724c624
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.ipynb
.ipynb_checkpoints
.idea
__pycache__/
3 changes: 3 additions & 0 deletions pixelcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pixelcnn.model import PixelCNN

__all__ = ['PixelCNN']
64 changes: 64 additions & 0 deletions pixelcnn/conv_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn

import numpy as np


class CroppedConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super(CroppedConv2d, self).__init__(*args, **kwargs)

def forward(self, x):
x = super(CroppedConv2d, self).forward(x)

kernel_height, _ = self.kernel_size
res = x[:, :, 1:-kernel_height, :]
shifted_up_res = x[:, :, :-kernel_height - 1, :]

return res, shifted_up_res


class MaskedConv2d(nn.Conv2d):
def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=True, **kwargs):
super(MaskedConv2d, self).__init__(*args, **kwargs)

assert mask_type in ['A', 'B'], 'Invalid mask type.'

out_channels, in_channels, height, width = self.weight.size()
yc, xc = height // 2, width // 2

mask = np.zeros(self.weight.size(), dtype=np.float32)
mask[:, :, :yc, :] = 1
mask[:, :, yc, :xc + 1] = 1

def cmask(out_c, in_c):
if out_spread:
a = (np.arange(out_channels) % data_channels == out_c)[:, None]
else:
split = np.ceil(out_channels / 3)
lbound = out_c * split
ubound = (out_c + 1) * split
a = ((lbound <= np.arange(out_channels)) * (np.arange(out_channels) < ubound))[:, None]
if in_spread:
b = (np.arange(in_channels) % data_channels == in_c)[None, :]
else:
split = np.ceil(in_channels / 3)
lbound = in_c * split
ubound = (in_c + 1) * split
b = ((lbound <= np.arange(in_channels)) * (np.arange(in_channels) < ubound))[None, :]
return a * b

for o in range(data_channels):
for i in range(o + 1, data_channels):
mask[cmask(o, i), yc, xc] = 0

if mask_type == 'A':
for c in range(data_channels):
mask[cmask(c, c), yc, xc] = 0

self.mask = torch.from_numpy(mask)

def forward(self, x):
self.weight.data *= self.mask
x = super(MaskedConv2d, self).forward(x)
return x
134 changes: 134 additions & 0 deletions pixelcnn/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv_layers import MaskedConv2d, CroppedConv2d


class GatedBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, mask_type, data_channels):
super(GatedBlock, self).__init__()
self.split_index = out_channels

self.v_conv = CroppedConv2d(in_channels,
2 * out_channels,
(kernel_size // 2 + 1, kernel_size),
padding=(kernel_size // 2 + 1, kernel_size // 2))
self.v_fc = nn.Conv2d(in_channels,
2 * out_channels,
(1, 1))
self.v_to_h = MaskedConv2d(2 * out_channels,
2 * out_channels,
(1, 1),
mask_type=mask_type,
data_channels=data_channels, )
self.h_conv = MaskedConv2d(in_channels,
2 * out_channels,
(1, kernel_size),
mask_type=mask_type,
data_channels=data_channels,
padding=(0, kernel_size // 2))
self.h_fc = MaskedConv2d(out_channels,
out_channels,
(1, 1),
mask_type=mask_type,
data_channels=data_channels)
self.h_skip = MaskedConv2d(out_channels,
out_channels,
(1, 1),
mask_type=mask_type,
data_channels=data_channels)

def forward(self, x):
v_in, h_in, skip = x[0], x[1], x[2]

v_out, v_shifted = self.v_conv(v_in)
v_out += self.v_fc(v_in)
v_out = torch.tanh(v_out[:, :self.split_index]) * torch.sigmoid(v_out[:, self.split_index:])

h_out = self.h_conv(h_in)
v_shifted = self.v_to_h(v_shifted)
h_out += v_shifted
h_out = torch.tanh(h_out[:, :self.split_index]) * torch.sigmoid(h_out[:, self.split_index:])

skip = skip + self.h_skip(h_out)

h_out = self.h_fc(h_out) + h_in

return {0: v_out, 1: h_out, 2: skip}


class PixelCNN(nn.Module):
def __init__(self, cfg):
super(PixelCNN, self).__init__()
self.causal_ksize = cfg.causal_ksize
self.hidden_ksize = cfg.hidden_ksize

self.data_channels = cfg.data_channels
self.hidden_fmaps = cfg.hidden_fmaps
self.out_hidden_fmaps = cfg.out_hidden_fmaps

self.hidden_layers = cfg.hidden_layers

self.color_levels = cfg.color_levels

self.causal_conv = GatedBlock(self.data_channels,
self.hidden_fmaps,
self.causal_ksize,
mask_type='A',
data_channels=self.data_channels)

self.hidden_conv = nn.Sequential(
*[GatedBlock(self.hidden_fmaps,
self.hidden_fmaps,
self.hidden_ksize,
mask_type='B',
data_channels=self.data_channels) for _ in range(self.hidden_layers)]
)

self.out_hidden_conv = MaskedConv2d(self.hidden_fmaps,
self.out_hidden_fmaps,
(1, 1),
mask_type='B',
data_channels=self.data_channels)

self.out_conv = MaskedConv2d(self.out_hidden_fmaps,
self.data_channels * self.color_levels,
(1, 1),
mask_type='B',
data_channels=self.data_channels,
out_spread=False)

def forward(self, x):
count, _, height, width = x.size()
out_shape = (count, self.hidden_fmaps, height, width)

v, h, _ = self.causal_conv({0: x, 1: x, 2: torch.zeros(out_shape, requires_grad=True)}).values()

_, _, out = self.hidden_conv({0: v, 1: h, 2: torch.zeros(out_shape, requires_grad=True)}).values()

assert out.requires_grad

out = F.relu(out)
out = F.relu(self.out_hidden_conv(out))
out = self.out_conv(out)

batch_size, _, height, width = out.size()
out = out.view(batch_size, self.color_levels, self.data_channels, height, width)

return out

def sample(self, shape, count):
channels, height, width = shape

samples = torch.zeros(count, *shape)

with torch.no_grad():
for i in range(height):
for j in range(width):
for c in range(channels):
unnormalized_probs = self.forward(samples)
pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
sampled_levels = torch.multinomial(pixel_probs, 1).squeeze()
samples[:, c, i, j] = sampled_levels

return samples
58 changes: 58 additions & 0 deletions sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch

from pixelcnn import PixelCNN

import argparse
from utils import str2bool, save_samples


def main():
parser = argparse.ArgumentParser(description='PixelCNN')

parser.add_argument('--causal-ksize', type=int, default=7,
help='Kernel size of causal convolution')
parser.add_argument('--hidden-ksize', type=int, default=3,
help='Kernel size of hidden layers convolutions')

parser.add_argument('--data-channels', type=int, default=3,
help='Number of data channels')
parser.add_argument('--color-levels', type=int, default=7,
help='Number of levels to quantisize value of each channel of each pixel into')

parser.add_argument('--hidden-fmaps', type=int, default=128,
help='Number of feature maps in hidden layer')
parser.add_argument('--out-hidden-fmaps', type=int, default=32,
help='Number of feature maps in outer hidden layer')
parser.add_argument('--hidden-layers', type=int, default=10,
help='Number of layers of gated convolutions with mask of type "B"')

parser.add_argument('--cuda', type=str2bool, default=True,
help='Flag indicating whether CUDA should be used')
parser.add_argument('--model-path', '-m', default='',
help="Path to model's saved parameters")
parser.add_argument('--output-fname', '-o', type=str, default='samples/samples.jpg',
help='Output filename')

parser.add_argument('--count', '-c', type=int, default=10,
help='Number of images to generate \
(is rounded to the nearest integer square)')
parser.add_argument('--height', type=int, default=28, help='Output image height')
parser.add_argument('--width', type=int, default=28, help='Output image width')

cfg = parser.parse_args()

model = PixelCNN(cfg=cfg)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
model.to(device)

model.load_state_dict(torch.load(cfg.model_path))

samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count)
save_samples(samples, cfg.output_fname)


if __name__ == '__main__':
main()

106 changes: 106 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import argparse
import os
from utils import str2bool, quantisize, save_samples
from tqdm import tqdm

from pixelcnn import PixelCNN

DATASET_ROOT = "data/"
TRAIN_SAMPLES_PATH = "train_samples"
TRAIN_SAMPLES_COUNT = 16 #must be square


def main():
parser = argparse.ArgumentParser(description='PixelCNN')

parser.add_argument('--epochs', type=int, default=30,
help='Number of epochs to train model for')
parser.add_argument('--batch-size', type=int, default=32,
help='Number of images per mini-batch')
parser.add_argument('--dataset', type=str, default='mnist',
help='Dataset to train model on. Either mnist, fashionmnist or cifar.')

parser.add_argument('--causal-ksize', type=int, default=7,
help='Kernel size of causal convolution')
parser.add_argument('--hidden-ksize', type=int, default=3,
help='Kernel size of hidden layers convolutions')

parser.add_argument('--data-channels', type=int, default=1,
help='Number of data channels')
parser.add_argument('--color-levels', type=int, default=7,
help='Number of levels to quantisize value of each channel of each pixel into')

parser.add_argument('--hidden-fmaps', type=int, default=128,
help='Number of feature maps in hidden layer')
parser.add_argument('--out-hidden-fmaps', type=int, default=32,
help='Number of feature maps in outer hidden layer')
parser.add_argument('--hidden-layers', type=int, default=10,
help='Number of layers of gated convolutions with mask of type "B"')

parser.add_argument('--cuda', type=str2bool, default=True,
help='Flag indicating whether CUDA should be used')
parser.add_argument('--model-output-path', '-m', default='',
help="Output path for model's parameters")
parser.add_argument('--samples-folder', '-o', type=str, default='train-samples/',
help='Path where sampled images will be saved')

cfg = parser.parse_args()
LEVELS = cfg.color_levels
MODEL_PATH = cfg.model_output_path

model = PixelCNN(cfg=cfg)

device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu")
model.to(device)

transform = transforms.Compose([
transforms.Lambda(lambda image: quantisize(image, LEVELS)),
transforms.ToTensor()
])
if cfg.dataset == "mnist":
dataset = datasets.MNIST(root=DATASET_ROOT, train=True, download=True, transform=transform)
HEIGHT, WIDTH = 28, 28
elif cfg.dataset == "fashionmnist":
dataset = datasets.FashionMNIST(root=DATASET_ROOT, train=True, download=True, transform=transform)
HEIGHT, WIDTH = 28, 28
elif cfg.dataset == "cifar":
dataset = datasets.CIFAR10(root=DATASET_ROOT, train=True, download=True, transform=transform)
HEIGHT, WIDTH = 28, 28

data_loader = DataLoader(dataset, batch_size=cfg.batch_size)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

for epoch in tqdm(range(cfg.epochs)):
for i, images in enumerate(data_loader):
optimizer.zero_grad()

if cfg.dataset in ['mnist', 'fashionmnist', 'cifar']:
# remove labels
images = images[0]

normalized_images = images.float() / (LEVELS - 1)

outputs = model(normalized_images)
loss = loss_fn(outputs, images)
loss.backward()
optimizer.step()

model.eval()
samples = model.sample((cfg.data_channels, HEIGHT, WIDTH), TRAIN_SAMPLES_COUNT)
save_samples(samples, os.path.join(TRAIN_SAMPLES_PATH, 'epoch{}_samples.jpg'.format(epoch)))
model.train()

torch.save(model.state_dict(), MODEL_PATH)


if __name__ == '__main__':
main()
Loading

0 comments on commit 724c624

Please sign in to comment.