Skip to content

Commit

Permalink
🐳 change readme and config so default model for training and generati…
Browse files Browse the repository at this point in the history
…on is mnist. fix celeba config and dataloader so they expect the original non resized, non cropped celeba images.
  • Loading branch information
EmilienDupont committed Dec 19, 2018
1 parent 947ac0e commit 3c1ced6
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ The attributes of the model can be set in the `config.json` file. To train the m
python main.py config.json
```

This will also save the trained model and log various information as training progresses.
This will also save the trained model and log various information as training progresses. Examples of `config.json` files are available in the `trained_models` directory.

### Inpainting

To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the CelebA dataset by conditioning on the top 16 rows. The model used to generate the completions is the trained CelebA model included in this repo and the results are saved to the `celeba_experiment` folder.
To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the top 14 rows. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_inpaintings` folder.

```
python main_generate.py -n celeba_experiment -m trained_models/celeba -t generation -i 73 84 -to 16 -ns 64
python main_generate.py -n mnist_inpaintings -m trained_models/mnist -t generation -i 73 84 -to 14 -ns 64
```

For a full list of options, run `python main_generate.py --help`.
For a full list of options, run `python main_generate.py --help`. Note that if you do not have the MNIST dataset on your machine it will be automatically downloaded when running the above command. The CelebA dataset will have to be manually downloaded (see the Data sources section). If you already have the datasets downloaded, you can change the paths in `utils/dataloaders.py` to point to the correct folders on your machine.

## Trained models

The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models.

## Data sources

The MNIST dataset can be automatically downloaded using `torchvision`. All CelebA images were resized to 32 by 32 and can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
The MNIST dataset can be automatically downloaded using `torchvision`. The CelebA dataset can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

## Citing

Expand Down
22 changes: 11 additions & 11 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
{
"name": "celeba_model",
"dataset": "celeba",
"resize": 32,
"crop": 32,
"grayscale": false,
"name": "mnist_model",
"dataset": "mnist",
"resize": 28,
"crop": 28,
"grayscale": true,
"batch_size": 64,
"constrained": true,
"num_colors": 32,
"num_colors": 2,
"filter_size": 5,
"depth": 18,
"num_filters_cond": 66,
"num_filters_prior": 66,
"depth": 16,
"num_filters_cond": 32,
"num_filters_prior": 32,
"lr": 4e-4,
"epochs": 20,
"mask_descriptor": ["random_rect", [12, 12]],
"epochs": 50,
"mask_descriptor": ["random_rect", [10, 10]],
"num_conds": 4,
"num_samples": 64,
"weight_cond_logits_loss": 1.0,
Expand Down
2 changes: 1 addition & 1 deletion trained_models/celeba/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "0",
"dataset": "celeba",
"resize": 32,
"crop": 32,
"crop": 89,
"grayscale": false,
"batch_size": 64,
"constrained": true,
Expand Down
18 changes: 10 additions & 8 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mnist(batch_size=128, num_colors=256, size=28,
transforms.Lambda(lambda x: quantize(x))
])

train_data = datasets.MNIST(path_to_data, train=True, download=False,
train_data = datasets.MNIST(path_to_data, train=True, download=True,
transform=all_transforms)
test_data = datasets.MNIST(path_to_data, train=False,
transform=all_transforms)
Expand All @@ -41,9 +41,11 @@ def mnist(batch_size=128, num_colors=256, size=28,
return train_loader, test_loader


def celeba(batch_size=128, num_colors=256, size=64, crop=64, grayscale=False,
shuffle=True, path_to_data='../celeba_64'):
"""CelebA dataloader with (64, 64) images.
def celeba(batch_size=128, num_colors=256, size=178, crop=178, grayscale=False,
shuffle=True, path_to_data='../celeba_data'):
"""CelebA dataloader with square images. Note original CelebA images have
shape (218, 178), this dataloader center crops these images to be (178, 178)
by default.
Parameters
----------
Expand All @@ -54,7 +56,7 @@ def celeba(batch_size=128, num_colors=256, size=64, crop=64, grayscale=False,
lower for e.g. binary images.
size : int
Size (height and width) of each image. Default is 64 for no resizing.
Size (height and width) of each image.
crop : int
Size of center crop. This crop happens *before* the resizing.
Expand All @@ -66,7 +68,7 @@ def celeba(batch_size=128, num_colors=256, size=64, crop=64, grayscale=False,
If True shuffles images.
path_to_data : string
Path to 64 by 64 CelebA data files.
Path to CelebA image files.
"""
quantize = get_quantize_func(num_colors)

Expand All @@ -93,12 +95,12 @@ def celeba(batch_size=128, num_colors=256, size=64, crop=64, grayscale=False,


class CelebADataset(Dataset):
"""CelebA dataset with 64 by 64 images.
"""CelebA dataset.
Parameters
----------
path_to_data : string
Path to 64 by 64 CelebA data files.
Path to CelebA images.
subsample : int
Only load every |subsample| number of images.
Expand Down

0 comments on commit 3c1ced6

Please sign in to comment.