@@ -69,6 +69,47 @@ def _get_non_identity_op(tensor):
6969 return tensor
7070
7171
72+ def convert_saved_model (sm_dir : str , tflite_model_path : str ):
73+ """Convert a saved model to tflite.
74+
75+ Args:
76+ sm_dir: path to the saved model to convert
77+
78+ tflite_model_path: desired output file path. Directory structure will
79+ be created by this function, as needed.
80+ """
81+ tf .io .gfile .makedirs (os .path .dirname (tflite_model_path ))
82+ converter = tf .lite .TFLiteConverter .from_saved_model (sm_dir )
83+ converter .target_spec .supported_ops = [
84+ tf .lite .OpsSet .TFLITE_BUILTINS ,
85+ ]
86+ converter .allow_custom_ops = True
87+ tfl_model = converter .convert ()
88+ with tf .io .gfile .GFile (tflite_model_path , 'wb' ) as f :
89+ f .write (tfl_model )
90+
91+
92+ def convert_mlgo_model (mlgo_model_dir : str , tflite_model_dir : str ):
93+ """Convert a mlgo saved model to mlgo tflite.
94+
95+ Args:
96+ mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
97+ the saved model files (i.e. saved_model.pb, the variables dir) and the
98+ output_spec.json file
99+
100+ tflite_model_dir: path to a directory where the tflite model will be placed.
101+ The model will be named model.tflite. Alongside it will be placed a copy
102+ of the output_spec.json file.
103+ """
104+ tf .io .gfile .makedirs (tflite_model_dir )
105+ convert_saved_model (mlgo_model_dir ,
106+ os .path .join (tflite_model_dir , TFLITE_MODEL_NAME ))
107+
108+ src_json = os .path .join (mlgo_model_dir , OUTPUT_SIGNATURE )
109+ dest_json = os .path .join (tflite_model_dir , OUTPUT_SIGNATURE )
110+ tf .io .gfile .copy (src_json , dest_json )
111+
112+
72113class PolicySaver (object ):
73114 """Object that saves policy and model config file required by inference.
74115
@@ -157,46 +198,11 @@ def _write_output_signature(self, saver, path):
157198 def save (self , root_dir : str ):
158199 """Writes policy and model_binding.txt to root_dir/policy_name/."""
159200 for policy_name , (saver , _ ) in self ._policy_saver_dict .items ():
160- self ._save_policy (saver , os .path .join (root_dir , policy_name ))
161- self ._write_output_signature (saver , os .path .join (root_dir , policy_name ))
162-
163-
164- def convert_saved_model (sm_dir : str , tflite_model_path : str ):
165- """Convert a saved model to tflite.
166-
167- Args:
168- sm_dir: path to the saved model to convert
169-
170- tflite_model_path: desired output file path. Directory structure will
171- be created by this function, as needed.
172- """
173- tf .io .gfile .makedirs (os .path .dirname (tflite_model_path ))
174- converter = tf .lite .TFLiteConverter .from_saved_model (sm_dir )
175- converter .target_spec .supported_ops = [
176- tf .lite .OpsSet .TFLITE_BUILTINS ,
177- ]
178- tfl_model = converter .convert ()
179- with tf .io .gfile .GFile (tflite_model_path , 'wb' ) as f :
180- f .write (tfl_model )
181-
182-
183- def convert_mlgo_model (mlgo_model_dir : str , tflite_model_dir : str ):
184- """Convert a mlgo saved model to mlgo tflite.
185-
186- Args:
187- mlgo_model_dir: path to the mlgo saved model dir. It is expected to contain
188- the saved model files (i.e. saved_model.pb, the variables dir) and the
189- output_spec.json file
190-
191- tflite_model_dir: path to a directory where the tflite model will be placed.
192- The model will be named model.tflite. Alongside it will be placed a copy of
193- the output_spec.json file.
194- """
195- tf .io .gfile .makedirs (tflite_model_dir )
196- convert_saved_model (mlgo_model_dir ,
197- os .path .join (tflite_model_dir , TFLITE_MODEL_NAME ))
198-
199- json_file = 'output_spec.json'
200- src_json = os .path .join (mlgo_model_dir , json_file )
201- dest_json = os .path .join (tflite_model_dir , json_file )
202- tf .io .gfile .copy (src_json , dest_json )
201+ saved_model_dir = os .path .join (root_dir , policy_name )
202+ self ._save_policy (saver , saved_model_dir )
203+ self ._write_output_signature (saver , saved_model_dir )
204+ # This is not quite the most efficient way to do this - we save the model
205+ # just to load it again and save it as tflite - but it's the minimum,
206+ # temporary step so we can validate more thoroughly our use of tflite.
207+ convert_saved_model (saved_model_dir ,
208+ os .path .join (saved_model_dir , TFLITE_MODEL_NAME ))
0 commit comments