-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·216 lines (188 loc) · 7.1 KB
/
utils.py
File metadata and controls
executable file
·216 lines (188 loc) · 7.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from lz import *
from datetime import datetime
from PIL import Image
import numpy as np
import io
from torchvision import transforms as trans
from torchvision import transforms
from data.data_pipe import de_preprocess
import torch
from models.model import l2_norm
import cv2
known_bottom = ['SuperKernel', 'batchnorm',
'conv', 'activation', 'linear',
'Conv2dSamePadding', ]
def separate_bn_paras(modules):
if isinstance(modules, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
modules = modules.module
paras_only_bn = []
paras_wo_bn = []
if not isinstance(modules, list):
modules = [*modules.modules()]
# [layer.__class__ for layer in modules]
should_skip = False
for layer in modules:
if 'model' in str(layer.__class__): # model defeind in models and model.py
should_skip = True
if 'torch.nn.modules.container' in str(layer.__class__):
should_skip = True
for kb_ in known_bottom:
if kb_ in str(layer.__class__):
should_skip = False
if should_skip:
logging.info(f'ignore {layer.__class__}')
continue
cls_nm = str(layer.__class__).lower()
if 'batchnorm' in cls_nm or 'inplaceabn' in cls_nm or 'prelu' in cls_nm:
paras_only_bn.extend([*layer.parameters()])
else:
paras_wo_bn.extend([*layer.parameters()])
# names = [ name for name, p in modules.named_parameters()]
# for name, p in modules.named_parameters():
return paras_only_bn, paras_wo_bn
def seperate_last_params(modules, name='output_layer'):
if isinstance(modules, torch.nn.DataParallel):
modules = modules.module
if not isinstance(modules, list):
modules = [*modules.modules()]
paras_only_bn = []
paras_wo_bn = []
for layer in modules:
if 'model' in str(layer.__class__):
continue
if 'container' in str(layer.__class__):
continue
else:
if name in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_wo_bn.extend([*layer.parameters()])
return paras_only_bn, paras_wo_bn
def prepare_facebank(conf, model, mtcnn, tta=True):
model.eval()
embeddings = []
names = ['Unknown']
for path in conf.facebank_path.iterdir():
if path.is_file():
continue
else:
embs = []
for file in path.iterdir():
if not file.is_file():
continue
else:
try:
img = Image.open(file)
except:
continue
if img.size != (112, 112):
img = mtcnn.align(img)
with torch.no_grad():
if tta:
mirror = trans.functional.hflip(img)
emb = model(conf.test_transform(img).to(conf.device).unsqueeze(0))
emb_mirror = model(conf.test_transform(mirror).to(conf.device).unsqueeze(0))
embs.append(l2_norm(emb + emb_mirror))
else:
embs.append(model(conf.test_transform(img).to(conf.device).unsqueeze(0)))
if len(embs) == 0:
continue
embedding = torch.cat(embs).mean(0, keepdim=True)
embeddings.append(embedding)
names.append(path.name)
embeddings = torch.cat(embeddings)
names = np.array(names)
torch.save(embeddings, conf.facebank_path / 'facebank.pth')
np.save(conf.facebank_path / 'names', names)
return embeddings, names
def load_facebank(conf):
embeddings = torch.load(conf.facebank_path / 'facebank.pth')
names = np.load(conf.facebank_path / 'names.npy')
return embeddings, names
def face_reader(conf, conn, flag, boxes_arr, result_arr, learner, mtcnn, targets, tta):
while True:
try:
image = conn.recv()
except:
continue
try:
bboxes, faces = mtcnn.align_multi(image, limit=conf.face_limit)
except:
bboxes = []
results = learner.infer(conf, faces, targets, tta)
if len(bboxes) > 0:
print('bboxes in reader : {}'.format(bboxes))
bboxes = bboxes[:, :-1] # shape:[10,4],only keep 10 highest possibiity faces
bboxes = bboxes.astype(int)
bboxes = bboxes + [-1, -1, 1, 1] # personal choice
assert bboxes.shape[0] == results.shape[0], 'bbox and faces number not same'
bboxes = bboxes.reshape([-1])
for i in range(len(boxes_arr)):
if i < len(bboxes):
boxes_arr[i] = bboxes[i]
else:
boxes_arr[i] = 0
for i in range(len(result_arr)):
if i < len(results):
result_arr[i] = results[i]
else:
result_arr[i] = -1
else:
for i in range(len(boxes_arr)):
boxes_arr[i] = 0 # by default,it's all 0
for i in range(len(result_arr)):
result_arr[i] = -1 # by default,it's all -1
print('boxes_arr : {}'.format(boxes_arr[:4]))
print('result_arr : {}'.format(result_arr[:4]))
flag.value = 0
hflip = trans.Compose([
de_preprocess,
trans.ToPILImage(),
trans.functional.hflip,
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def hflip_batch(imgs_tensor):
hfliped_imgs = torch.empty_like(imgs_tensor)
for i, img_ten in enumerate(imgs_tensor):
hfliped_imgs[i] = hflip(img_ten)
return hfliped_imgs
ccrop = transforms.Compose([
de_preprocess,
transforms.ToPILImage(),
transforms.Resize([128, 128]), # smaller side resized
transforms.CenterCrop([112, 112]),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def ccrop_batch(imgs_tensor):
ccropped_imgs = torch.empty_like(imgs_tensor)
for i, img_ten in enumerate(imgs_tensor):
ccropped_imgs[i] = ccrop(img_ten)
return ccropped_imgs
def get_time():
# return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
return (str(datetime.now())[:-13]).replace(' ', '-').replace(':', '-')
def gen_plot(fpr, tpr):
"""Create a pyplot plot and save to buffer."""
plt.figure()
plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)
plt.title("ROC Curve", fontsize=14)
plot = plt.plot(fpr, tpr, linewidth=2)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return buf
def draw_box_name(bbox, name, frame):
frame = cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 6)
frame = cv2.putText(frame,
name,
(bbox[0], bbox[1]),
cv2.FONT_HERSHEY_SIMPLEX,
2,
(0, 255, 0),
3,
cv2.LINE_AA)
return frame