forked from s17472/VRS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_on_batch.py
95 lines (73 loc) · 2.28 KB
/
predict_on_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import time
import cv2
import numpy as np
from config import FRAMES_NO
from fgn_data_transformation import (get_optical_flow, normalize_respectively,
reshape, set_optical_flow)
from keras.models import load_model
def reshape_frames(frames):
"""
Reshaped frames using method from fgn_data_transformation module
Args:
frames: frames to be reshaped
Returns:
reshaped frames
"""
reshaped_frames = []
for frame in frames:
reshaped_frames.append(reshape(frame))
return reshaped_frames
def transform_frames(frames):
"""
Transform frames to be corresponding to model input shape
Args:
frames: frames to transform
Returns:
transformed frames (input data)
"""
collected_frames = np.array(frames)
flows = get_optical_flow(collected_frames)
data = set_optical_flow(collected_frames, flows)
data = np.float32(data)
data = normalize_respectively(data)
data = np.array([data])
return data
def get_prediction(data):
"""
classify the input
Args:
data: input data
Returns:
prediction of the fight in percentage
"""
predict = model.predict(data)[0][0]
return round(predict * 100, 2)
def get_frames(path):
"""
Reads first FRAMES_NO frames
Args:
path: source of video
Returns:
collected frames
"""
cap = cv2.VideoCapture(path)
frames = []
while len(frames) != FRAMES_NO:
_, frame = cap.read()
frames.append(frame)
return frames
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path of the model")
ap.add_argument("-s", "--source", required=False, default="testset/fight.avi", help="source of video")
args = vars(ap.parse_args())
# load the trained network
model = load_model(args["model"])
frames = get_frames(args["source"])
frames = reshape_frames(frames)
data = transform_frames(frames)
start_time = time.time()
prediction = get_prediction(data)
print("Fight:", prediction)
print("--- %s seconds ---" % (time.time() - start_time))