Skip to content

Commit

Permalink
Fixed mistakes
Browse files Browse the repository at this point in the history
  • Loading branch information
HiGal committed Aug 5, 2019
1 parent 342efaa commit e56aaf4
Show file tree
Hide file tree
Showing 22 changed files with 25 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here is what the whole architecture looks like:

![PixelCNN architecture](https://github.com/anordertoreclaim/PixelCNN/blob/master/.images/architecture.png?raw=true)

Causal block is the same as gated block, except that it has neither residual nor skip connections, its input is image instead of a tensor with depth of *hidden_fmaps* and it uses mask of type A instead of B of a usual gated block.
Causal block is the same as gated block, except that it has neither residual nor skip connections, its input is image instead of a tensor with depth of *hidden_fmaps*, it uses mask of type A instead of B of a usual gated block and it doesn't incorporate label bias.

Skip results are summed and ran through a ReLu – 1x1 Conv – ReLu block. Then the final convolutional layer is applied, which outputs a tensor that represents unnormalized probabilities of each color level for each color channel of each pixel in the image.

Expand Down
Binary file added image.pth
Binary file not shown.
33 changes: 20 additions & 13 deletions pixelcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(self, in_channels, out_channels, kernel_size, data_channels):
mask_type='A',
data_channels=data_channels)

def forward(self, v_in, h_in):
v_out, v_shifted = self.v_conv(v_in)
v_out += self.v_fc(v_in)
def forward(self, image):
v_out, v_shifted = self.v_conv(image)
v_out += self.v_fc(image)
v_out_tanh, v_out_sigmoid = torch.split(v_out, self.split_size, dim=1)
v_out = torch.tanh(v_out_tanh) * torch.sigmoid(v_out_sigmoid)

h_out = self.h_conv(h_in)
h_out = self.h_conv(image)
v_shifted = self.v_to_h(v_shifted)
h_out += v_shifted
h_out_tanh, h_out_sigmoid = torch.split(h_out, self.split_size, dim=1)
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, in_channels, out_channels, kernel_size, data_channels):
def forward(self, x):
v_in, h_in, skip, label = x[0], x[1], x[2], x[3]

label_embedded = self.label_embeddings(label).unsqueeze(2).unsqueeze(3)
label_embedded = self.label_embedding(label).unsqueeze(2).unsqueeze(3)

v_out, v_shifted = self.v_conv(v_in)
v_out += self.v_fc(v_in)
Expand Down Expand Up @@ -128,9 +128,11 @@ def __init__(self, cfg):
cfg.causal_ksize,
data_channels=cfg.data_channels)

self.hidden_conv = nn.Sequential([
self.hidden_conv = nn.Sequential(
*[GatedBlock(cfg.hidden_fmaps, cfg.hidden_fmaps, cfg.hidden_ksize, cfg.data_channels) for _ in range(cfg.hidden_layers)]
])
)

self.label_embedding = nn.Embedding(10, self.hidden_fmaps)

self.out_hidden_conv = MaskedConv2d(cfg.hidden_fmaps,
cfg.out_hidden_fmaps,
Expand All @@ -147,13 +149,17 @@ def __init__(self, cfg):
def forward(self, image, label):
count, data_channels, height, width = image.size()

v, h, _ = self.causal_conv({0: image, 1: image}).values()
v, h = self.causal_conv(image)

_, _, out, _ = self.hidden_conv({0: v,
1: h,
2: image.new_zeros((count, self.hidden_fmaps, height, width), requires_grad=True),
3: label}).values()

_, _, out = self.hidden_conv({0: v,
1: h,
2: image.new_zeros((count, self.hidden_fmaps, height, width), requires_grad=True),
3: label}).values()
label_embedded = self.label_embedding(label).unsqueeze(2).unsqueeze(3)

# add label bias
out += label_embedded
out = F.relu(out)
out = F.relu(self.out_hidden_conv(out))
out = self.out_conv(out)
Expand All @@ -166,12 +172,13 @@ def sample(self, shape, count, device='cuda'):
channels, height, width = shape

samples = torch.zeros(count, *shape).to(device)
labels = torch.randint(high=10, size=(count,)).to(device)

with torch.no_grad():
for i in range(height):
for j in range(width):
for c in range(channels):
unnormalized_probs = self.forward(samples)
unnormalized_probs = self.forward(samples, labels)
pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
sampled_levels = torch.multinomial(pixel_probs, 1).squeeze().float() / (self.color_levels - 1)
samples[:, c, i, j] = sampled_levels
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def train(cfg, model, device, train_loader, optimizer, scheduler, epoch):
for images, labels in tqdm(train_loader, desc='Epoch {}/{}'.format(epoch + 1, cfg.epochs)):
optimizer.zero_grad()

images.to(device, non_blocking=True)
labels.to(device, non_blocking=True)
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)

normalized_images = images.float() / (cfg.color_levels - 1)

Expand All @@ -50,8 +50,8 @@ def test_and_sample(cfg, model, device, test_loader, height, width, epoch):
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images.to(device, non_blocking=True)
labels.to(device, non_blocking=True)
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)

normalized_images = images.float() / (cfg.color_levels - 1)
outputs = model(normalized_images, labels)
Expand Down
Binary file added train_samples/epoch10_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch11_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch12_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch13_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch14_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch15_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch16_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch17_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch18_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch1_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch2_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch3_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch4_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch5_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch6_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch7_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch8_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added train_samples/epoch9_samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e56aaf4

Please sign in to comment.