Skip to content

Commit

Permalink
🌴 code review fixes and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilienDupont committed Dec 14, 2018
1 parent abb7873 commit 62373a6
Show file tree
Hide file tree
Showing 18 changed files with 95 additions and 46 deletions.
58 changes: 25 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,63 @@

Pytorch implementation of [Probabilistic Semantic Inpainting with Pixel Constrained CNNs](https://arxiv.org/abs/1804.00104) (2018).

This repo contains all code to reproduce the experiments in the paper as well as all the trained model weights.
This repo contains an implementation of Pixel Constrained CNN, a framework for performing probabilistic inpainting of images with arbitrary occlusions. It also includes all code to reproduce the experiments in the paper as well as the weights of the trained models.

## Examples

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/summary-figure.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/summary-gif.gif" width="500">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/grid-progression-celeba-row.gif" width="500">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-top-celeba.png" width="400">

#### Samples sorted by their likelihood
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-left-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/eye-completions-likelihood.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-small-missing-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/mnist-likelihood.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-random-celeba.png" width="400">

#### Pixel probabilities during sampling
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-bottom-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-1.gif" width="200">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-blob-samples.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit-progression-2.gif" width="200">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-bottom-seven-nine.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_1_from_1.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/new-blob-samples-2.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_3_from_3.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/int-dark-side.png" width="300">

#### Architecture
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/int-eye-color.png" width="300">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/architecture.png" width="400">
<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/int-male-female.png" width="300">

## Usage

To train a model, run (ensure that you have either the CelebA or MNIST dataset downloaded)
### Training

The attributes of the model can be set in the `config.json` file. To train the model, run

```
python main.py config.json
```

To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the bottom 7 rows. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_experiment` folder.
This will also save the trained model and log various information as training progresses.

### Inpainting

To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the CelebA dataset by conditioning on the top 16 rows. The model used to generate the completions is the trained CelebA model included in this repo and the results are saved to the `celeba_experiment` folder.

```
python main_generate.py -n mnist_experiment -m trained_models/mnist -t generation -i 73 84 -b 7 -ns 64
python main_generate.py -n celeba_experiment -m trained_models/celeba -t generation -i 73 84 -to 16 -ns 64
```

For a full list of options, run `python main_generate.py --help`.

## Trained models

The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models.

## Data sources

The MNIST dataset can be automatically downloaded using `torchvision`. All CelebA images were resized to be 32 by 32. Data can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
The MNIST dataset can be automatically downloaded using `torchvision`. All CelebA images were resized to 32 by 32 and can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

## Citing

Expand All @@ -65,22 +73,6 @@ If you find this work useful in your research, please cite using:
}
```

## More examples

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-celeba-2.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/blob-samples-mnist.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/bottom-samples-celeba-2.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_6_from_6.png" width="400">

<img src="https://github.com/Schlumberger/pixel-constrained-cnn/raw/master/open-source/imgs/logit_5_from_0.png" width="400">

## License

MIT
Binary file added imgs/int-dark-side.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/int-eye-color.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/int-male-female.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-blob-samples-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-blob-samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-bottom-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-bottom-seven-nine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-left-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-random-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-small-missing-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/new-top-celeba.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/summary-gif.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions main_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@
type=int, help='Number of random pixels to keep unmasked.')
parser.add_argument('-b', '--bottom', dest='bottom_attribute', default=None,
type=int, help='Number of bottom pixels to keep unmasked.')
parser.add_argument('-to', '--top', dest='top_attribute', default=None,
type=int, help='Number of top pixels to keep unmasked.')
parser.add_argument('-c', '--center', dest='center_attribute', default=None,
type=int, help='Number of central pixels to keep unmasked.')
parser.add_argument('-e', '--edge', dest='edge_attribute', default=None,
type=int, help='Number of edge pixels to keep unmasked.')
parser.add_argument('-l', '--left', dest='left_attribute', default=None,
type=int, help='Number of left pixels to keep unmasked.')
parser.add_argument('-ri', '--right', dest='right_attribute', default=None,
type=int, help='Number of right pixels to keep unmasked.')
parser.add_argument('-rb', '--random-blob', dest='blob_attribute', default=None,
type=int, nargs='+', help='First int should be maximum number of blobs, second lower bound on num_iters and third upper bound on num_iters.')
parser.add_argument('-mf', '--mask-folder', dest='folder_attribute', default=None,
Expand Down Expand Up @@ -75,12 +79,16 @@
mask_descriptors.append(('random', args.random_attribute))
if args.bottom_attribute is not None:
mask_descriptors.append(('bottom', args.bottom_attribute))
if args.top_attribute is not None:
mask_descriptors.append(('top', args.top_attribute))
if args.center_attribute is not None:
mask_descriptors.append(('center', args.center_attribute))
if args.edge_attribute is not None:
mask_descriptors.append(('edge', args.edge_attribute))
if args.left_attribute is not None:
mask_descriptors.append(('left', args.left_attribute))
if args.right_attribute is not None:
mask_descriptors.append(('right', args.right_attribute))
if args.blob_attribute is not None:
max_num_blobs, lower_iter, upper_iter = args.blob_attribute
mask_descriptors.append(('random_blob', (max_num_blobs, (lower_iter, upper_iter), 0.5)))
Expand Down
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
matplotlib
numpy
torch
torchvision
imageio
Pillow
imageio==2.4.1
matplotlib==2.2.3
numpy==1.11.2
Pillow==5.2.0
torch==0.4.1
torchvision==0.2.1
2 changes: 1 addition & 1 deletion utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def mnist(batch_size=128, num_colors=256, size=28,

def celeba(batch_size=128, num_colors=256, size=64, crop=64, grayscale=False,
shuffle=True, path_to_data='../celeba_64'):
"""MNIST dataloader with (64, 64) images.
"""CelebA dataloader with (64, 64) images.
Parameters
----------
Expand Down
59 changes: 53 additions & 6 deletions utils/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,30 @@ class MaskGenerator():
2. ('bottom', int): Generates masks where only the bottom pixels are
visible. The int determines the number of rows of the image to
keep visible at the bottom.
3. ('center', int): Generates masks where only the central pixels
3. ('top', int): Generates masks where only the top pixels are
visible. The int determines the number of rows of the image to
keep visible at the top.
4. ('center', int): Generates masks where only the central pixels
are visible. The int determines the size in pixels of the sides
of the square of visible pixels of the image.
4. ('edge', int): Generates masks where only the edge pixels of the
5. ('edge', int): Generates masks where only the edge pixels of the
image are visible. The int determines the thickness of the edges
in pixels.
5. ('left', int): Generates masks where only the left pixels of the
6. ('left', int): Generates masks where only the left pixels of the
image are visible. The int determines the number of columns
in pixels which are visible.
6. ('random_rect', (int, int)): Generates random rectangular masks
7. ('right', int): Generates masks where only the right pixels of
the image are visible. The int determines the number of columns
in pixels which are visible.
8. ('random_rect', (int, int)): Generates random rectangular masks
where the maximum height and width of the rectangles are
determined by the two ints.
7. ('random_blob', (int, (int, int), float)): Generates random
9. ('random_blob', (int, (int, int), float)): Generates random
blobs, where the number of blobs is determined by the first int,
the range of iterations (see function definition) is determined
by the tuple of ints and the threshold for making pixels visible
is determined by the float.
8. ('random_blob_cache', (str, int)): Loads pregenerated random masks
10. ('random_blob_cache', (str, int)): Loads pregenerated random masks
from a folder given by the string, using a batch_size given by
the int.
"""
Expand Down Expand Up @@ -81,12 +87,16 @@ def get_masks(self, batch_size):
return batch_random_mask(self.img_size, num_visibles, batch_size)
elif self.mask_type == 'bottom':
return batch_bottom_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'top':
return batch_top_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'center':
return batch_center_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'edge':
return batch_edge_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'left':
return batch_left_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'right':
return batch_right_mask(self.img_size, self.mask_attribute, batch_size)
elif self.mask_type == 'random_rect':
return batch_random_rect_mask(self.img_size, self.mask_attribute[0],
self.mask_attribute[1], batch_size)
Expand Down Expand Up @@ -185,6 +195,25 @@ def batch_bottom_mask(img_size, num_rows, batch_size):
return mask


def batch_top_mask(img_size, num_rows, batch_size):
"""Masks all the output except the |num_rows| highest rows (in the height
dimension).
Parameters
----------
img_size : see single_random_mask
num_rows : int
Number of rows from top which will be visible.
batch_size : int
Number of masks to create.
"""
mask = torch.zeros(batch_size, 1, *img_size[1:])
mask[:, :, :num_rows, :] = 1.
return mask


def batch_center_mask(img_size, num_pixels, batch_size):
"""Masks all the output except the num_pixels by num_pixels central square
of the image.
Expand Down Expand Up @@ -249,6 +278,24 @@ def batch_left_mask(img_size, num_cols, batch_size):
return mask


def batch_right_mask(img_size, num_cols, batch_size):
"""Masks all the pixels except the right side of the image.
Parameters
----------
img_size : see single_random_mask
num_cols : int
Number of columns of the right side of the image to remain visible.
batch_size : int
Number of masks to create.
"""
mask = torch.zeros(batch_size, 1, *img_size[1:])
mask[:, :, :, -num_cols:] = 1.
return mask


def random_rect_mask(img_size, max_height, max_width):
"""Returns a mask with a random rectangle of visible pixels.
Expand Down
2 changes: 2 additions & 0 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def probs_and_conditional_plot(img, probs, mask, cmap='plasma'):
the original image overlayed. Note this function only works for binary
images.
Parameters
----------
img : torch.Tensor
Shape (1, H, W)
Expand Down

0 comments on commit 62373a6

Please sign in to comment.