-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresize_training_subset.py
64 lines (52 loc) · 2.02 KB
/
resize_training_subset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os
from tqdm import tqdm
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import save_image
'''
Sample N training images, resize, and write each image to file
'''
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Sample N training images, resize, and write each image to file")
parser.add_argument('-d_in','--dir_input', type=str, default='./data/CelebA/')
parser.add_argument('-d_out','--dir_output', type=str, default='./data/CelebA_sample_resized/')
parser.add_argument('-s','--size', type=int, default=64)
parser.add_argument('-N', type=int, default=100)
parser.add_argument('--sample_grid_fname', type=str, default='')
opt = parser.parse_args()
if __name__ == '__main__':
if not os.path.exists(opt.dir_output):
os.makedirs(opt.dir_output)
pretrained_dataset = dset.ImageFolder(
root=opt.dir_input,
transform=transforms.Compose([
transforms.Resize(opt.size),
transforms.CenterCrop(opt.size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(
pretrained_dataset,
batch_size=opt.N,
shuffle=False,
num_workers=1)
data_0 = next(iter(dataloader))
for i, img in enumerate(tqdm(data_0[0], desc = f'Writing resized images to {opt.dir_output}')):
save_image(img, f"{opt.dir_output}/{i}.jpg", normalize = True)
#write out a sample grid to the result folder for paper
if opt.sample_grid_fname != '':
results_dir = './results'
if not os.path.exists(results_dir):
os.makedirs(results_dir)
img_out = np.transpose(
vutils.make_grid(data_0[0][0:4], padding=2, normalize=True),
(1, 2, 0))
plt.imshow(img_out)
plt.axis('off')
plt.savefig(f"{results_dir}/{opt.sample_grid_fname}", bbox_inches='tight', transparent=True, pad_inches=.1)