-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbird_classify_pycoraltest.py
182 lines (148 loc) · 6.64 KB
/
bird_classify_pycoraltest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python3
"""
Coral Smart Bird Feeder
Uses ClassificationEngine from the EdgeTPU API to analyze animals in
camera frames. Sounds a deterrent if a squirrel is detected.
Users define model, labels file, storage path, deterrent sound, and
optionally can set this to training mode for collecting images for a custom
model.
"""
import argparse
import time
import re
import imp
import logging
import gstreamer
#import sys
#sys.path.append('/usr/lib/python3/dist-packages/edgetpu')
#from edgetpu.classification.engine import ClassificationEngine
import os
from common import avg_fps_counter, SVG
from PIL import Image
from playsound import playsound
import numpy as np
from pycoral.adapters import classify
from pycoral.adapters import common
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
from pycoral.utils.edgetpu import run_inference
from pycoral.adapters.common import input_size
from pycoral.adapters.classify import get_classes
def save_data(image,results,path,ext='png'):
"""Saves camera frame and model inference results
to user-defined storage directory."""
tag = '%010d' % int(time.monotonic()*1000)
name = '%s/img-%s.%s' %(path,tag,ext)
image.save(name)
print('Frame saved as: %s' %name)
logging.info('Image: %s Results: %s', tag,results)
def load_labels(path):
"""Parses provided label file for use in model inference."""
p = re.compile(r'\s*(\d+)(.+)')
with open(path, 'r', encoding='utf-8') as f:
lines = (p.match(line).groups() for line in f.readlines())
return {int(num): text.strip() for num, text in lines}
def print_results(start_time, last_time, end_time, results):
"""Print results to terminal for debugging."""
inference_rate = ((end_time - start_time) * 1000)
fps = (1.0/(end_time - last_time))
print('\nInference: %.2f ms, FPS: %.2f fps' % (inference_rate, fps))
for label, score in results:
print(' %s, score=%.2f' %(label, score))
def do_training(results,last_results,top_k):
"""Compares current model results to previous results and returns
true if at least one label difference is detected. Used to collect
images for training a custom model."""
new_labels = [label[0] for label in results]
old_labels = [label[0] for label in last_results]
shared_labels = set(new_labels).intersection(old_labels)
if len(shared_labels) < top_k:
print('Difference detected')
return True
def user_selections():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True,
help='.tflite model path')
parser.add_argument('--labels', required=True,
help='label file path')
parser.add_argument('--top_k', type=int, default=3,
help='number of classes with highest score to display')
parser.add_argument('--threshold', type=float, default=0.1,
help='class score threshold')
parser.add_argument('--storage', required=True,
help='File path to store images and results')
parser.add_argument('--sound', required=True,
help='File path to deterrent sound')
parser.add_argument('--print', default=False, required=False,
help='Print inference results to terminal')
parser.add_argument('--training', default=False, required=False,
help='Training mode for image collection')
parser.add_argument('--videosrc', help='Which video source to use. ',
default='/dev/video0')
parser.add_argument('--headless', help='Run without displaying the video.',
default=False, type=bool)
parser.add_argument('--videofmt', help='Input video format.',
default='raw',
choices=['raw', 'h264', 'jpeg'])
args = parser.parse_args()
return args
def main():
"""Creates camera pipeline, and pushes pipeline through ClassificationEngine
model. Logs results to user-defined storage. Runs either in training mode to
gather images for custom model creation or in deterrent mode that sounds an
'alarm' if a defined label is detected."""
args = user_selections()
print("Loading %s with %s labels."%(args.model, args.labels))
interpreter = make_interpreter(args.model)
interpreter.allocate_tensors()
labels = read_label_file(args.labels)
inference_size = input_size(interpreter)
storage_dir = args.storage
#Initialize logging file
logging.basicConfig(filename='%s/results.log'%storage_dir,
format='%(asctime)s-%(message)s',
level=logging.DEBUG)
last_time = time.monotonic()
last_results = [('label', 0)]
def user_callback(input_tensor, src_size, inference_box):
nonlocal last_time
nonlocal last_results
start_time = time.monotonic()
run_inference(interpreter, input_tensor)
results = get_classes(interpreter, args.top_k, args.threshold)
end_time = time.monotonic()
results = [(labels[i], score) for i, score in results]
if args.print:
print_results(start_time,last_time, end_time, results)
if args.training:
if do_training(results,last_results,args.top_k):
save_data(image,results, storage_dir)
else:
#Custom model mode:
#The labels can be modified to detect/deter user-selected items
if results[0][0] !='background':
save_data(image, storage_dir,results)
if 'fox squirrel, eastern fox squirrel, Sciurus niger' in results:
playsound(args.sound)
logging.info('Deterrent sounded')
last_results=results
last_time = end_time
result = gstreamer.run_pipeline(user_callback)
#src_size=(640, 480),
#appsink_size=inference_size)
# delete videosrc=args.videosrc, videofmt=args.videofmt,headless=args.headless
if __name__ == '__main__':
main()