-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 724c624
Showing
7 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.ipynb | ||
.ipynb_checkpoints | ||
.idea | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pixelcnn.model import PixelCNN | ||
|
||
__all__ = ['PixelCNN'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import numpy as np | ||
|
||
|
||
class CroppedConv2d(nn.Conv2d): | ||
def __init__(self, *args, **kwargs): | ||
super(CroppedConv2d, self).__init__(*args, **kwargs) | ||
|
||
def forward(self, x): | ||
x = super(CroppedConv2d, self).forward(x) | ||
|
||
kernel_height, _ = self.kernel_size | ||
res = x[:, :, 1:-kernel_height, :] | ||
shifted_up_res = x[:, :, :-kernel_height - 1, :] | ||
|
||
return res, shifted_up_res | ||
|
||
|
||
class MaskedConv2d(nn.Conv2d): | ||
def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=True, **kwargs): | ||
super(MaskedConv2d, self).__init__(*args, **kwargs) | ||
|
||
assert mask_type in ['A', 'B'], 'Invalid mask type.' | ||
|
||
out_channels, in_channels, height, width = self.weight.size() | ||
yc, xc = height // 2, width // 2 | ||
|
||
mask = np.zeros(self.weight.size(), dtype=np.float32) | ||
mask[:, :, :yc, :] = 1 | ||
mask[:, :, yc, :xc + 1] = 1 | ||
|
||
def cmask(out_c, in_c): | ||
if out_spread: | ||
a = (np.arange(out_channels) % data_channels == out_c)[:, None] | ||
else: | ||
split = np.ceil(out_channels / 3) | ||
lbound = out_c * split | ||
ubound = (out_c + 1) * split | ||
a = ((lbound <= np.arange(out_channels)) * (np.arange(out_channels) < ubound))[:, None] | ||
if in_spread: | ||
b = (np.arange(in_channels) % data_channels == in_c)[None, :] | ||
else: | ||
split = np.ceil(in_channels / 3) | ||
lbound = in_c * split | ||
ubound = (in_c + 1) * split | ||
b = ((lbound <= np.arange(in_channels)) * (np.arange(in_channels) < ubound))[None, :] | ||
return a * b | ||
|
||
for o in range(data_channels): | ||
for i in range(o + 1, data_channels): | ||
mask[cmask(o, i), yc, xc] = 0 | ||
|
||
if mask_type == 'A': | ||
for c in range(data_channels): | ||
mask[cmask(c, c), yc, xc] = 0 | ||
|
||
self.mask = torch.from_numpy(mask) | ||
|
||
def forward(self, x): | ||
self.weight.data *= self.mask | ||
x = super(MaskedConv2d, self).forward(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from .conv_layers import MaskedConv2d, CroppedConv2d | ||
|
||
|
||
class GatedBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels, kernel_size, mask_type, data_channels): | ||
super(GatedBlock, self).__init__() | ||
self.split_index = out_channels | ||
|
||
self.v_conv = CroppedConv2d(in_channels, | ||
2 * out_channels, | ||
(kernel_size // 2 + 1, kernel_size), | ||
padding=(kernel_size // 2 + 1, kernel_size // 2)) | ||
self.v_fc = nn.Conv2d(in_channels, | ||
2 * out_channels, | ||
(1, 1)) | ||
self.v_to_h = MaskedConv2d(2 * out_channels, | ||
2 * out_channels, | ||
(1, 1), | ||
mask_type=mask_type, | ||
data_channels=data_channels, ) | ||
self.h_conv = MaskedConv2d(in_channels, | ||
2 * out_channels, | ||
(1, kernel_size), | ||
mask_type=mask_type, | ||
data_channels=data_channels, | ||
padding=(0, kernel_size // 2)) | ||
self.h_fc = MaskedConv2d(out_channels, | ||
out_channels, | ||
(1, 1), | ||
mask_type=mask_type, | ||
data_channels=data_channels) | ||
self.h_skip = MaskedConv2d(out_channels, | ||
out_channels, | ||
(1, 1), | ||
mask_type=mask_type, | ||
data_channels=data_channels) | ||
|
||
def forward(self, x): | ||
v_in, h_in, skip = x[0], x[1], x[2] | ||
|
||
v_out, v_shifted = self.v_conv(v_in) | ||
v_out += self.v_fc(v_in) | ||
v_out = torch.tanh(v_out[:, :self.split_index]) * torch.sigmoid(v_out[:, self.split_index:]) | ||
|
||
h_out = self.h_conv(h_in) | ||
v_shifted = self.v_to_h(v_shifted) | ||
h_out += v_shifted | ||
h_out = torch.tanh(h_out[:, :self.split_index]) * torch.sigmoid(h_out[:, self.split_index:]) | ||
|
||
skip = skip + self.h_skip(h_out) | ||
|
||
h_out = self.h_fc(h_out) + h_in | ||
|
||
return {0: v_out, 1: h_out, 2: skip} | ||
|
||
|
||
class PixelCNN(nn.Module): | ||
def __init__(self, cfg): | ||
super(PixelCNN, self).__init__() | ||
self.causal_ksize = cfg.causal_ksize | ||
self.hidden_ksize = cfg.hidden_ksize | ||
|
||
self.data_channels = cfg.data_channels | ||
self.hidden_fmaps = cfg.hidden_fmaps | ||
self.out_hidden_fmaps = cfg.out_hidden_fmaps | ||
|
||
self.hidden_layers = cfg.hidden_layers | ||
|
||
self.color_levels = cfg.color_levels | ||
|
||
self.causal_conv = GatedBlock(self.data_channels, | ||
self.hidden_fmaps, | ||
self.causal_ksize, | ||
mask_type='A', | ||
data_channels=self.data_channels) | ||
|
||
self.hidden_conv = nn.Sequential( | ||
*[GatedBlock(self.hidden_fmaps, | ||
self.hidden_fmaps, | ||
self.hidden_ksize, | ||
mask_type='B', | ||
data_channels=self.data_channels) for _ in range(self.hidden_layers)] | ||
) | ||
|
||
self.out_hidden_conv = MaskedConv2d(self.hidden_fmaps, | ||
self.out_hidden_fmaps, | ||
(1, 1), | ||
mask_type='B', | ||
data_channels=self.data_channels) | ||
|
||
self.out_conv = MaskedConv2d(self.out_hidden_fmaps, | ||
self.data_channels * self.color_levels, | ||
(1, 1), | ||
mask_type='B', | ||
data_channels=self.data_channels, | ||
out_spread=False) | ||
|
||
def forward(self, x): | ||
count, _, height, width = x.size() | ||
out_shape = (count, self.hidden_fmaps, height, width) | ||
|
||
v, h, _ = self.causal_conv({0: x, 1: x, 2: torch.zeros(out_shape, requires_grad=True)}).values() | ||
|
||
_, _, out = self.hidden_conv({0: v, 1: h, 2: torch.zeros(out_shape, requires_grad=True)}).values() | ||
|
||
assert out.requires_grad | ||
|
||
out = F.relu(out) | ||
out = F.relu(self.out_hidden_conv(out)) | ||
out = self.out_conv(out) | ||
|
||
batch_size, _, height, width = out.size() | ||
out = out.view(batch_size, self.color_levels, self.data_channels, height, width) | ||
|
||
return out | ||
|
||
def sample(self, shape, count): | ||
channels, height, width = shape | ||
|
||
samples = torch.zeros(count, *shape) | ||
|
||
with torch.no_grad(): | ||
for i in range(height): | ||
for j in range(width): | ||
for c in range(channels): | ||
unnormalized_probs = self.forward(samples) | ||
pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1) | ||
sampled_levels = torch.multinomial(pixel_probs, 1).squeeze() | ||
samples[:, c, i, j] = sampled_levels | ||
|
||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
|
||
from pixelcnn import PixelCNN | ||
|
||
import argparse | ||
from utils import str2bool, save_samples | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='PixelCNN') | ||
|
||
parser.add_argument('--causal-ksize', type=int, default=7, | ||
help='Kernel size of causal convolution') | ||
parser.add_argument('--hidden-ksize', type=int, default=3, | ||
help='Kernel size of hidden layers convolutions') | ||
|
||
parser.add_argument('--data-channels', type=int, default=3, | ||
help='Number of data channels') | ||
parser.add_argument('--color-levels', type=int, default=7, | ||
help='Number of levels to quantisize value of each channel of each pixel into') | ||
|
||
parser.add_argument('--hidden-fmaps', type=int, default=128, | ||
help='Number of feature maps in hidden layer') | ||
parser.add_argument('--out-hidden-fmaps', type=int, default=32, | ||
help='Number of feature maps in outer hidden layer') | ||
parser.add_argument('--hidden-layers', type=int, default=10, | ||
help='Number of layers of gated convolutions with mask of type "B"') | ||
|
||
parser.add_argument('--cuda', type=str2bool, default=True, | ||
help='Flag indicating whether CUDA should be used') | ||
parser.add_argument('--model-path', '-m', default='', | ||
help="Path to model's saved parameters") | ||
parser.add_argument('--output-fname', '-o', type=str, default='samples/samples.jpg', | ||
help='Output filename') | ||
|
||
parser.add_argument('--count', '-c', type=int, default=10, | ||
help='Number of images to generate \ | ||
(is rounded to the nearest integer square)') | ||
parser.add_argument('--height', type=int, default=28, help='Output image height') | ||
parser.add_argument('--width', type=int, default=28, help='Output image width') | ||
|
||
cfg = parser.parse_args() | ||
|
||
model = PixelCNN(cfg=cfg) | ||
model.eval() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu") | ||
model.to(device) | ||
|
||
model.load_state_dict(torch.load(cfg.model_path)) | ||
|
||
samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count) | ||
save_samples(samples, cfg.output_fname) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
from torchvision import datasets, transforms | ||
from torch.utils.data import DataLoader | ||
|
||
import argparse | ||
import os | ||
from utils import str2bool, quantisize, save_samples | ||
from tqdm import tqdm | ||
|
||
from pixelcnn import PixelCNN | ||
|
||
DATASET_ROOT = "data/" | ||
TRAIN_SAMPLES_PATH = "train_samples" | ||
TRAIN_SAMPLES_COUNT = 16 #must be square | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='PixelCNN') | ||
|
||
parser.add_argument('--epochs', type=int, default=30, | ||
help='Number of epochs to train model for') | ||
parser.add_argument('--batch-size', type=int, default=32, | ||
help='Number of images per mini-batch') | ||
parser.add_argument('--dataset', type=str, default='mnist', | ||
help='Dataset to train model on. Either mnist, fashionmnist or cifar.') | ||
|
||
parser.add_argument('--causal-ksize', type=int, default=7, | ||
help='Kernel size of causal convolution') | ||
parser.add_argument('--hidden-ksize', type=int, default=3, | ||
help='Kernel size of hidden layers convolutions') | ||
|
||
parser.add_argument('--data-channels', type=int, default=1, | ||
help='Number of data channels') | ||
parser.add_argument('--color-levels', type=int, default=7, | ||
help='Number of levels to quantisize value of each channel of each pixel into') | ||
|
||
parser.add_argument('--hidden-fmaps', type=int, default=128, | ||
help='Number of feature maps in hidden layer') | ||
parser.add_argument('--out-hidden-fmaps', type=int, default=32, | ||
help='Number of feature maps in outer hidden layer') | ||
parser.add_argument('--hidden-layers', type=int, default=10, | ||
help='Number of layers of gated convolutions with mask of type "B"') | ||
|
||
parser.add_argument('--cuda', type=str2bool, default=True, | ||
help='Flag indicating whether CUDA should be used') | ||
parser.add_argument('--model-output-path', '-m', default='', | ||
help="Output path for model's parameters") | ||
parser.add_argument('--samples-folder', '-o', type=str, default='train-samples/', | ||
help='Path where sampled images will be saved') | ||
|
||
cfg = parser.parse_args() | ||
LEVELS = cfg.color_levels | ||
MODEL_PATH = cfg.model_output_path | ||
|
||
model = PixelCNN(cfg=cfg) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and cfg.cuda else "cpu") | ||
model.to(device) | ||
|
||
transform = transforms.Compose([ | ||
transforms.Lambda(lambda image: quantisize(image, LEVELS)), | ||
transforms.ToTensor() | ||
]) | ||
if cfg.dataset == "mnist": | ||
dataset = datasets.MNIST(root=DATASET_ROOT, train=True, download=True, transform=transform) | ||
HEIGHT, WIDTH = 28, 28 | ||
elif cfg.dataset == "fashionmnist": | ||
dataset = datasets.FashionMNIST(root=DATASET_ROOT, train=True, download=True, transform=transform) | ||
HEIGHT, WIDTH = 28, 28 | ||
elif cfg.dataset == "cifar": | ||
dataset = datasets.CIFAR10(root=DATASET_ROOT, train=True, download=True, transform=transform) | ||
HEIGHT, WIDTH = 28, 28 | ||
|
||
data_loader = DataLoader(dataset, batch_size=cfg.batch_size) | ||
|
||
loss_fn = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(model.parameters()) | ||
|
||
for epoch in tqdm(range(cfg.epochs)): | ||
for i, images in enumerate(data_loader): | ||
optimizer.zero_grad() | ||
|
||
if cfg.dataset in ['mnist', 'fashionmnist', 'cifar']: | ||
# remove labels | ||
images = images[0] | ||
|
||
normalized_images = images.float() / (LEVELS - 1) | ||
|
||
outputs = model(normalized_images) | ||
loss = loss_fn(outputs, images) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
model.eval() | ||
samples = model.sample((cfg.data_channels, HEIGHT, WIDTH), TRAIN_SAMPLES_COUNT) | ||
save_samples(samples, os.path.join(TRAIN_SAMPLES_PATH, 'epoch{}_samples.jpg'.format(epoch))) | ||
model.train() | ||
|
||
torch.save(model.state_dict(), MODEL_PATH) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.