3434from multiprocessing .managers import BaseManager
3535import subprocess
3636from rl_coach .graph_managers .graph_manager import HumanPlayScheduleParameters , GraphManager
37- from rl_coach .utils import list_all_presets , short_dynamic_import , get_open_port , SharedMemoryScratchPad , get_base_dir
37+ from rl_coach .utils import list_all_presets , short_dynamic_import , get_open_port , SharedMemoryScratchPad , \
38+ get_base_dir , start_multi_threaded_learning
3839from rl_coach .graph_managers .basic_rl_graph_manager import BasicRLGraphManager
3940from rl_coach .environments .environment import SingleLevelSelection
4041from rl_coach .memories .backend .redis import RedisPubSubMemoryBackendParameters
@@ -87,6 +88,17 @@ def start_graph(graph_manager: 'GraphManager', task_parameters: 'TaskParameters'
8788
8889def handle_distributed_coach_tasks (graph_manager , args , task_parameters ):
8990 ckpt_inside_container = "/checkpoint"
91+ non_dist_task_parameters = TaskParameters (
92+ framework_type = args .framework ,
93+ evaluate_only = args .evaluate ,
94+ experiment_path = args .experiment_path ,
95+ seed = args .seed ,
96+ use_cpu = args .use_cpu ,
97+ checkpoint_save_secs = args .checkpoint_save_secs ,
98+ checkpoint_save_dir = args .checkpoint_save_dir ,
99+ export_onnx_graph = args .export_onnx_graph ,
100+ apply_stop_condition = args .apply_stop_condition
101+ )
90102
91103 memory_backend_params = None
92104 if args .memory_backend_params :
@@ -102,15 +114,18 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
102114 graph_manager .data_store_params = data_store_params
103115
104116 if args .distributed_coach_run_type == RunType .TRAINER :
117+ if not args .distributed_training :
118+ task_parameters = non_dist_task_parameters
105119 task_parameters .checkpoint_save_dir = ckpt_inside_container
106120 training_worker (
107121 graph_manager = graph_manager ,
108122 task_parameters = task_parameters ,
123+ args = args ,
109124 is_multi_node_test = args .is_multi_node_test
110125 )
111126
112127 if args .distributed_coach_run_type == RunType .ROLLOUT_WORKER :
113- task_parameters .checkpoint_restore_dir = ckpt_inside_container
128+ non_dist_task_parameters .checkpoint_restore_dir = ckpt_inside_container
114129
115130 data_store = None
116131 if args .data_store_params :
@@ -120,7 +135,7 @@ def handle_distributed_coach_tasks(graph_manager, args, task_parameters):
120135 graph_manager = graph_manager ,
121136 data_store = data_store ,
122137 num_workers = args .num_workers ,
123- task_parameters = task_parameters
138+ task_parameters = non_dist_task_parameters
124139 )
125140
126141
@@ -552,6 +567,11 @@ def get_argument_parser(self) -> argparse.ArgumentParser:
552567 parser .add_argument ('-dc' , '--distributed_coach' ,
553568 help = "(flag) Use distributed Coach." ,
554569 action = 'store_true' )
570+ parser .add_argument ('-dt' , '--distributed_training' ,
571+ help = "(flag) Use distributed training with Coach."
572+ "Used only with --distributed_coach flag."
573+ "Ignored if --distributed_coach flag is not used." ,
574+ action = 'store_true' )
555575 parser .add_argument ('-dcp' , '--distributed_coach_config_path' ,
556576 help = "(string) Path to config file when using distributed rollout workers."
557577 "Only distributed Coach parameters should be provided through this config file."
@@ -607,18 +627,31 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
607627 atexit .register (logger .summarize_experiment )
608628 screen .change_terminal_title (args .experiment_name )
609629
610- task_parameters = TaskParameters (
611- framework_type = args .framework ,
612- evaluate_only = args .evaluate ,
613- experiment_path = args .experiment_path ,
614- seed = args .seed ,
615- use_cpu = args .use_cpu ,
616- checkpoint_save_secs = args .checkpoint_save_secs ,
617- checkpoint_restore_dir = args .checkpoint_restore_dir ,
618- checkpoint_save_dir = args .checkpoint_save_dir ,
619- export_onnx_graph = args .export_onnx_graph ,
620- apply_stop_condition = args .apply_stop_condition
621- )
630+ if args .num_workers == 1 :
631+ task_parameters = TaskParameters (
632+ framework_type = args .framework ,
633+ evaluate_only = args .evaluate ,
634+ experiment_path = args .experiment_path ,
635+ seed = args .seed ,
636+ use_cpu = args .use_cpu ,
637+ checkpoint_save_secs = args .checkpoint_save_secs ,
638+ checkpoint_restore_dir = args .checkpoint_restore_dir ,
639+ checkpoint_save_dir = args .checkpoint_save_dir ,
640+ export_onnx_graph = args .export_onnx_graph ,
641+ apply_stop_condition = args .apply_stop_condition
642+ )
643+ else :
644+ task_parameters = DistributedTaskParameters (
645+ framework_type = args .framework ,
646+ use_cpu = args .use_cpu ,
647+ num_training_tasks = args .num_workers ,
648+ experiment_path = args .experiment_path ,
649+ checkpoint_save_secs = args .checkpoint_save_secs ,
650+ checkpoint_restore_dir = args .checkpoint_restore_dir ,
651+ checkpoint_save_dir = args .checkpoint_save_dir ,
652+ export_onnx_graph = args .export_onnx_graph ,
653+ apply_stop_condition = args .apply_stop_condition
654+ )
622655
623656 # open dashboard
624657 if args .open_dashboard :
@@ -633,78 +666,16 @@ def run_graph_manager(self, graph_manager: 'GraphManager', args: argparse.Namesp
633666
634667 # Single-threaded runs
635668 if args .num_workers == 1 :
636- self .start_single_threaded (task_parameters , graph_manager , args )
669+ self .start_single_threaded_learning (task_parameters , graph_manager , args )
637670 else :
638- self .start_multi_threaded (graph_manager , args )
671+ global start_graph
672+ start_multi_threaded_learning (start_graph , (graph_manager , task_parameters ),
673+ task_parameters , graph_manager , args )
639674
640- def start_single_threaded (self , task_parameters , graph_manager : 'GraphManager' , args : argparse .Namespace ):
675+ def start_single_threaded_learning (self , task_parameters , graph_manager : 'GraphManager' , args : argparse .Namespace ):
641676 # Start the training or evaluation
642677 start_graph (graph_manager = graph_manager , task_parameters = task_parameters )
643678
644- def start_multi_threaded (self , graph_manager : 'GraphManager' , args : argparse .Namespace ):
645- total_tasks = args .num_workers
646- if args .evaluation_worker :
647- total_tasks += 1
648-
649- ps_hosts = "localhost:{}" .format (get_open_port ())
650- worker_hosts = "," .join (["localhost:{}" .format (get_open_port ()) for i in range (total_tasks )])
651-
652- # Shared memory
653- class CommManager (BaseManager ):
654- pass
655- CommManager .register ('SharedMemoryScratchPad' , SharedMemoryScratchPad , exposed = ['add' , 'get' , 'internal_call' ])
656- comm_manager = CommManager ()
657- comm_manager .start ()
658- shared_memory_scratchpad = comm_manager .SharedMemoryScratchPad ()
659-
660- def start_distributed_task (job_type , task_index , evaluation_worker = False ,
661- shared_memory_scratchpad = shared_memory_scratchpad ):
662- task_parameters = DistributedTaskParameters (
663- framework_type = args .framework ,
664- parameters_server_hosts = ps_hosts ,
665- worker_hosts = worker_hosts ,
666- job_type = job_type ,
667- task_index = task_index ,
668- evaluate_only = 0 if evaluation_worker else None , # 0 value for evaluation worker as it should run infinitely
669- use_cpu = args .use_cpu ,
670- num_tasks = total_tasks , # training tasks + 1 evaluation task
671- num_training_tasks = args .num_workers ,
672- experiment_path = args .experiment_path ,
673- shared_memory_scratchpad = shared_memory_scratchpad ,
674- seed = args .seed + task_index if args .seed is not None else None , # each worker gets a different seed
675- checkpoint_save_secs = args .checkpoint_save_secs ,
676- checkpoint_restore_dir = args .checkpoint_restore_dir ,
677- checkpoint_save_dir = args .checkpoint_save_dir ,
678- export_onnx_graph = args .export_onnx_graph ,
679- apply_stop_condition = args .apply_stop_condition
680- )
681- # we assume that only the evaluation workers are rendering
682- graph_manager .visualization_parameters .render = args .render and evaluation_worker
683- p = Process (target = start_graph , args = (graph_manager , task_parameters ))
684- # p.daemon = True
685- p .start ()
686- return p
687-
688- # parameter server
689- parameter_server = start_distributed_task ("ps" , 0 )
690-
691- # training workers
692- # wait a bit before spawning the non chief workers in order to make sure the session is already created
693- workers = []
694- workers .append (start_distributed_task ("worker" , 0 ))
695- time .sleep (2 )
696- for task_index in range (1 , args .num_workers ):
697- workers .append (start_distributed_task ("worker" , task_index ))
698-
699- # evaluation worker
700- if args .evaluation_worker or args .render :
701- evaluation_worker = start_distributed_task ("worker" , args .num_workers , evaluation_worker = True )
702-
703- # wait for all workers
704- [w .join () for w in workers ]
705- if args .evaluation_worker :
706- evaluation_worker .terminate ()
707-
708679
709680def main ():
710681 launcher = CoachLauncher ()
0 commit comments