-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfast_approximate_evaluation.py
100 lines (75 loc) · 3 KB
/
fast_approximate_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# from PIL import Image
import os
import argparse
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
import torch
from torch.autograd import Variable
import data_loader as data_loader
from models.vae import BetaVAE
from models.ae import AE
import utilities
torch.manual_seed(42)
torch.cuda.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='/scratch/image_datasets/3_65x65/ready',
help="Directory containing the dataset")
parser.add_argument('--model_dir', default='models/',
help="Directory containing params.json")
parser.add_argument('--weights_dir', default='/scratch/image_datasets/3_65x65/ready/weights',
help="Directory where weights will be saved")
parser.add_argument('--restore_file', default=None,
help="Optional, name of the file in --model_dir containing weights to reload before \
training") # 'best' or 'train'
args = parser.parse_args()
weights_path = os.path.join(args.weights_dir, 'weights_20210304_140125_vae/best.pth.tar')
json_path = os.path.join(args.model_dir, 'params.json')
model = BetaVAE(32)#AE()
model.load_state_dict(torch.load(weights_path)['state_dict'])
model.eval()
params = utilities.Params(json_path)
params.cuda = torch.cuda.is_available()
if params.cuda:
model = model.cuda()
dataloaders = data_loader.fetch_dataloader(['test'], args.data_dir, params, batch_size=32)
test_dl = dataloaders['test']
counter = 0
counter_vis = 0
diff_mse_cum = 0
diff_ssim_cum = 0
diff_psnr_cum = 0
for data_batch in test_dl:
if params.cuda:
data_batch = data_batch.cuda(non_blocking=True)
data_batch = Variable(data_batch)
output_batch, _, _ = model(data_batch) # , _, _
data_batch = data_batch.cpu().numpy()
plt.imshow(data_batch[0][0], cmap='gray')
plt.show()
output_batch = output_batch.detach().cpu().numpy()
plt.imshow(output_batch[0][0], cmap='gray')
plt.show()
counter += data_batch.shape[0]
counter_vis += 1
if counter_vis > 10:
break
for i in range(output_batch.shape[0]):
diff_mse = mse(data_batch[i], output_batch[i])
diff_mse_cum += diff_mse
dr_max = max(data_batch[i].max(), output_batch[i].max())
dr_min = min(data_batch[i].min(), output_batch[i].min())
diff_ssim = ssim(data_batch[i,0], output_batch[i,0], data_range=dr_max - dr_min)
diff_ssim_cum += diff_ssim
diff_psnr = psnr(data_batch[i,0], output_batch[i,0], data_range=dr_max - dr_min)
diff_psnr_cum += diff_psnr
diff_mse_average = diff_mse_cum / counter
diff_ssim_average = diff_ssim_cum / counter
diff_psnr_average = diff_psnr_cum / counter
print(diff_mse_average)
print(diff_ssim_average)
print(diff_psnr_average)
# print("MSEs", diff_mse_average)
# print("SSIMs", diff_ssim_average)
# print("PSNRs", diff_psnr_average)