diff --git a/sample.py b/sample.py index 08a45d6..6c3a183 100644 --- a/sample.py +++ b/sample.py @@ -53,7 +53,7 @@ def main(): model.load_state_dict(torch.load(cfg.model_path)) label = None if cfg.label == -1 else cfg.label - samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, label=label, device=device) + samples = model.sample((3, cfg.height, cfg.width), cfg.count, label=label, device=device) save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME) diff --git a/train.py b/train.py index 7205647..57f762f 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,7 @@ def test_and_sample(cfg, model, device, test_loader, height, width, losses, para losses.append(test_loss) params.append(model.state_dict()) - samples = model.sample((cfg.data_channels, height, width), cfg.epoch_samples, device=device) + samples = model.sample((3, height, width), cfg.epoch_samples, device=device) save_samples(samples, TRAIN_SAMPLES_DIR, 'epoch{}_samples.png'.format(epoch + 1))