Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config/atari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
checkpoint_interval=100,
target_model_interval=200,
save_ckpt_interval=10000,
max_moves=108000,
max_moves=12000,
test_max_moves=12000,
history_length=400,
discount=0.997,
Expand Down Expand Up @@ -84,9 +84,9 @@ def __init__(self):

def visit_softmax_temperature_fn(self, num_moves, trained_steps):
if self.change_temperature:
if trained_steps < 0.5 * (self.training_steps + self.last_steps):
if trained_steps < 0.5 * (self.training_steps):
return 1.0
elif trained_steps < 0.75 * (self.training_steps + self.last_steps):
elif trained_steps < 0.75 * (self.training_steps):
return 0.5
else:
return 0.25
Expand Down
3 changes: 3 additions & 0 deletions config/atari/env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self, env, discount: float, cvt_string=True):
def legal_actions(self):
return [_ for _ in range(self.env.action_space.n)]

def get_max_episode_steps(self):
return self.env.get_max_episode_steps()

def step(self, action):
observation, reward, done, info = self.env.step(action)
observation = observation.astype(np.uint8)
Expand Down
3 changes: 2 additions & 1 deletion core/reanalyze_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,13 @@ def _prepare_reward_value(self, reward_value_context):
value_lst = value_lst * np.array(value_mask)
value_lst = value_lst.tolist()

horizon_id, value_index = 0, 0
value_index = 0
for traj_len_non_re, reward_lst, state_index in zip(traj_lens, rewards_lst, state_index_lst):
# traj_len = len(game)
target_values = []
target_value_prefixs = []

horizon_id = 0
value_prefix = 0.0
base_index = state_index
for current_index in range(state_index, state_index + self.config.num_unroll_steps + 1):
Expand Down
3 changes: 2 additions & 1 deletion core/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def save_pools(self, pools, gap_step):
for (game, priorities) in pools:
# Only append end game
# if end_tag:
self.save_game(game, True, gap_step, priorities)
if len(game) > 0:
self.save_game(game, True, gap_step, priorities)

