Skip to content

Commit

Permalink
☀️ initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilienDupont committed Oct 30, 2018
1 parent eb4094f commit a76396b
Show file tree
Hide file tree
Showing 41 changed files with 2,442 additions and 1 deletion.
81 changes: 80 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,80 @@
# pixel-constrained-cnn-pytorch
# Probabilistic Semantic Inpainting with Pixel Constrained CNNs

Pytorch implementation of [Probabilistic Semantic Inpainting with Pixel Constrained CNNs](https://arxiv.org/abs/1804.00104) (2018).

This repo contains all code to reproduce the experiments in the paper as well as all the trained model weights.

## Examples

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/summary-figure.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/grid-progression-celeba-row.gif" width="500">

#### Samples sorted by their likelihood

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/eye-completions-likelihood.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/mnist-likelihood.png" width="400">

#### Pixel probabilities during sampling

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-1.gif" width="200">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-2.gif" width="200">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_1_from_1.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_3_from_3.png" width="400">

#### Architecture

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/architecture.png" width="400">

## Usage

To train a model, run `main.py config.json` (ensure that you have either the CelebA or MNIST dataset downloaded). To generate images using a trained model use `main_generate.py`. As an example, the following generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the bottom 7 rows of those images. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_experiment` folder.

```
python main_generate.py -n mnist_experiment -m trained_models/mnist -t generation -i 73 84 -b 7 -ns 64
```

## Trained models

The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models.

## Data sources

The MNIST dataset can be automatically downloaded using `torchvision`. All CelebA images were resized to be 32 by 32. Data can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

## Citing

If you find this work useful in your research, please cite using:

```
@article{dupont2018probabilistic,
title={Probabilistic Semantic Inpainting with Pixel Constrained CNNs},
author={Dupont, Emilien and Suresha, Suhas},
journal={arXiv preprint arXiv:1810.03728},
year={2018}
}
```

## More examples

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba-2.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-mnist.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba-2.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_6_from_6.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_5_from_0.png" width="400">

## License

MIT
21 changes: 21 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"name": "celeba_model",
"dataset": "celeba",
"resize": 32,
"crop": 32,
"grayscale": false,
"batch_size": 64,
"constrained": true,
"num_colors": 32,
"filter_size": 5,
"depth": 18,
"num_filters_cond": 66,
"num_filters_prior": 66,
"lr": 4e-4,
"epochs": 20,
"mask_descriptor": ["random_rect", [12, 12]],
"num_conds": 4,
"num_samples": 64,
"weight_cond_logits_loss": 1.0,
"weight_prior_logits_loss": 0.0
}
Binary file added imgs/architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/blob-samples-celeba-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/blob-samples-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/blob-samples-mnist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/bottom-samples-celeba-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/bottom-samples-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/eye-completions-likelihood.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/eye-completions.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/grid-progression-celeba-row.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit-progression-1.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit-progression-2.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit_1_from_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit_2_from_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit_3_from_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit_5_from_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/logit_6_from_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/mnist-likelihood.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/summary-figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
135 changes: 135 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import imageio
import json
import os
import sys
import time
import torch
from pixconcnn.training import Trainer, PixelConstrainedTrainer
from torchvision.utils import save_image
from utils.dataloaders import mnist, celeba
from utils.init_models import initialize_model
from utils.masks import batch_random_mask, get_repeated_conditional_pixels, MaskGenerator


# Set device
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# Get config file from command line arguments
if len(sys.argv) != 2:
raise(RuntimeError("Wrong arguments, use python main.py <path_to_config>"))
config_path = sys.argv[1]

# Open config file
with open(config_path) as config_file:
config = json.load(config_file)

name = config['name']
constrained = config['constrained']
batch_size = config['batch_size']
lr = config['lr']
num_colors = config['num_colors']
epochs = config['epochs']
dataset = config['dataset']
resize = config['resize'] # Only relevant for celeba
crop = config['crop'] # Only relevant for celeba
grayscale = config["grayscale"] # Only relevant for celeba
num_conds = config['num_conds'] # Only relevant if constrained
num_samples = config['num_samples'] # Only relevant if constrained
filter_size = config['filter_size']
depth = config['depth']
num_filters_cond = config['num_filters_cond']
num_filters_prior = config['num_filters_prior']
mask_descriptor = config['mask_descriptor']
weight_cond_logits_loss = config['weight_cond_logits_loss']
weight_prior_logits_loss = config['weight_prior_logits_loss']

# Create a folder to store experiment results
timestamp = time.strftime("%Y-%m-%d_%H-%M")
directory = "{}_{}".format(timestamp, name)
if not os.path.exists(directory):
os.makedirs(directory)

# Save config file in experiment directory
with open(directory + '/config.json', 'w') as config_file:
json.dump(config, config_file)

# Get data
if dataset == 'mnist':
data_loader, _ = mnist(batch_size, num_colors=num_colors, size=resize)
img_size = (1, resize, resize)
elif dataset == 'celeba':
data_loader = celeba(batch_size, num_colors=num_colors, size=resize,
crop=crop, grayscale=grayscale)
if grayscale:
img_size = (1, resize, resize)
else:
img_size = (3, resize, resize)

# Initialize model weights and architecture
model = initialize_model(img_size,
num_colors,
depth,
filter_size,
constrained,
num_filters_prior,
num_filters_cond)
model.to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

if constrained:
mask_generator = MaskGenerator(img_size, mask_descriptor)
trainer = PixelConstrainedTrainer(model, optimizer, device, mask_generator,
weight_cond_logits_loss=weight_cond_logits_loss,
weight_prior_logits_loss=weight_prior_logits_loss)
# Train model
progress_imgs = trainer.train(data_loader, epochs, directory=directory)

# Get a random batch of images
for batch, _ in data_loader:
break

for i in range(num_conds):
mask = mask_generator.get_masks(batch_size)
print('Generating {}/{} conditionings'.format(i + 1, num_conds))
cond_pixels = get_repeated_conditional_pixels(batch[i:i+1], mask[i:i+1],
num_colors, num_samples)
# Save mask as tensor
torch.save(mask[i:i+1], directory + '/mask{}.pt'.format(i))
# Save image that gave rise to the conditioning as tensor
torch.save(batch[i:i+1], directory + '/source{}.pt'.format(i))
# Save conditional pixels as tensor and image
torch.save(cond_pixels[0:1], directory + '/cond_pixels{}.pt'.format(i))
save_image(cond_pixels[0:1], directory + '/cond_pixels{}.png'.format(i))

cond_pixels = cond_pixels.to(device)
samples = model.sample(cond_pixels)
# Save samples and mean sample as tensor and image
torch.save(samples, directory + '/samples_cond{}.pt'.format(i))
save_image(samples.float() / (num_colors - 1.),
directory + '/samples_cond{}.png'.format(i))
save_image(samples.float().mean(dim=0) / (num_colors - 1.),
directory + '/mean_cond{}.png'.format(i))
# Save conditional logits if image is binary
if num_colors == 2:
# Save conditional logits
logits, _, cond_logits = model(batch[i:i+1].float().to(device), cond_pixels[0:1])
# Second dimension corresponds to different pixel values, so select probs of it being 1
save_image(cond_logits[:, 1], directory + '/prob_of_one_cond{}.png'.format(i))
# Second dimension corresponds to different pixel values, so select probs of it being 1
save_image(logits[:, 1], directory + '/prob_of_one_logits{}.png'.format(i))
else:
trainer = Trainer(model, optimizer, device)
progress_imgs = trainer.train(data_loader, epochs, directory=directory)

# Save losses and plots of them
with open(directory + '/losses.json', 'w') as losses_file:
json.dump(trainer.losses, losses_file)

# Save model
torch.save(trainer.model.state_dict(), directory + '/model.pt')

# Save gif of progress
imageio.mimsave(directory + '/training.gif', progress_imgs, fps=24)
Loading

0 comments on commit a76396b

Please sign in to comment.