-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTestModel.py
56 lines (49 loc) · 2.11 KB
/
TestModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from Train import MyDataset
from imports.ParametersManager import *
from Mantra_Net import *
from matplotlib import pyplot as plt
import matplotlib
import torchvision.transforms as transforms
# Enter the *.pt model file name here to load parameters
DIR = './Pre_TrainedModel/'
# ===You need to change the name of the model here =====
ModelName = DIR + 'MantraNet on NIST16_model.pt'
# ====================================================
parManager = ParametersManager('cuda')
parManager.loadFromFile(ModelName)
print("This model has done : {} Epochs.".format(parManager.EpochDone))
model = ManTraNet()
model.cuda()
parManager.setModelParameters(model)
TrainSetDIR = './NIST2016/Train.csv'
TestSetDIR = './NIST2016/Test.csv'
'''
You can set the TrainSetDIR or TestSetDIR to validate on different dataset.
'''
data = MyDataset(TrainSetDIR)
with torch.no_grad():
model.eval()
Loader = DataLoader(data, pin_memory=True, batch_size=1, sampler= torch.utils.data.sampler.SubsetRandomSampler(range(len(data))))
trans = transforms.ToPILImage()
for (x,label) in Loader:
out = model(x.cuda())
x = trans(torch.squeeze(x,0))
label[0,0,0] = 1
y = trans(torch.squeeze(label,0))
z = trans(torch.squeeze(out.cpu(),0))
q = trans(torch.squeeze((out > 0.5).float().cpu(), 0 ))
plt.subplot(1,4,1)
plt.imshow(x, cmap='gray')
plt.subplot(1,4,2)
'''
NORM parameter here is to solve the problem of when you use plt.imshow(...) to show a Tensor is filled with '1' in every position, it will show a total black image. Because the function thought your input is all in int type, and will transform to the feild of [0-255], as the result, a '1' here is a nearly black image.
So, here we need to do the norm manually.
'''
plt.imshow(y, cmap='gray', norm=matplotlib.colors.Normalize(0,255))
plt.subplot(1,4,3)
plt.imshow(z, cmap='gray', norm=matplotlib.colors.Normalize(0,255))
plt.subplot(1,4,4)
plt.imshow(q, cmap='gray', norm=matplotlib.colors.Normalize(0,255))
plt.show()
plt.close()