diff --git a/config/atari/__init__.py b/config/atari/__init__.py index a49c4d31..4e972ec6 100644 --- a/config/atari/__init__.py +++ b/config/atari/__init__.py @@ -19,7 +19,7 @@ def __init__(self): checkpoint_interval=100, target_model_interval=200, save_ckpt_interval=10000, - max_moves=108000, + max_moves=12000, test_max_moves=12000, history_length=400, discount=0.997, @@ -84,9 +84,9 @@ def __init__(self): def visit_softmax_temperature_fn(self, num_moves, trained_steps): if self.change_temperature: - if trained_steps < 0.5 * (self.training_steps + self.last_steps): + if trained_steps < 0.5 * (self.training_steps): return 1.0 - elif trained_steps < 0.75 * (self.training_steps + self.last_steps): + elif trained_steps < 0.75 * (self.training_steps): return 0.5 else: return 0.25 diff --git a/config/atari/env_wrapper.py b/config/atari/env_wrapper.py index b3510d12..bb79fa2b 100644 --- a/config/atari/env_wrapper.py +++ b/config/atari/env_wrapper.py @@ -21,6 +21,9 @@ def __init__(self, env, discount: float, cvt_string=True): def legal_actions(self): return [_ for _ in range(self.env.action_space.n)] + def get_max_episode_steps(self): + return self.env.get_max_episode_steps() + def step(self, action): observation, reward, done, info = self.env.step(action) observation = observation.astype(np.uint8) diff --git a/core/reanalyze_worker.py b/core/reanalyze_worker.py index a3f3c233..a32f57f3 100644 --- a/core/reanalyze_worker.py +++ b/core/reanalyze_worker.py @@ -344,12 +344,13 @@ def _prepare_reward_value(self, reward_value_context): value_lst = value_lst * np.array(value_mask) value_lst = value_lst.tolist() - horizon_id, value_index = 0, 0 + value_index = 0 for traj_len_non_re, reward_lst, state_index in zip(traj_lens, rewards_lst, state_index_lst): # traj_len = len(game) target_values = [] target_value_prefixs = [] + horizon_id = 0 value_prefix = 0.0 base_index = state_index for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1): diff --git a/core/replay_buffer.py b/core/replay_buffer.py index c7577539..df7844e8 100644 --- a/core/replay_buffer.py +++ b/core/replay_buffer.py @@ -32,7 +32,8 @@ def save_pools(self, pools, gap_step): for (game, priorities) in pools: # Only append end game # if end_tag: - self.save_game(game, True, gap_step, priorities) + if len(game) > 0: + self.save_game(game, True, gap_step, priorities) def save_game(self, game, end_tag, gap_steps, priorities=None): """Save a game history block diff --git a/core/selfplay_worker.py b/core/selfplay_worker.py index 402b05fe..52810311 100644 --- a/core/selfplay_worker.py +++ b/core/selfplay_worker.py @@ -109,7 +109,7 @@ def run(self): model.eval() start_training = False - envs = [self.config.new_game(self.config.seed + self.rank * i) for i in range(env_nums)] + envs = [self.config.new_game(self.config.seed + (self.rank + 1) * i) for i in range(env_nums)] def _get_max_entropy(action_space): p = 1.0 / action_space diff --git a/core/test.py b/core/test.py index fb0f2602..fc1602c6 100644 --- a/core/test.py +++ b/core/test.py @@ -28,7 +28,7 @@ def _test(config, shared_storage): test_model.set_weights(ray.get(shared_storage.get_weights.remote())) test_model.eval() - test_score, _ = test(config, test_model, counter, config.test_episodes, config.device, False, save_video=False) + test_score, eval_steps, _ = test(config, test_model, counter, config.test_episodes, config.device, False, save_video=False) mean_score = test_score.mean() std_score = test_score.std() print('Start evaluation at step {}.'.format(counter)) @@ -44,7 +44,7 @@ def _test(config, shared_storage): } shared_storage.add_test_log.remote(counter, test_log) - print('Step {}, test scores: \n{}'.format(counter, test_score)) + print('Training step {}, test scores: \n{} of {} eval steps.'.format(counter, test_score, eval_steps)) time.sleep(30) @@ -74,20 +74,17 @@ def test(config, model, counter, test_episodes, device, render, save_video=False model.eval() save_path = os.path.join(config.exp_path, 'recordings', 'step_{}'.format(counter)) - if use_pb: - pb = tqdm(np.arange(config.max_moves), leave=True) - with torch.no_grad(): # new games envs = [config.new_game(seed=i, save_video=save_video, save_path=save_path, test=True, final_test=final_test, video_callable=lambda episode_id: True, uid=i) for i in range(test_episodes)] + max_episode_steps = envs[0].get_max_episode_steps() + if use_pb: + pb = tqdm(np.arange(max_episode_steps), leave=True) # initializations init_obses = [env.reset() for env in envs] dones = np.array([False for _ in range(test_episodes)]) - game_histories = [ - GameHistory(envs[_].env.action_space, max_length=config.max_moves, config=config) for - _ in - range(test_episodes)] + game_histories = [GameHistory(envs[_].env.action_space, max_length=max_episode_steps, config=config) for _ in range(test_episodes)] for i in range(test_episodes): game_histories[i].init([init_obses[i] for _ in range(config.stacked_observations)]) @@ -155,4 +152,4 @@ def test(config, model, counter, test_episodes, device, render, save_video=False for env in envs: env.close() - return ep_ori_rewards, save_path + return ep_ori_rewards, step, save_path diff --git a/core/train.py b/core/train.py index 0cb3a22d..a776a627 100644 --- a/core/train.py +++ b/core/train.py @@ -160,9 +160,9 @@ def update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_r other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item() consistency_loss += temp_loss - policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) - value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) - value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) + policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i] + value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i] + value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i] # Follow MuZero, set half gradient hidden_state.register_hook(lambda grad: grad * 0.5) @@ -215,9 +215,9 @@ def update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_r other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item() consistency_loss += temp_loss - policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) - value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) - value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) + policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i] + value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i] + value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i] # Follow MuZero, set half gradient hidden_state.register_hook(lambda grad: grad * 0.5) diff --git a/core/utils.py b/core/utils.py index 9dcff956..de34e19a 100644 --- a/core/utils.py +++ b/core/utils.py @@ -50,6 +50,9 @@ def step(self, ac): info['TimeLimit.truncated'] = True return observation, reward, done, info + def get_max_episode_steps(self): + return self._max_episode_steps + def reset(self, **kwargs): self._elapsed_steps = 0 return self.env.reset(**kwargs) @@ -291,6 +294,8 @@ def select_action(visit_counts, temperature=1, deterministic=True): total_count = sum(action_probs) action_probs = [x / total_count for x in action_probs] if deterministic: + # best_actions = np.argwhere(visit_counts == np.amax(visit_counts)).flatten() + # action_pos = np.random.choice(best_actions) action_pos = np.argmax([v for v in visit_counts]) else: action_pos = np.random.choice(len(visit_counts), p=action_probs) diff --git a/main.py b/main.py index 429933c8..ffde00de 100644 --- a/main.py +++ b/main.py @@ -32,7 +32,7 @@ help='Overrides past results (default: %(default)s)') parser.add_argument('--cpu_actor', type=int, default=14, help='batch cpu actor') parser.add_argument('--gpu_actor', type=int, default=20, help='batch bpu actor') - parser.add_argument('--p_mcts_num', type=int, default=8, help='number of parallel mcts') + parser.add_argument('--p_mcts_num', type=int, default=4, help='number of parallel mcts') parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)') parser.add_argument('--num_gpus', type=int, default=4, help='gpus available') parser.add_argument('--num_cpus', type=int, default=80, help='cpus available') @@ -95,8 +95,7 @@ model, weights = train(game_config, summary_writer, model_path) model.set_weights(weights) total_steps = game_config.training_steps + game_config.last_steps - test_score, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes, - device, render=False, save_video=args.save_video, final_test=True, use_pb=True) + test_score, _, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes, device, render=False, save_video=args.save_video, final_test=True, use_pb=True) mean_score = test_score.mean() std_score = test_score.std() @@ -122,8 +121,7 @@ model = game_config.get_uniform_network().to(device) model.load_state_dict(torch.load(model_path, map_location=torch.device(device))) - test_score, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render, - save_video=args.save_video, final_test=True, use_pb=True) + test_score, _, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render, save_video=args.save_video, final_test=True, use_pb=True) mean_score = test_score.mean() std_score = test_score.std() logging.getLogger('test').info('Test Mean Score: {} (max: {}, min: {})'.format(mean_score, test_score.max(), test_score.min())) diff --git a/requirements.txt b/requirements.txt index 51695dfb..a5c89b7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ numpy==1.19.5 ray==1.0.0 -gym==0.15.7 -atari-py==0.2.6 +gym[atari,roms,accept-rom-license]==0.15.7 cython==0.29.23 tensorboard -opencv-python -kornia +opencv-python==4.5.1.48 +kornia==0.6.6 diff --git a/train.sh b/train.sh index 8b5d6801..b0c984db 100644 --- a/train.sh +++ b/train.sh @@ -5,6 +5,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --force \ --num_gpus 4 --num_cpus 96 --cpu_actor 14 --gpu_actor 20 \ --seed 0 \ + --p_mcts_num 4 \ --use_priority \ --use_max_priority \ --amp_type 'torch_amp' \