-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
eb4094f
commit a76396b
Showing
41 changed files
with
2,442 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,80 @@ | ||
# pixel-constrained-cnn-pytorch | ||
# Probabilistic Semantic Inpainting with Pixel Constrained CNNs | ||
|
||
Pytorch implementation of [Probabilistic Semantic Inpainting with Pixel Constrained CNNs](https://arxiv.org/abs/1804.00104) (2018). | ||
|
||
This repo contains all code to reproduce the experiments in the paper as well as all the trained model weights. | ||
|
||
## Examples | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/summary-figure.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/grid-progression-celeba-row.gif" width="500"> | ||
|
||
#### Samples sorted by their likelihood | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/eye-completions-likelihood.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/mnist-likelihood.png" width="400"> | ||
|
||
#### Pixel probabilities during sampling | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-1.gif" width="200"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-2.gif" width="200"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_1_from_1.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_3_from_3.png" width="400"> | ||
|
||
#### Architecture | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/architecture.png" width="400"> | ||
|
||
## Usage | ||
|
||
To train a model, run `main.py config.json` (ensure that you have either the CelebA or MNIST dataset downloaded). To generate images using a trained model use `main_generate.py`. As an example, the following generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the bottom 7 rows of those images. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_experiment` folder. | ||
|
||
``` | ||
python main_generate.py -n mnist_experiment -m trained_models/mnist -t generation -i 73 84 -b 7 -ns 64 | ||
``` | ||
|
||
## Trained models | ||
|
||
The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models. | ||
|
||
## Data sources | ||
|
||
The MNIST dataset can be automatically downloaded using `torchvision`. All CelebA images were resized to be 32 by 32. Data can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). | ||
|
||
## Citing | ||
|
||
If you find this work useful in your research, please cite using: | ||
|
||
``` | ||
@article{dupont2018probabilistic, | ||
title={Probabilistic Semantic Inpainting with Pixel Constrained CNNs}, | ||
author={Dupont, Emilien and Suresha, Suhas}, | ||
journal={arXiv preprint arXiv:1810.03728}, | ||
year={2018} | ||
} | ||
``` | ||
|
||
## More examples | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba-2.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-mnist.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba-2.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_6_from_6.png" width="400"> | ||
|
||
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_5_from_0.png" width="400"> | ||
|
||
## License | ||
|
||
MIT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"name": "celeba_model", | ||
"dataset": "celeba", | ||
"resize": 32, | ||
"crop": 32, | ||
"grayscale": false, | ||
"batch_size": 64, | ||
"constrained": true, | ||
"num_colors": 32, | ||
"filter_size": 5, | ||
"depth": 18, | ||
"num_filters_cond": 66, | ||
"num_filters_prior": 66, | ||
"lr": 4e-4, | ||
"epochs": 20, | ||
"mask_descriptor": ["random_rect", [12, 12]], | ||
"num_conds": 4, | ||
"num_samples": 64, | ||
"weight_cond_logits_loss": 1.0, | ||
"weight_prior_logits_loss": 0.0 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import imageio | ||
import json | ||
import os | ||
import sys | ||
import time | ||
import torch | ||
from pixconcnn.training import Trainer, PixelConstrainedTrainer | ||
from torchvision.utils import save_image | ||
from utils.dataloaders import mnist, celeba | ||
from utils.init_models import initialize_model | ||
from utils.masks import batch_random_mask, get_repeated_conditional_pixels, MaskGenerator | ||
|
||
|
||
# Set device | ||
cuda = torch.cuda.is_available() | ||
device = torch.device("cuda" if cuda else "cpu") | ||
|
||
# Get config file from command line arguments | ||
if len(sys.argv) != 2: | ||
raise(RuntimeError("Wrong arguments, use python main.py <path_to_config>")) | ||
config_path = sys.argv[1] | ||
|
||
# Open config file | ||
with open(config_path) as config_file: | ||
config = json.load(config_file) | ||
|
||
name = config['name'] | ||
constrained = config['constrained'] | ||
batch_size = config['batch_size'] | ||
lr = config['lr'] | ||
num_colors = config['num_colors'] | ||
epochs = config['epochs'] | ||
dataset = config['dataset'] | ||
resize = config['resize'] # Only relevant for celeba | ||
crop = config['crop'] # Only relevant for celeba | ||
grayscale = config["grayscale"] # Only relevant for celeba | ||
num_conds = config['num_conds'] # Only relevant if constrained | ||
num_samples = config['num_samples'] # Only relevant if constrained | ||
filter_size = config['filter_size'] | ||
depth = config['depth'] | ||
num_filters_cond = config['num_filters_cond'] | ||
num_filters_prior = config['num_filters_prior'] | ||
mask_descriptor = config['mask_descriptor'] | ||
weight_cond_logits_loss = config['weight_cond_logits_loss'] | ||
weight_prior_logits_loss = config['weight_prior_logits_loss'] | ||
|
||
# Create a folder to store experiment results | ||
timestamp = time.strftime("%Y-%m-%d_%H-%M") | ||
directory = "{}_{}".format(timestamp, name) | ||
if not os.path.exists(directory): | ||
os.makedirs(directory) | ||
|
||
# Save config file in experiment directory | ||
with open(directory + '/config.json', 'w') as config_file: | ||
json.dump(config, config_file) | ||
|
||
# Get data | ||
if dataset == 'mnist': | ||
data_loader, _ = mnist(batch_size, num_colors=num_colors, size=resize) | ||
img_size = (1, resize, resize) | ||
elif dataset == 'celeba': | ||
data_loader = celeba(batch_size, num_colors=num_colors, size=resize, | ||
crop=crop, grayscale=grayscale) | ||
if grayscale: | ||
img_size = (1, resize, resize) | ||
else: | ||
img_size = (3, resize, resize) | ||
|
||
# Initialize model weights and architecture | ||
model = initialize_model(img_size, | ||
num_colors, | ||
depth, | ||
filter_size, | ||
constrained, | ||
num_filters_prior, | ||
num_filters_cond) | ||
model.to(device) | ||
print(model) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | ||
|
||
if constrained: | ||
mask_generator = MaskGenerator(img_size, mask_descriptor) | ||
trainer = PixelConstrainedTrainer(model, optimizer, device, mask_generator, | ||
weight_cond_logits_loss=weight_cond_logits_loss, | ||
weight_prior_logits_loss=weight_prior_logits_loss) | ||
# Train model | ||
progress_imgs = trainer.train(data_loader, epochs, directory=directory) | ||
|
||
# Get a random batch of images | ||
for batch, _ in data_loader: | ||
break | ||
|
||
for i in range(num_conds): | ||
mask = mask_generator.get_masks(batch_size) | ||
print('Generating {}/{} conditionings'.format(i + 1, num_conds)) | ||
cond_pixels = get_repeated_conditional_pixels(batch[i:i+1], mask[i:i+1], | ||
num_colors, num_samples) | ||
# Save mask as tensor | ||
torch.save(mask[i:i+1], directory + '/mask{}.pt'.format(i)) | ||
# Save image that gave rise to the conditioning as tensor | ||
torch.save(batch[i:i+1], directory + '/source{}.pt'.format(i)) | ||
# Save conditional pixels as tensor and image | ||
torch.save(cond_pixels[0:1], directory + '/cond_pixels{}.pt'.format(i)) | ||
save_image(cond_pixels[0:1], directory + '/cond_pixels{}.png'.format(i)) | ||
|
||
cond_pixels = cond_pixels.to(device) | ||
samples = model.sample(cond_pixels) | ||
# Save samples and mean sample as tensor and image | ||
torch.save(samples, directory + '/samples_cond{}.pt'.format(i)) | ||
save_image(samples.float() / (num_colors - 1.), | ||
directory + '/samples_cond{}.png'.format(i)) | ||
save_image(samples.float().mean(dim=0) / (num_colors - 1.), | ||
directory + '/mean_cond{}.png'.format(i)) | ||
# Save conditional logits if image is binary | ||
if num_colors == 2: | ||
# Save conditional logits | ||
logits, _, cond_logits = model(batch[i:i+1].float().to(device), cond_pixels[0:1]) | ||
# Second dimension corresponds to different pixel values, so select probs of it being 1 | ||
save_image(cond_logits[:, 1], directory + '/prob_of_one_cond{}.png'.format(i)) | ||
# Second dimension corresponds to different pixel values, so select probs of it being 1 | ||
save_image(logits[:, 1], directory + '/prob_of_one_logits{}.png'.format(i)) | ||
else: | ||
trainer = Trainer(model, optimizer, device) | ||
progress_imgs = trainer.train(data_loader, epochs, directory=directory) | ||
|
||
# Save losses and plots of them | ||
with open(directory + '/losses.json', 'w') as losses_file: | ||
json.dump(trainer.losses, losses_file) | ||
|
||
# Save model | ||
torch.save(trainer.model.state_dict(), directory + '/model.pt') | ||
|
||
# Save gif of progress | ||
imageio.mimsave(directory + '/training.gif', progress_imgs, fps=24) |
Oops, something went wrong.