Skip to content

Commit

Permalink
Fixed sample.py and deleted images
Browse files Browse the repository at this point in the history
  • Loading branch information
anordertoreclaim committed Aug 6, 2019
1 parent e56aaf4 commit f9f5376
Show file tree
Hide file tree
Showing 23 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
.data
__pycache__/
venv
wandb
wandb
train_samples
Binary file removed image.pth
Binary file not shown.
7 changes: 5 additions & 2 deletions pixelcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,14 @@ def forward(self, image, label):

return out

def sample(self, shape, count, device='cuda'):
def sample(self, shape, count, label=None, device='cuda'):
channels, height, width = shape

samples = torch.zeros(count, *shape).to(device)
labels = torch.randint(high=10, size=(count,)).to(device)
if label is None:
labels = torch.randint(high=10, size=(count,)).to(device)
else:
labels = (label*torch.ones(count)).to(device)

with torch.no_grad():
for i in range(height):
Expand Down
3 changes: 2 additions & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def main():

model.load_state_dict(torch.load(cfg.model_path))

samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, device=device)
label = None if cfg.label == -1 else cfg.label
samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, label=label, device=device)
save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)


Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
MODEL_PARAMS_OUTPUT_FILENAME = 'params.pth'

TRAIN_SAMPLES_DIR = 'train_samples'
TRAIN_SAMPLES_COUNT = 9


def train(cfg, model, device, train_loader, optimizer, scheduler, epoch):
Expand Down Expand Up @@ -65,7 +64,7 @@ def test_and_sample(cfg, model, device, test_loader, height, width, epoch):
})
print("Average test loss: {}".format(test_loss))

samples = model.sample((cfg.data_channels, height, width), TRAIN_SAMPLES_COUNT, device=device)
samples = model.sample((cfg.data_channels, height, width), cfg.epoch_samples, device=device)
save_samples(samples, TRAIN_SAMPLES_DIR, 'epoch{}_samples.png'.format(epoch + 1))


Expand Down Expand Up @@ -103,6 +102,9 @@ def main():
parser.add_argument('--max-norm', type=float, default=1.,
help='Max norm of the gradients after clipping')

parser.add_argument('--epoch-samples', type=int, default=25,
help='Number of images to sample each epoch')

parser.add_argument('--cuda', type=str2bool, default=True,
help='Flag indicating whether CUDA should be used')

Expand Down
Binary file removed train_samples/epoch10_samples.png
Binary file not shown.
Binary file removed train_samples/epoch11_samples.png
Binary file not shown.
Binary file removed train_samples/epoch12_samples.png
Binary file not shown.
Binary file removed train_samples/epoch13_samples.png
Binary file not shown.
Binary file removed train_samples/epoch14_samples.png
Binary file not shown.
Binary file removed train_samples/epoch15_samples.png
Binary file not shown.
Binary file removed train_samples/epoch16_samples.png
Binary file not shown.
Binary file removed train_samples/epoch17_samples.png
Binary file not shown.
Binary file removed train_samples/epoch18_samples.png
Binary file not shown.
Binary file removed train_samples/epoch1_samples.png
Binary file not shown.
Binary file removed train_samples/epoch2_samples.png
Binary file not shown.
Binary file removed train_samples/epoch3_samples.png
Binary file not shown.
Binary file removed train_samples/epoch4_samples.png
Binary file not shown.
Binary file removed train_samples/epoch5_samples.png
Binary file not shown.
Binary file removed train_samples/epoch6_samples.png
Binary file not shown.
Binary file removed train_samples/epoch7_samples.png
Binary file not shown.
Binary file removed train_samples/epoch8_samples.png
Binary file not shown.
Binary file removed train_samples/epoch9_samples.png
Binary file not shown.

0 comments on commit f9f5376

Please sign in to comment.