def save_game(self, game, end_tag, gap_steps, priorities=None):
"""Save a game history block
Expand Down
2 changes: 1 addition & 1 deletion core/selfplay_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def run(self):
model.eval()

start_training = False
envs = [self.config.new_game(self.config.seed + self.rank * i) for i in range(env_nums)]
envs = [self.config.new_game(self.config.seed + (self.rank + 1) * i) for i in range(env_nums)]

def _get_max_entropy(action_space):
p = 1.0 / action_space
Expand Down
17 changes: 7 additions & 10 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _test(config, shared_storage):
test_model.set_weights(ray.get(shared_storage.get_weights.remote()))
test_model.eval()

test_score, _ = test(config, test_model, counter, config.test_episodes, config.device, False, save_video=False)
test_score, eval_steps, _ = test(config, test_model, counter, config.test_episodes, config.device, False, save_video=False)
mean_score = test_score.mean()
std_score = test_score.std()
print('Start evaluation at step {}.'.format(counter))
Expand All @@ -44,7 +44,7 @@ def _test(config, shared_storage):
}

shared_storage.add_test_log.remote(counter, test_log)
print('Step {}, test scores: \n{}'.format(counter, test_score))
print('Training step {}, test scores: \n{} of {} eval steps.'.format(counter, test_score, eval_steps))

time.sleep(30)

Expand Down Expand Up @@ -74,20 +74,17 @@ def test(config, model, counter, test_episodes, device, render, save_video=False
model.eval()
save_path = os.path.join(config.exp_path, 'recordings', 'step_{}'.format(counter))

if use_pb:
pb = tqdm(np.arange(config.max_moves), leave=True)

with torch.no_grad():
# new games
envs = [config.new_game(seed=i, save_video=save_video, save_path=save_path, test=True, final_test=final_test,
video_callable=lambda episode_id: True, uid=i) for i in range(test_episodes)]
max_episode_steps = envs[0].get_max_episode_steps()
if use_pb:
pb = tqdm(np.arange(max_episode_steps), leave=True)
# initializations
init_obses = [env.reset() for env in envs]
dones = np.array([False for _ in range(test_episodes)])
game_histories = [
GameHistory(envs[_].env.action_space, max_length=config.max_moves, config=config) for
_ in
range(test_episodes)]
game_histories = [GameHistory(envs[_].env.action_space, max_length=max_episode_steps, config=config) for _ in range(test_episodes)]
for i in range(test_episodes):
game_histories[i].init([init_obses[i] for _ in range(config.stacked_observations)])

Expand Down Expand Up @@ -155,4 +152,4 @@ def test(config, model, counter, test_episodes, device, render, save_video=False
for env in envs:
env.close()

return ep_ori_rewards, save_path
return ep_ori_rewards, step, save_path
12 changes: 6 additions & 6 deletions core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_r
other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item()
consistency_loss += temp_loss

policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1)
value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1])
value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i])
policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i]
value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i]
value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i]
# Follow MuZero, set half gradient
hidden_state.register_hook(lambda grad: grad * 0.5)

Expand Down Expand Up @@ -215,9 +215,9 @@ def update_weights(model, batch, optimizer, replay_buffer, config, scaler, vis_r
other_loss['consist_' + str(step_i + 1)] = temp_loss.mean().item()
consistency_loss += temp_loss

policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1)
value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1])
value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i])
policy_loss += -(torch.log_softmax(policy_logits, dim=1) * target_policy[:, step_i + 1]).sum(1) * mask_batch[:, step_i]
value_loss += config.scalar_value_loss(value, target_value_phi[:, step_i + 1]) * mask_batch[:, step_i]
value_prefix_loss += config.scalar_reward_loss(value_prefix, target_value_prefix_phi[:, step_i]) * mask_batch[:, step_i]
# Follow MuZero, set half gradient
hidden_state.register_hook(lambda grad: grad * 0.5)

Expand Down
5 changes: 5 additions & 0 deletions core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def step(self, ac):
info['TimeLimit.truncated'] = True
return observation, reward, done, info

def get_max_episode_steps(self):
return self._max_episode_steps

def reset(self, **kwargs):
self._elapsed_steps = 0
return self.env.reset(**kwargs)
Expand Down Expand Up @@ -291,6 +294,8 @@ def select_action(visit_counts, temperature=1, deterministic=True):
total_count = sum(action_probs)
action_probs = [x / total_count for x in action_probs]
if deterministic:
# best_actions = np.argwhere(visit_counts == np.amax(visit_counts)).flatten()
# action_pos = np.random.choice(best_actions)
action_pos = np.argmax([v for v in visit_counts])
else:
action_pos = np.random.choice(len(visit_counts), p=action_probs)
Expand Down
8 changes: 3 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
help='Overrides past results (default: %(default)s)')
parser.add_argument('--cpu_actor', type=int, default=14, help='batch cpu actor')
parser.add_argument('--gpu_actor', type=int, default=20, help='batch bpu actor')
parser.add_argument('--p_mcts_num', type=int, default=8, help='number of parallel mcts')
parser.add_argument('--p_mcts_num', type=int, default=4, help='number of parallel mcts')
parser.add_argument('--seed', type=int, default=0, help='seed (default: %(default)s)')
parser.add_argument('--num_gpus', type=int, default=4, help='gpus available')
parser.add_argument('--num_cpus', type=int, default=80, help='cpus available')
Expand Down Expand Up @@ -95,8 +95,7 @@
model, weights = train(game_config, summary_writer, model_path)
model.set_weights(weights)
total_steps = game_config.training_steps + game_config.last_steps
test_score, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes,
device, render=False, save_video=args.save_video, final_test=True, use_pb=True)
test_score, _, test_path = test(game_config, model.to(device), total_steps, game_config.test_episodes, device, render=False, save_video=args.save_video, final_test=True, use_pb=True)
mean_score = test_score.mean()
std_score = test_score.std()

Expand All @@ -122,8 +121,7 @@

model = game_config.get_uniform_network().to(device)
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
test_score, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render,
save_video=args.save_video, final_test=True, use_pb=True)
test_score, _, test_path = test(game_config, model, 0, args.test_episodes, device=device, render=args.render, save_video=args.save_video, final_test=True, use_pb=True)
mean_score = test_score.mean()
std_score = test_score.std()
logging.getLogger('test').info('Test Mean Score: {} (max: {}, min: {})'.format(mean_score, test_score.max(), test_score.min()))
Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
numpy==1.19.5
ray==1.0.0
gym==0.15.7
atari-py==0.2.6
gym[atari,roms,accept-rom-license]==0.15.7
cython==0.29.23
tensorboard
opencv-python
kornia
opencv-python==4.5.1.48
kornia==0.6.6
1 change: 1 addition & 0 deletions train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --force \
--num_gpus 4 --num_cpus 96 --cpu_actor 14 --gpu_actor 20 \
--seed 0 \
--p_mcts_num 4 \
--use_priority \
--use_max_priority \
--amp_type 'torch_amp' \
Expand Down