-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
150 lines (115 loc) · 4.79 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from argparse import ArgumentParser
from os import listdir, walk
from os.path import isdir, isfile, join, dirname, abspath
from sys import exit as sys_exit, stdout
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torchvision.models as models
from PIL import Image, UnidentifiedImageError
import json
DEVICE = torch.device("cpu")
class BlurDetectionResNet(nn.Module):
def __init__(self):
super(BlurDetectionResNet, self).__init__()
self.resnet = models.resnet50()
num_ftrs = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_ftrs, 1)
def forward(self, x) -> torch.Tensor:
x = self.resnet(x)
return torch.sigmoid(x)
class RunnerDataset(torch.utils.data.Dataset):
def __init__(self, root_dir, transform=None, recursive=False):
self.root_dir = root_dir
self.transform = transform
self.samples = []
self.recursive = recursive
if self.recursive is True:
self.load_images_recursively(root_dir)
else:
self.load_images(root_dir)
def load_images_recursively(self, root_dir):
for dirpath, _, filenames in walk(root_dir):
print(dirpath)
self.load_images(dirpath)
def load_images(self, root_dir):
for image_file in listdir(root_dir):
img_path = join(root_dir, image_file)
if isfile(img_path) is False:
continue
try:
with Image.open(img_path):
self.samples.append((image_file, None))
except UnidentifiedImageError:
stdout.write(f"Cannot identify image file '{img_path}', skipping.\n")
except Exception as e:
stdout.write(f"Error loading image file '{img_path}': {e}, skipping.\n")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path = join(self.root_dir, self.samples[idx][0])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
def setup_argparse() -> ArgumentParser:
default_output_path = join(dirname(abspath(__file__)), "predictions.json")
parser = ArgumentParser(
prog="blurwarp",
description="Detection of blurry images using ResNet50 AI model",
epilog="If you encounter any problem please submit an issue here: https://github.com/MidKnightXI/BlurWarp")
parser.add_argument("target",
type=str,
required=True,
help="Define in which directory the model will analyze the images")
parser.add_argument("-o", "--output",
default=default_output_path,
type=str,
help="Define the path of the output file eg: ./out/pred.json")
parser.add_argument("-r", "--recursive",
action='store_true',
help="Recursively search for images in subdirectories")
args = parser.parse_args()
return args
def setup_model() -> BlurDetectionResNet:
model_path = join(dirname(abspath(__file__)), "blur_detection_model.tch")
model = BlurDetectionResNet()
model.load_state_dict(torch.load(model_path))
model.to(DEVICE)
model.eval()
stdout.write("Model loaded\n")
return model
def dump_predictions(path: str, predictions: list) -> None:
with open(path, "w") as f:
json.dump(predictions, f, indent=2)
stdout.write(f"Results saved to {path}\n")
def run_model(path: str, output_path: str, recursive: bool) -> None:
model = setup_model()
transform = transforms.Compose([
transforms.Resize((256, 256), antialias=True),
transforms.ToTensor(),
])
dataset = RunnerDataset(root_dir=path, transform=transform, recursive=recursive)
loader = DataLoader(dataset, batch_size=1, shuffle=False)
z_loader_dataset = zip(loader, dataset.samples)
predictions = list()
with torch.no_grad():
for _, (data, entry) in enumerate(z_loader_dataset):
if data is None:
continue
data = data[0].unsqueeze(0).to(DEVICE)
output = model(data).item()
predictions.append({
"status": True if round(output, 2) > 0.9 else False,
"filename": entry[0],
"score": round(output, 2)
})
dump_predictions(output_path, predictions)
if __name__ == "__main__":
args = setup_argparse()
if isdir(args.target) == False:
stdout.write("Please specify a proper path: path/to/directory\n")
sys_exit(1)
stdout.write(f"Using - {DEVICE} - backend to run the model\n")
run_model(args.target, args.output, args.recursive)