2323from tf_agents .trajectories import trajectory
2424
2525from compiler_opt .distributed import worker
26+ from compiler_opt .distributed .local import buffered_scheduler
2627from compiler_opt .rl import compilation_runner
2728from compiler_opt .rl import corpus
2829from compiler_opt .rl import data_collector
@@ -55,7 +56,7 @@ def __init__(
5556 # with the training phase - i.e. whatever happens between successive data
5657 # collection calls. Subsequent runs will wait for these to finish.
5758 self ._reset_workers : Optional [concurrent .futures .Future ] = None
58- self ._current_work : List [Tuple [ corpus . ModuleSpec , worker .WorkerFuture ] ] = []
59+ self ._current_futures : List [worker .WorkerFuture ] = []
5960 self ._pool = concurrent .futures .ThreadPoolExecutor ()
6061
6162 def close_pool (self ):
@@ -85,12 +86,15 @@ def _schedule_jobs(
8586 jobs = [(module_spec , policy_path , self ._reward_stat_map [module_spec .name ])
8687 for module_spec in sampled_modules ]
8788
88- # TODO: Issue #91. Naive load balancing.
89- ret = []
90- for i in range (len (jobs )):
91- ret .append (self ._worker_pool [i % len (self ._worker_pool )].collect_data (
92- * (jobs [i ])))
93- return ret
89+ def work_factory (job ):
90+
91+ def work (w ):
92+ return w .collect_data (* job )
93+
94+ return work
95+
96+ work = [work_factory (job ) for job in jobs ]
97+ return buffered_scheduler .schedule (work , self ._worker_pool , buffer = 10 )
9498
9599 def collect_data (
96100 self , policy_path : str
@@ -108,22 +112,20 @@ def collect_data(
108112 information is viewable in TensorBoard.
109113 """
110114 sampled_modules = self ._corpus .sample (k = self ._num_modules , sort = False )
111- results = self ._schedule_jobs (policy_path , sampled_modules )
115+ self . _current_futures = self ._schedule_jobs (policy_path , sampled_modules )
112116
113117 def wait_for_termination ():
114118 early_exit = self ._exit_checker_ctor (num_modules = self ._num_modules )
115119
116120 def get_num_finished_work ():
117- finished_work = sum (res .done () for res in results )
121+ finished_work = sum (res .done () for res in self . _current_futures )
118122 return finished_work
119123
120124 return early_exit .wait (get_num_finished_work )
121125
122126 wait_seconds = wait_for_termination ()
123- self ._current_work = list (zip (sampled_modules , results ))
124- finished_work = [
125- (spec , res ) for spec , res in self ._current_work if res .done ()
126- ]
127+ current_work = list (zip (sampled_modules , self ._current_futures ))
128+ finished_work = [(spec , res ) for spec , res in current_work if res .done ()]
127129 successful_work = [(spec , res .result ())
128130 for spec , res in finished_work
129131 if not worker .get_exception (res )]
@@ -139,7 +141,7 @@ def wrapup():
139141 # now that the workers killed pending compilations, make sure the workers
140142 # drained their working queues first - they should all complete quickly
141143 # since the cancellation manager is killing immediately any process starts
142- worker .wait_for (results )
144+ worker .wait_for (self . _current_futures )
143145 worker .wait_for ([wkr .enable () for wkr in self ._worker_pool ])
144146
145147 self ._reset_workers = self ._pool .submit (wrapup )
0 commit comments