Skip to content

Commit e6c93ff

Browse files
committed
Initial commit
1 parent 9b5fdda commit e6c93ff

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

app.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# app.py
2+
#
3+
# A simple example of hosting a TensorFlow model as a Flask service
4+
#
5+
# Copyright 2017 ActiveState Software Inc.
6+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
# ==============================================================================
20+
21+
from __future__ import absolute_import
22+
from __future__ import division
23+
from __future__ import print_function
24+
25+
import logging
26+
import random
27+
import time
28+
29+
from flask import Flask, jsonify, request
30+
from scipy.misc import imread, imresize
31+
32+
import argparse
33+
import sys
34+
35+
import numpy as np
36+
import tensorflow as tf
37+
38+
app = Flask(__name__)
39+
40+
def load_graph(model_file):
41+
graph = tf.Graph()
42+
graph_def = tf.GraphDef()
43+
44+
with open(model_file, "rb") as f:
45+
graph_def.ParseFromString(f.read())
46+
with graph.as_default():
47+
tf.import_graph_def(graph_def)
48+
49+
return graph
50+
51+
def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
52+
input_mean=0, input_std=255):
53+
input_name = "file_reader"
54+
output_name = "normalized"
55+
file_reader = tf.read_file(file_name, input_name)
56+
if file_name.endswith(".png"):
57+
image_reader = tf.image.decode_png(file_reader, channels = 3,
58+
name='png_reader')
59+
elif file_name.endswith(".gif"):
60+
image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
61+
name='gif_reader'))
62+
elif file_name.endswith(".bmp"):
63+
image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
64+
else:
65+
image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
66+
name='jpeg_reader')
67+
float_caster = tf.cast(image_reader, tf.float32)
68+
dims_expander = tf.expand_dims(float_caster, 0);
69+
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
70+
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
71+
sess = tf.Session()
72+
result = sess.run(normalized)
73+
74+
return result
75+
76+
def load_labels(label_file):
77+
label = []
78+
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
79+
for l in proto_as_ascii_lines:
80+
label.append(l.rstrip())
81+
return label
82+
83+
@app.route('/')
84+
def classify():
85+
file_name = request.args['file']
86+
87+
t = read_tensor_from_image_file(file_name,
88+
input_height=input_height,
89+
input_width=input_width,
90+
input_mean=input_mean,
91+
input_std=input_std)
92+
93+
with tf.Session(graph=graph) as sess:
94+
start = time.time()
95+
results = sess.run(output_operation.outputs[0],
96+
{input_operation.outputs[0]: t})
97+
end=time.time()
98+
results = np.squeeze(results)
99+
100+
top_k = results.argsort()[-5:][::-1]
101+
labels = load_labels(label_file)
102+
103+
print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))
104+
105+
for i in top_k:
106+
print(labels[i], results[i])
107+
108+
return jsonify(labels,results.tolist())
109+
110+
if __name__ == '__main__':
111+
# TensorFlow configuration/initialization
112+
model_file = "retrained_graph.pb"
113+
label_file = "retrained_labels.txt"
114+
input_height = 224
115+
input_width = 224
116+
input_mean = 128
117+
input_std = 128
118+
input_layer = "input"
119+
output_layer = "final_result"
120+
121+
# Load TensorFlow Graph from disk
122+
graph = load_graph(model_file)
123+
124+
# Grab the Input/Output operations
125+
input_name = "import/" + input_layer
126+
output_name = "import/" + output_layer
127+
input_operation = graph.get_operation_by_name(input_name);
128+
output_operation = graph.get_operation_by_name(output_name);
129+
130+
# Initialize the Flask Service
131+
# Obviously, disable Debug in actual Production
132+
app.run(debug=True, port=8000)
133+

retrained_graph.pb

5.23 MB
Binary file not shown.

retrained_labels.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
poodle
2+
pug
3+
dachshund

0 commit comments

Comments
 (0)