forked from xuexingyu24/MobileFaceNet_Tutorial_Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathVideo_demo.py
More file actions
145 lines (117 loc) · 6.13 KB
/
Video_demo.py
File metadata and controls
145 lines (117 loc) · 6.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 4 12:31:42 2019
Video Demo
@author: AIRocker
"""
import sys
import os
sys.path.append(os.path.join(sys.path[0], 'MTCNN'))
import argparse
import torch
from torchvision import transforms as trans
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from utils.util import *
from utils.align_trans import *
from MTCNN import create_mtcnn_net
from face_model import MobileFaceNet, l2_norm
from facebank import load_facebank, prepare_facebank
import cv2
import time
def resize_image(img, scale):
"""
resize image
"""
height, width, channel = img.shape
new_height = int(height * scale) # resized new height
new_width = int(width * scale) # resized new width
new_dim = (new_width, new_height)
img_resized = cv2.resize(img, new_dim, interpolation=cv2.INTER_LINEAR) # resized image
return img_resized
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='face detection demo')
parser.add_argument('-th','--threshold',help='threshold score to decide identical faces',default=50, type=float)
parser.add_argument("-u", "--update", help="whether perform update the facebank",action="store_true", default= False)
parser.add_argument("-tta", "--tta", help="whether test time augmentation",action="store_true", default= False)
parser.add_argument("-c", "--score", help="whether show the confidence score",action="store_true",default= False )
parser.add_argument("--scale", dest='scale', help="input frame scale to accurate the speed", default=0.5, type=float)
parser.add_argument('--mini_face', dest='mini_face', help=
"Minimum face to be detected. derease to increase accuracy. Increase to increase speed",
default=20, type=int)
args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
detect_model = MobileFaceNet(512).to(device) # embeding size is 512 (feature vector)
detect_model.load_state_dict(torch.load('Weights/MobileFace_Net', map_location=lambda storage, loc: storage))
print('MobileFaceNet face detection model generated')
detect_model.eval()
if args.update:
targets, names = prepare_facebank(detect_model, path='facebank', tta=args.tta)
print('facebank updated')
else:
targets, names = load_facebank(path='facebank')
print('facebank loaded')
# targets: number of candidate x 512
cap = cv2.VideoCapture('images/TheBigBangTheory.mp4')
_, frame = cap.read()
out = cv2.VideoWriter('output.avi',cv2.VideoWriter_fourcc(*'MJPG'), 20.0, (frame.shape[1], frame.shape[0]), isColor=True)
while True:
isSuccess, frame = cap.read()
if isSuccess:
try:
start_time = time.time()
input = resize_image(frame, args.scale)
bboxes, landmarks = create_mtcnn_net(input, args.mini_face, device, p_model_path='MTCNN/weights/pnet_Weights',
r_model_path='MTCNN/weights/rnet_Weights',
o_model_path='MTCNN/weights/onet_Weights')
if bboxes != []:
bboxes = bboxes / args.scale
landmarks = landmarks / args.scale
faces = Face_alignment(frame, default_square=True, landmarks=landmarks)
embs = []
test_transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
for img in faces:
if args.tta:
mirror = cv2.flip(img,1)
emb = detect_model(test_transform(img).to(device).unsqueeze(0))
emb_mirror = detect_model(test_transform(mirror).to(device).unsqueeze(0))
embs.append(l2_norm(emb + emb_mirror))
else:
embs.append(detect_model(test_transform(img).to(device).unsqueeze(0)))
source_embs = torch.cat(embs) # number of detected faces x 512
diff = source_embs.unsqueeze(-1) - targets.transpose(1, 0).unsqueeze(0) # i.e. 3 x 512 x 1 - 1 x 512 x 2 = 3 x 512 x 2
dist = torch.sum(torch.pow(diff, 2), dim=1) # number of detected faces x numer of target faces
minimum, min_idx = torch.min(dist, dim=1) # min and idx for each row
min_idx[minimum > ((args.threshold-156)/(-80))] = -1 # if no match, set idx to -1
score = minimum
results = min_idx
# convert distance to score dis(0.7,1.2) to score(100,60)
score_100 = torch.clamp(score*-80+156,0,100)
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(image)
font = ImageFont.truetype('utils/simkai.ttf', 30)
FPS = 1.0 / (time.time() - start_time)
draw.text((10, 10), 'FPS: {:.1f}'.format(FPS), fill=(0, 255, 0), font=font)
for i, b in enumerate(bboxes):
draw.rectangle([(b[0], b[1]), (b[2], b[3])], outline='blue', width=2)
if args.score:
draw.text((int(b[0]), int(b[1]-25)), names[results[i] + 1] + ' score:{:.0f}'.format(score_100[i]), fill=(255,255,0), font=font)
else:
draw.text((int(b[0]), int(b[1]-25)), names[results[i] + 1], fill=(255,255,0), font=font)
# print(names[results[i] + 1])
# for p in landmarks:
# for i in range(5):
# draw.ellipse([(p[i] - 2.0, p[i + 5] - 2.0), (p[i] + 2.0, p[i + 5] + 2.0)], outline='blue')
frame = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
out.write(frame)
except:
print('detect error')
cv2.imshow('video', frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
out.release()
cv2.destroyAllWindows()