@@ -57,6 +57,50 @@ def test_load_and_set(dummy_config, use_discrete):
5757 np .testing .assert_array_equal (w , lw )
5858
5959
60+ def test_resume (dummy_config , tmp_path ):
61+ mock_specs = mb .setup_test_behavior_specs (
62+ True , False , vector_action_space = [2 ], vector_obs_space = 1
63+ )
64+ behavior_id_team0 = "test_brain?team=0"
65+ behavior_id_team1 = "test_brain?team=1"
66+ brain_name = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team0 ).brain_name
67+ tmp_path = tmp_path .as_posix ()
68+ ppo_trainer = PPOTrainer (brain_name , 0 , dummy_config , True , False , 0 , tmp_path )
69+ controller = GhostController (100 )
70+ trainer = GhostTrainer (
71+ ppo_trainer , brain_name , controller , 0 , dummy_config , True , tmp_path
72+ )
73+
74+ parsed_behavior_id0 = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team0 )
75+ policy = trainer .create_policy (parsed_behavior_id0 , mock_specs )
76+ trainer .add_policy (parsed_behavior_id0 , policy )
77+
78+ parsed_behavior_id1 = BehaviorIdentifiers .from_name_behavior_id (behavior_id_team1 )
79+ policy = trainer .create_policy (parsed_behavior_id1 , mock_specs )
80+ trainer .add_policy (parsed_behavior_id1 , policy )
81+
82+ trainer .save_model ()
83+
84+ # Make a new trainer, check that the policies are the same
85+ ppo_trainer2 = PPOTrainer (brain_name , 0 , dummy_config , True , True , 0 , tmp_path )
86+ trainer2 = GhostTrainer (
87+ ppo_trainer2 , brain_name , controller , 0 , dummy_config , True , tmp_path
88+ )
89+ policy = trainer2 .create_policy (parsed_behavior_id0 , mock_specs )
90+ trainer2 .add_policy (parsed_behavior_id0 , policy )
91+
92+ policy = trainer2 .create_policy (parsed_behavior_id1 , mock_specs )
93+ trainer2 .add_policy (parsed_behavior_id1 , policy )
94+
95+ trainer1_policy = trainer .get_policy (parsed_behavior_id1 .behavior_id )
96+ trainer2_policy = trainer2 .get_policy (parsed_behavior_id1 .behavior_id )
97+ weights = trainer1_policy .get_weights ()
98+ weights2 = trainer2_policy .get_weights ()
99+
100+ for w , lw in zip (weights , weights2 ):
101+ np .testing .assert_array_equal (w , lw )
102+
103+
60104def test_process_trajectory (dummy_config ):
61105 mock_specs = mb .setup_test_behavior_specs (
62106 True , False , vector_action_space = [2 ], vector_obs_space = 1
0 commit comments