Skip to content

Commit 9019b22

Browse files
mingxingtancopybara-github
authored andcommitted
Open source run_tflite script.
PiperOrigin-RevId: 373293969
1 parent df09af0 commit 9019b22

File tree

2 files changed

+114
-7
lines changed

2 files changed

+114
-7
lines changed

tensorflow_examples/lite/model_maker/third_party/efficientdet/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ In addition, the following table includes a list of models trained with fixed 64
9191
We have also provided a list of mobile-size lite models:
9292

9393
| Model | mAP (float) | Quantized mAP (int8) | Prameters | Mobile latency |
94-
| ------ | :------: | :------: | :------: | :------: |:------: |
95-
| lite0, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite0.tgz) | 26.41 | 26.10 | 4.3M | 36ms |
96-
| lite1, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite1.tgz) | 31.50 | 31.12 | 5.8M | 49ms |
97-
| lite2, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite2.tgz) | 35.06 | 34.69 | 7.2M | 69ms |
98-
| lite3, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite3.tgz) | 38.77 | 38.42 | 11M | 116ms |
99-
| lite3x, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite3x..gz) | 42.64 | 41.87 | 12M | 208ms |
100-
| lite4, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite4.tgz) | 43.18 | 42.83 | 20M | 260ms |
94+
| ------ | :------: | :------: | :------: | :------: |
95+
| EfficientDet-lite0, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite0.tgz) | 26.41 | 26.10 | 4.3M | 36ms |
96+
| EfficientDet-lite1, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite1.tgz) | 31.50 | 31.12 | 5.8M | 49ms |
97+
| EfficientDet-lite2, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite2.tgz) | 35.06 | 34.69 | 7.2M | 69ms |
98+
| EfficientDet-lite3, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite3.tgz) | 38.77 | 38.42 | 11M | 116ms |
99+
| EfficientDet-lite3x, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite3x..gz) | 42.64 | 41.87 | 12M | 208ms |
100+
| EfficientDet-lite4, [ckpt](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/efficientdet-lite4.tgz) | 43.18 | 42.83 | 20M | 260ms |
101101

102102

103103
## 3. Export SavedModel, frozen graph, tensort models, or tflite.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2021 Google Research. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
r"""Run TF Lite model."""
16+
from absl import app
17+
from absl import flags
18+
19+
from PIL import Image
20+
import tensorflow as tf
21+
22+
from tensorflow_examples.lite.model_maker.third_party.efficientdet import inference
23+
24+
FLAGS = flags.FLAGS
25+
26+
27+
def define_flags():
28+
"""Define flags."""
29+
flags.DEFINE_string('tflite_path', None, 'Path of tflite file.')
30+
flags.DEFINE_string('sample_image', None, 'Sample image path')
31+
flags.DEFINE_string('output_image', None, 'Output image path')
32+
flags.DEFINE_string('image_size', '512x512', 'Image size "WxH".')
33+
34+
35+
def load_image(image_path, image_size):
36+
"""Loads an image, and returns numpy.ndarray.
37+
38+
Args:
39+
image_path: str, path to image.
40+
image_size: list of int, representing [width, height].
41+
42+
Returns:
43+
image_batch: numpy.ndarray of shape [1, H, W, C].
44+
"""
45+
input_data = tf.io.gfile.GFile(image_path, 'rb').read()
46+
image = tf.io.decode_image(input_data, channels=3, dtype=tf.uint8)
47+
image = tf.image.resize(
48+
image, image_size, method='bilinear', antialias=True)
49+
return tf.expand_dims(tf.cast(image, tf.uint8), 0).numpy()
50+
51+
52+
def save_visualized_image(image, prediction, output_path):
53+
"""Saves the visualized image with prediction.
54+
55+
Args:
56+
image: numpy.ndarray of shape [H, W, C].
57+
prediction: numpy.ndarray of shape [num_predictions, 7].
58+
output_path: str, output image path.
59+
"""
60+
output_image = inference.visualize_image_prediction(
61+
image,
62+
prediction,
63+
label_map='coco')
64+
Image.fromarray(output_image).save(output_path)
65+
66+
67+
class TFLiteRunner:
68+
"""Wrapper to run TFLite model."""
69+
70+
def __init__(self, model_path):
71+
"""Init.
72+
73+
Args:
74+
model_path: str, path to tflite model.
75+
"""
76+
self.interpreter = tf.lite.Interpreter(model_path=model_path)
77+
self.interpreter.allocate_tensors()
78+
self.input_index = self.interpreter.get_input_details()[0]['index']
79+
self.output_index = self.interpreter.get_output_details()[0]['index']
80+
81+
def run(self, image):
82+
"""Run inference on a single images.
83+
84+
Args:
85+
image: numpy.ndarray of shape [1, H, W, C].
86+
87+
Returns:
88+
prediction: numpy.ndarray of shape [1, num_detections, 7].
89+
"""
90+
self.interpreter.set_tensor(self.input_index, image)
91+
self.interpreter.invoke()
92+
return self.interpreter.get_tensor(self.output_index)
93+
94+
95+
def main(_):
96+
image_size = [int(dim) for dim in FLAGS.image_size.split('x')]
97+
image = load_image(FLAGS.sample_image, image_size)
98+
99+
runner = TFLiteRunner(FLAGS.tflite_path)
100+
prediction = runner.run(image)
101+
102+
save_visualized_image(image[0], prediction[0], FLAGS.output_image)
103+
104+
105+
if __name__ == '__main__':
106+
define_flags()
107+
app.run(main)

0 commit comments

Comments
 (0)