@@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
219219 ]
220220 object .__setattr__ (self , 'length' , sum (lengths ))
221221
222+ # TODO: is it necessary to return keys AND reward_stats(which has the keys)?
222223 assert (len (self .serialized_sequence_examples ) == len (self .rewards ) ==
223224 (len (self .keys )))
224225 assert set (self .keys ) == set (self .reward_stats .keys ())
@@ -228,6 +229,14 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
228229class CompilationRunnerStub (metaclass = abc .ABCMeta ):
229230 """The interface of a stub to CompilationRunner, for type checkers."""
230231
232+ @abc .abstractmethod
233+ def collect_results (self ,
234+ module_spec : corpus .ModuleSpec ,
235+ tf_policy_path : str ,
236+ collect_default_result : bool ,
237+ reward_only : bool = False ) -> Tuple [Dict , Dict ]:
238+ raise NotImplementedError ()
239+
231240 @abc .abstractmethod
232241 def collect_data (
233242 self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
@@ -275,6 +284,47 @@ def enable(self):
275284 def cancel_all_work (self ):
276285 self ._cancellation_manager .kill_all_processes ()
277286
287+ @staticmethod
288+ def get_rewards (result : Dict ) -> List [float ]:
289+ if len (result ) == 0 :
290+ return []
291+ return [v [1 ] for v in result .values ()]
292+
293+ def collect_results (self ,
294+ module_spec : corpus .ModuleSpec ,
295+ tf_policy_path : str ,
296+ collect_default_result : bool ,
297+ reward_only : bool = False ) -> Tuple [Dict , Dict ]:
298+ """Collect data for the given IR file and policy.
299+
300+ Args:
301+ module_spec: a ModuleSpec.
302+ tf_policy_path: path to the tensorflow policy.
303+ collect_default_result: whether to get the default result as well.
304+ reward_only: whether to only collect the rewards in the results.
305+
306+ Returns:
307+ A tuple of the default result and policy result.
308+ """
309+ default_result = None
310+ policy_result = None
311+ if collect_default_result :
312+ default_result = self ._compile_fn (
313+ module_spec ,
314+ tf_policy_path = '' ,
315+ reward_only = bool (tf_policy_path ) or reward_only ,
316+ cancellation_manager = self ._cancellation_manager )
317+ policy_result = default_result
318+
319+ if tf_policy_path :
320+ policy_result = self ._compile_fn (
321+ module_spec ,
322+ tf_policy_path ,
323+ reward_only = reward_only ,
324+ cancellation_manager = self ._cancellation_manager )
325+
326+ return default_result , policy_result
327+
278328 def collect_data (
279329 self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
280330 reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
@@ -284,8 +334,6 @@ def collect_data(
284334 module_spec: a ModuleSpec.
285335 tf_policy_path: path to the tensorflow policy.
286336 reward_stat: reward stat of this module, None if unknown.
287- cancellation_token: a CancellationToken through which workers may be
288- signaled early termination
289337
290338 Returns:
291339 A CompilationResult. In particular:
@@ -297,25 +345,18 @@ def collect_data(
297345 compilation_runner.ProcessKilledException is passed through.
298346 ValueError if example under default policy and ml policy does not match.
299347 """
348+ default_result , policy_result = self .collect_results (
349+ module_spec ,
350+ tf_policy_path ,
351+ collect_default_result = reward_stat is None ,
352+ reward_only = False )
300353 if reward_stat is None :
301- default_result = self ._compile_fn (
302- module_spec ,
303- tf_policy_path = '' ,
304- reward_only = bool (tf_policy_path ),
305- cancellation_manager = self ._cancellation_manager )
354+ # TODO: Add structure to default_result and policy_result.
355+ # get_rewards above should be updated/removed when this is resolved.
306356 reward_stat = {
307357 k : RewardStat (v [1 ], v [1 ]) for (k , v ) in default_result .items ()
308358 }
309359
310- if tf_policy_path :
311- policy_result = self ._compile_fn (
312- module_spec ,
313- tf_policy_path ,
314- reward_only = False ,
315- cancellation_manager = self ._cancellation_manager )
316- else :
317- policy_result = default_result
318-
319360 sequence_example_list = []
320361 rewards = []
321362 keys = []
0 commit comments