diff --git a/.gitignore b/.gitignore index 7ea1c96..735c5e5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ .data __pycache__/ venv -wandb \ No newline at end of file +wandb +train_samples \ No newline at end of file diff --git a/image.pth b/image.pth deleted file mode 100644 index f25c27d..0000000 Binary files a/image.pth and /dev/null differ diff --git a/pixelcnn/model.py b/pixelcnn/model.py index cea032f..9d3dd60 100644 --- a/pixelcnn/model.py +++ b/pixelcnn/model.py @@ -168,11 +168,14 @@ def forward(self, image, label): return out - def sample(self, shape, count, device='cuda'): + def sample(self, shape, count, label=None, device='cuda'): channels, height, width = shape samples = torch.zeros(count, *shape).to(device) - labels = torch.randint(high=10, size=(count,)).to(device) + if label is None: + labels = torch.randint(high=10, size=(count,)).to(device) + else: + labels = (label*torch.ones(count)).to(device) with torch.no_grad(): for i in range(height): diff --git a/sample.py b/sample.py index e53fca0..6ebd4b2 100644 --- a/sample.py +++ b/sample.py @@ -49,7 +49,8 @@ def main(): model.load_state_dict(torch.load(cfg.model_path)) - samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, device=device) + label = None if cfg.label == -1 else cfg.label + samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, label=label, device=device) save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME) diff --git a/train.py b/train.py index 7c0cb91..d7564ab 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,6 @@ MODEL_PARAMS_OUTPUT_FILENAME = 'params.pth' TRAIN_SAMPLES_DIR = 'train_samples' -TRAIN_SAMPLES_COUNT = 9 def train(cfg, model, device, train_loader, optimizer, scheduler, epoch): @@ -65,7 +64,7 @@ def test_and_sample(cfg, model, device, test_loader, height, width, epoch): }) print("Average test loss: {}".format(test_loss)) - samples = model.sample((cfg.data_channels, height, width), TRAIN_SAMPLES_COUNT, device=device) + samples = model.sample((cfg.data_channels, height, width), cfg.epoch_samples, device=device) save_samples(samples, TRAIN_SAMPLES_DIR, 'epoch{}_samples.png'.format(epoch + 1)) @@ -103,6 +102,9 @@ def main(): parser.add_argument('--max-norm', type=float, default=1., help='Max norm of the gradients after clipping') + parser.add_argument('--epoch-samples', type=int, default=25, + help='Number of images to sample each epoch') + parser.add_argument('--cuda', type=str2bool, default=True, help='Flag indicating whether CUDA should be used') diff --git a/train_samples/epoch10_samples.png b/train_samples/epoch10_samples.png deleted file mode 100644 index ac692aa..0000000 Binary files a/train_samples/epoch10_samples.png and /dev/null differ diff --git a/train_samples/epoch11_samples.png b/train_samples/epoch11_samples.png deleted file mode 100644 index b538c3e..0000000 Binary files a/train_samples/epoch11_samples.png and /dev/null differ diff --git a/train_samples/epoch12_samples.png b/train_samples/epoch12_samples.png deleted file mode 100644 index d0126f1..0000000 Binary files a/train_samples/epoch12_samples.png and /dev/null differ diff --git a/train_samples/epoch13_samples.png b/train_samples/epoch13_samples.png deleted file mode 100644 index 615dadf..0000000 Binary files a/train_samples/epoch13_samples.png and /dev/null differ diff --git a/train_samples/epoch14_samples.png b/train_samples/epoch14_samples.png deleted file mode 100644 index 303ff47..0000000 Binary files a/train_samples/epoch14_samples.png and /dev/null differ diff --git a/train_samples/epoch15_samples.png b/train_samples/epoch15_samples.png deleted file mode 100644 index bd3f1e5..0000000 Binary files a/train_samples/epoch15_samples.png and /dev/null differ diff --git a/train_samples/epoch16_samples.png b/train_samples/epoch16_samples.png deleted file mode 100644 index cedd890..0000000 Binary files a/train_samples/epoch16_samples.png and /dev/null differ diff --git a/train_samples/epoch17_samples.png b/train_samples/epoch17_samples.png deleted file mode 100644 index 558517c..0000000 Binary files a/train_samples/epoch17_samples.png and /dev/null differ diff --git a/train_samples/epoch18_samples.png b/train_samples/epoch18_samples.png deleted file mode 100644 index eb06c0d..0000000 Binary files a/train_samples/epoch18_samples.png and /dev/null differ diff --git a/train_samples/epoch1_samples.png b/train_samples/epoch1_samples.png deleted file mode 100644 index 5a5ff5b..0000000 Binary files a/train_samples/epoch1_samples.png and /dev/null differ diff --git a/train_samples/epoch2_samples.png b/train_samples/epoch2_samples.png deleted file mode 100644 index 3094bac..0000000 Binary files a/train_samples/epoch2_samples.png and /dev/null differ diff --git a/train_samples/epoch3_samples.png b/train_samples/epoch3_samples.png deleted file mode 100644 index 8075fec..0000000 Binary files a/train_samples/epoch3_samples.png and /dev/null differ diff --git a/train_samples/epoch4_samples.png b/train_samples/epoch4_samples.png deleted file mode 100644 index 3683b23..0000000 Binary files a/train_samples/epoch4_samples.png and /dev/null differ diff --git a/train_samples/epoch5_samples.png b/train_samples/epoch5_samples.png deleted file mode 100644 index fab7817..0000000 Binary files a/train_samples/epoch5_samples.png and /dev/null differ diff --git a/train_samples/epoch6_samples.png b/train_samples/epoch6_samples.png deleted file mode 100644 index 02c5607..0000000 Binary files a/train_samples/epoch6_samples.png and /dev/null differ diff --git a/train_samples/epoch7_samples.png b/train_samples/epoch7_samples.png deleted file mode 100644 index f5643f1..0000000 Binary files a/train_samples/epoch7_samples.png and /dev/null differ diff --git a/train_samples/epoch8_samples.png b/train_samples/epoch8_samples.png deleted file mode 100644 index e2ce645..0000000 Binary files a/train_samples/epoch8_samples.png and /dev/null differ diff --git a/train_samples/epoch9_samples.png b/train_samples/epoch9_samples.png deleted file mode 100644 index 7f2b618..0000000 Binary files a/train_samples/epoch9_samples.png and /dev/null differ