This repository has been archived by the owner on Jan 18, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathinference.py
92 lines (79 loc) · 2.92 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
#
import argparse
import csv
import os
import os.path as osp
import gluoncvth as gcv
import numpy as np
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
def getargs():
parser = argparse.ArgumentParser()
parser.add_argument('--inputdir', type=str, required=True,
help="input data dir")
parser.add_argument('--model', type=str, required=True,
help="model file")
parser.add_argument('--outputdir', type=str, default=None,
help='output dir')
parser.add_argument('--aux', action='store_true',
help='use aux layer')
return parser.parse_args()
def get_gleason_grade(segmentation):
segmentation = segmentation.flatten()
u, count = np.unique(segmentation, return_counts=True)
ind = np.argsort(count)
if u.size == 1:
primary = u[ind][-1]
result = primary*2
elif u.size == 2:
primary = u[ind][-1]
secondary = u[ind][-2]
result = primary + secondary
else:
primary = u[ind][-1]
result = primary + u.max()
return result
if __name__ == '__main__':
args = getargs()
os.makedirs(osp.join(args.outputdir, 'task1'), exist_ok=True)
os.makedirs(osp.join(args.outputdir, 'task2'), exist_ok=True)
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 6
model = gcv.models.get_psp_resnet101_ade(pretrained=True)
model.auxlayer.conv5[-1] = nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
model.head.conv5[-1] = nn.Conv2d(512, num_classes, kernel_size=1, stride=1)
model_data = torch.load(args.model, map_location='cpu')
model.load_state_dict(model_data['model'])
model = model.to(device)
tf = transforms.Compose([
transforms.Resize(800),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
model.eval()
grade = {}
for imgfile in os.listdir(args.inputdir):
data = Image.open(osp.join(args.inputdir, imgfile))
w = data.width
h = data.height
data = tf(data)
data = data.to(device).unsqueeze(0)
y, y_aux = model(data)
# y_aux takes lower weight, or even 0 if you like
y = y + 0.5 * y_aux
y = y.argmax(dim=1).cpu().squeeze().numpy().astype(np.uint8)
y[y == 2] = 6
result = Image.fromarray(y)
result = transforms.Resize((h, w))(result)
result.save(osp.join(args.outputdir, 'task1', imgfile))
grade[imgfile[:-4]] = get_gleason_grade(y)
with open(osp.join(args.outputdir, 'task2', 'task2.csv'), 'w') as f:
writer = csv.writer(f)
for key in grade.keys():
writer.writerow([key, grade[key]])
print('Done')