diff --git a/README.md b/README.md index 4a4a821..41c685d 100644 --- a/README.md +++ b/README.md @@ -19,19 +19,20 @@ and it should be good to go. To train a policy, run the following command: -``python run.py --env_config data/envs/dm_cheetah.yaml --agent_config a2/dm_cheetah_cem_agent.yaml --mode train --log_file output/log.txt --out_model_file output/model.pt --visualize`` +``python run.py --env_config data/envs/dm_cheetah.yaml --agent_config a2/dm_cheetah_cem_agent.yaml --mode train --log_file output/log.txt --out_model_file output/model.pt --visualize --num_workers 1`` - `--env_config` specifies the configuration file for the environment. - `--agent_config` specifies configuration file for the agent. - `--visualize` enables visualization. Rendering should be disabled for faster training. - `--log_file` specifies the output log file, which will keep track of statistics during training. - `--out_model_file` specifies the output model file, which contains the model parameters. +- `--num_workers` specifies the number of worker processes used to parallelize training. ## Test Models To load a trained model, run the following command: -``python run.py --env_config data/envs/dm_cheetah_env.yaml --agent_config a2/dm_cheetah_cem_agent.yaml --mode test --model_file data/models/dm_cheetah_ppo_model.pt --visualize`` +``python run.py --env_config data/envs/dm_cheetah_env.yaml --agent_config a2/dm_cheetah_cem_agent.yaml --mode test --model_file data/models/dm_cheetah_ppo_model.pt --visualize --num_workers 1`` - `--model_file` specifies the `.pt` file that contains parameters for the trained model. Pretrained models are available in `data/models/`. diff --git a/a1/bc_agent.py b/a1/bc_agent.py index d68d35c..e75ba2a 100644 --- a/a1/bc_agent.py +++ b/a1/bc_agent.py @@ -4,6 +4,7 @@ import learning.agent_builder as agent_builder import learning.base_agent as base_agent import learning.bc_model as bc_model +import util.mp_util as mp_util import util.torch_util as torch_util class BCAgent(base_agent.BaseAgent): @@ -15,11 +16,16 @@ def __init__(self, config, env, device): def _load_params(self, config): super()._load_params(config) - - buffer_size = config["exp_buffer_size"] - self._exp_buffer_length = max(buffer_size, self._steps_per_iter) + + num_procs = mp_util.get_num_procs() self._batch_size = config["batch_size"] + self._batch_size = int(np.ceil(self._batch_size / num_procs)) + + buffer_size = config["exp_buffer_size"] + self._exp_buffer_length = int(np.ceil(buffer_size / num_procs)) + self._exp_buffer_length = max(self._exp_buffer_length, self._steps_per_iter) + self._update_epochs = config["update_epochs"] return @@ -81,10 +87,8 @@ def _update_model(self): batch = self._exp_buffer.sample(batch_size) loss_info = self._compute_loss(batch) - self._optimizer.zero_grad() loss = loss_info["loss"] - loss.backward() - self._optimizer.step() + self._optimizer.step(loss) torch_util.add_torch_dict(loss_info, train_info) diff --git a/a1/dm_cheetah_bc_agent.yaml b/a1/dm_cheetah_bc_agent.yaml index a65d100..efb6aa9 100644 --- a/a1/dm_cheetah_bc_agent.yaml +++ b/a1/dm_cheetah_bc_agent.yaml @@ -6,6 +6,10 @@ model: actor_std_type: "FIXED" action_std: 0.2 +optimizer: + type: "SGD" + learning_rate: 5e-4 + expert_config: "data/agents/bc_expert_agent.yaml" expert_model_file: "data/models/dm_cheetah_expert_model.pt" @@ -14,7 +18,6 @@ steps_per_iter: 100 iters_per_output: 10 test_episodes: 10 -learning_rate: 5e-4 update_epochs: 20 exp_buffer_size: 1000000 diff --git a/a1/dm_walker_bc_agent.yaml b/a1/dm_walker_bc_agent.yaml index 910d130..87a6647 100644 --- a/a1/dm_walker_bc_agent.yaml +++ b/a1/dm_walker_bc_agent.yaml @@ -5,7 +5,11 @@ model: actor_init_output_scale: 0.01 actor_std_type: "FIXED" action_std: 0.2 - + +optimizer: + type: "SGD" + learning_rate: 5e-4 + expert_config: "data/agents/bc_expert_agent.yaml" expert_model_file: "data/models/dm_walker_run_expert_model.pt" @@ -14,7 +18,6 @@ steps_per_iter: 100 iters_per_output: 10 test_episodes: 10 -learning_rate: 5e-4 update_epochs: 20 exp_buffer_size: 1000000 diff --git a/a2/cem_agent.py b/a2/cem_agent.py index 8823bc3..cb1ba00 100644 --- a/a2/cem_agent.py +++ b/a2/cem_agent.py @@ -3,6 +3,7 @@ import learning.base_agent as base_agent import learning.cem_model as cem_model +import util.mp_util as mp_util class CEMAgent(base_agent.BaseAgent): NAME = "CEM" @@ -57,10 +58,16 @@ def _update_sample_count(self): return self._sample_count def _train_iter(self): - candidates = self._sample_candidates(self._population_size) + num_procs = mp_util.get_num_procs() + num_candidates = int(np.ceil(self._population_size / num_procs)) + candidates = self._sample_candidates(num_candidates) rets, ep_lens = self._eval_candidates(candidates) - + candidates, rets, ep_lens = self._gather_candidates(candidates, rets, ep_lens) + + rets = rets.cpu().numpy() + ep_lens = ep_lens.cpu().numpy() + curr_best_idx = np.argmax(rets) curr_best_ret = rets[curr_best_idx] @@ -82,6 +89,8 @@ def _train_iter(self): num_eps = self._population_size * self._eps_per_candidate mean_param_std = torch.mean(new_std) + assert(self._check_synced()), "Network parameters desynchronized" + train_info = { "mean_return": train_return, "mean_ep_len": train_ep_len, @@ -90,6 +99,29 @@ def _train_iter(self): } return train_info + def _gather_candidates(self, candidates, rets, ep_lens): + candidates = mp_util.all_gather(candidates) + rets = mp_util.all_gather(rets) + ep_lens = mp_util.all_gather(ep_lens) + + candidates = torch.cat(candidates, dim=0) + rets = torch.cat(rets, dim=0) + ep_lens = torch.cat(ep_lens, dim=0) + + return candidates, rets, ep_lens + + def _check_synced(self): + global_param_mean = mp_util.broadcast(self._param_mean) + global_param_std = mp_util.broadcast(self._param_std) + global_best_params = mp_util.broadcast(self._best_params) + + synced = torch.equal(global_param_mean, self._param_mean) + synced &= torch.equal(global_param_std, self._param_std) + synced &= torch.equal(global_best_params, self._best_params) + + return synced + + def _sample_candidates(self, n): ''' @@ -118,8 +150,8 @@ def _eval_candidates(self, candidates): n = candidates.shape[0] # placeholder - rets = np.zeros(n) - ep_lens = np.zeros(n) + rets = torch.zeros(n, device=self._device) + ep_lens = torch.zeros(n, device=self._device) return rets, ep_lens diff --git a/a2/dm_cheetah_pg_agent.yaml b/a2/dm_cheetah_pg_agent.yaml index 4225d3c..048c25f 100644 --- a/a2/dm_cheetah_pg_agent.yaml +++ b/a2/dm_cheetah_pg_agent.yaml @@ -8,17 +8,20 @@ model: critic_net: "fc_2layers_128units" - +actor_optimizer: + type: "SGD" + learning_rate: 2e-3 + +critic_optimizer: + type: "SGD" + learning_rate: 2e-3 + discount: 0.99 steps_per_iter: 4096 iters_per_output: 50 test_episodes: 32 critic_update_epoch: 5 -# optimizer parameters -actor_learning_rate: 2e-3 -critic_learning_rate: 2e-3 - batch_size: 512 norm_adv_clip: 4.0 action_bound_weight: 10.0 \ No newline at end of file diff --git a/a2/pg_agent.py b/a2/pg_agent.py index 311b6dd..4058eb6 100644 --- a/a2/pg_agent.py +++ b/a2/pg_agent.py @@ -3,7 +3,9 @@ import envs.base_env as base_env import learning.base_agent as base_agent +import learning.mp_optimizer as mp_optimizer import learning.pg_model as pg_model +import util.mp_util as mp_util import util.torch_util as torch_util class PGAgent(base_agent.BaseAgent): @@ -16,7 +18,11 @@ def __init__(self, config, env, device): def _load_params(self, config): super()._load_params(config) + num_procs = mp_util.get_num_procs() + self._batch_size = config["batch_size"] + self._batch_size = int(np.ceil(self._batch_size / num_procs)) + self._critic_update_epoch = config["critic_update_epoch"] self._norm_adv_clip = config["norm_adv_clip"] self._action_bound_weight = config["action_bound_weight"] @@ -56,16 +62,15 @@ def _init_iter(self): return def _build_optimizer(self, config): - actor_lr = float(config["actor_learning_rate"]) - critic_lr = float(config["critic_learning_rate"]) - + actor_opt_config = config["actor_optimizer"] actor_params = list(self._model._actor_layers.parameters())+list(self._model._action_dist.parameters()) actor_params_grad = [p for p in actor_params if p.requires_grad] - self._actor_optimizer = torch.optim.SGD(actor_params_grad, actor_lr, momentum=0.9) - + self._actor_optimizer = mp_optimizer.MPOptimizer(actor_opt_config, actor_params_grad) + + critic_opt_config = config["critic_optimizer"] critic_params = list(self._model._critic_layers.parameters())+list(self._model._critic_out.parameters()) critic_params_grad = [p for p in critic_params if p.requires_grad] - self._critic_optimizer = torch.optim.SGD(critic_params_grad, critic_lr, momentum=0.9) + self._critic_optimizer = mp_optimizer.MPOptimizer(critic_opt_config, critic_params_grad) return @@ -152,9 +157,7 @@ def _update_critic(self, batch): loss = self._calc_critic_loss(norm_obs, tar_val) - self._critic_optimizer.zero_grad() - loss.backward() - self._critic_optimizer.step() + self._critic_optimizer.step(loss) info = { "critic_loss": loss @@ -179,13 +182,13 @@ def _update_actor(self, batch): action_bound_loss = torch.mean(action_bound_loss) loss += self._action_bound_weight * action_bound_loss info["action_bound_loss"] = action_bound_loss.detach() - - self._actor_optimizer.zero_grad() - loss.backward() - self._actor_optimizer.step() + + self._actor_optimizer.step(loss) return info + + def _calc_return(self, r, done): ''' TODO 2.1: Given a tensor of per-timestep rewards (r), and a tensor (done) diff --git a/a3/atari_breakout_dqn_agent.yaml b/a3/atari_breakout_dqn_agent.yaml index f72f615..0b78493 100644 --- a/a3/atari_breakout_dqn_agent.yaml +++ b/a3/atari_breakout_dqn_agent.yaml @@ -10,8 +10,9 @@ iters_per_output: 50000 test_episodes: 10 normalizer_samples: 0 -# optimizer parameters -learning_rate: 5e-4 +optimizer: + type: "SGD" + learning_rate: 5e-4 exp_buffer_size: 200000 updates_per_iter: 1 diff --git a/a3/atari_pong_dqn_agent.yaml b/a3/atari_pong_dqn_agent.yaml index f72f615..0b78493 100644 --- a/a3/atari_pong_dqn_agent.yaml +++ b/a3/atari_pong_dqn_agent.yaml @@ -10,8 +10,9 @@ iters_per_output: 50000 test_episodes: 10 normalizer_samples: 0 -# optimizer parameters -learning_rate: 5e-4 +optimizer: + type: "SGD" + learning_rate: 5e-4 exp_buffer_size: 200000 updates_per_iter: 1 diff --git a/a3/dqn_agent.py b/a3/dqn_agent.py index 66d1028..aa1c6c0 100644 --- a/a3/dqn_agent.py +++ b/a3/dqn_agent.py @@ -5,6 +5,7 @@ import envs.base_env as base_env import learning.base_agent as base_agent import learning.dqn_model as dqn_model +import util.mp_util as mp_util import util.torch_util as torch_util class DQNAgent(base_agent.BaseAgent): @@ -12,18 +13,26 @@ class DQNAgent(base_agent.BaseAgent): def __init__(self, config, env, device): super().__init__(config, env, device) + + self._sync_tar_model() return def _load_params(self, config): super()._load_params(config) + + num_procs = mp_util.get_num_procs() buffer_size = config["exp_buffer_size"] - self._exp_buffer_length = int(buffer_size) + self._exp_buffer_length = int(np.ceil(buffer_size / (num_procs))) self._exp_buffer_length = max(self._exp_buffer_length, self._steps_per_iter) - self._updates_per_iter = config["updates_per_iter"] self._batch_size = config["batch_size"] + self._batch_size = int(np.ceil(self._batch_size / num_procs)) + self._init_samples = config["init_samples"] + self._init_samples = int(np.ceil(self._init_samples / num_procs)) + + self._updates_per_iter = config["updates_per_iter"] self._tar_net_update_iters = config["tar_net_update_iters"] self._exp_anneal_samples = config.get("exp_anneal_samples", np.inf) @@ -40,7 +49,6 @@ def _build_model(self, config): for param in self._tar_model.parameters(): param.requires_grad = False - self._sync_tar_model() return def _get_exp_buffer_length(self): @@ -80,10 +88,8 @@ def _update_model(self): batch = self._exp_buffer.sample(self._batch_size) loss_info = self._compute_loss(batch) - self._optimizer.zero_grad() loss = loss_info["loss"] - loss.backward() - self._optimizer.step() + self._optimizer.step(loss) torch_util.add_torch_dict(loss_info, train_info) diff --git a/data/agents/bc_expert_agent.yaml b/data/agents/bc_expert_agent.yaml index 50a3a2f..7063fdd 100644 --- a/data/agents/bc_expert_agent.yaml +++ b/data/agents/bc_expert_agent.yaml @@ -8,11 +8,11 @@ model: critic_net: "fc_2layers_128units" - +optimizer: + type: "SGD" + learning_rate: 1e-3 + discount: 0.99 steps_per_iter: 4096 iters_per_output: 10 -test_episodes: 32 - -# optimizer parameters -learning_rate: 1e-3 \ No newline at end of file +test_episodes: 32 \ No newline at end of file diff --git a/learning/base_agent.py b/learning/base_agent.py index 74a1fd1..e94c909 100644 --- a/learning/base_agent.py +++ b/learning/base_agent.py @@ -8,8 +8,10 @@ import envs.base_env as base_env import learning.experience_buffer as experience_buffer +import learning.mp_optimizer as mp_optimizer import learning.normalizer as normalizer import learning.return_tracker as return_tracker +import util.mp_util as mp_util import util.tb_logger as tb_logger import util.torch_util as torch_util @@ -79,9 +81,12 @@ def train_model(self, max_samples, out_model_file, int_output_dir, log_file): def test_model(self, num_episodes): self.eval() self.set_mode(AgentMode.TEST) + + num_procs = mp_util.get_num_procs() + num_eps_proc = int(np.ceil(num_episodes / num_procs)) self._curr_obs, self._curr_info = self._env.reset() - test_info = self._rollout_test(num_episodes) + test_info = self._rollout_test(num_eps_proc) return test_info @@ -106,8 +111,9 @@ def set_mode(self, mode): return def save(self, out_file): - state_dict = self.state_dict() - torch.save(state_dict, out_file) + if (mp_util.is_root_proc()): + state_dict = self.state_dict() + torch.save(state_dict, out_file) return def load(self, in_file): @@ -116,11 +122,16 @@ def load(self, in_file): return def _load_params(self, config): + num_procs = mp_util.get_num_procs() + self._discount = config["discount"] - self._steps_per_iter = config["steps_per_iter"] self._iters_per_output = config["iters_per_output"] - self._test_episodes = config["test_episodes"] self._normalizer_samples = config.get("normalizer_samples", np.inf) + self._test_episodes = config["test_episodes"] + + self._steps_per_iter = config["steps_per_iter"] + self._steps_per_iter = int(np.ceil(self._steps_per_iter / num_procs)) + return @abc.abstractmethod @@ -153,10 +164,10 @@ def _build_action_normalizer(self): return a_norm def _build_optimizer(self, config): - lr = float(config["learning_rate"]) + opt_config = config["optimizer"] params = list(self.parameters()) params = [p for p in params if p.requires_grad] - self._optimizer = torch.optim.SGD(params, lr, momentum=0.9) + self._optimizer = mp_optimizer.MPOptimizer(opt_config, params) return def _build_exp_buffer(self, config): @@ -196,11 +207,14 @@ def _get_exp_buffer_length(self): def _build_logger(self, log_file): log = tb_logger.TBLogger() log.set_step_key("Samples") - log.configure_output_file(log_file) + if (mp_util.is_root_proc()): + log.configure_output_file(log_file) return log def _update_sample_count(self): - return self._exp_buffer.get_total_samples() + sample_count = self._exp_buffer.get_total_samples() + sample_count = mp_util.reduce_sum(sample_count) + return sample_count def _init_train(self): self._iter = 0 @@ -352,6 +366,8 @@ def _log_train_info(self, train_info, test_info, start_time): test_return = test_info["mean_return"] test_ep_len = test_info["mean_ep_len"] test_eps = test_info["episodes"] + test_eps = mp_util.reduce_sum(test_eps) + self._logger.log("Test_Return", test_return, collection="0_Main") self._logger.log("Test_Episode_Length", test_ep_len, collection="0_Main", quiet=True) self._logger.log("Test_Episodes", test_eps, collection="1_Info", quiet=True) @@ -359,6 +375,8 @@ def _log_train_info(self, train_info, test_info, start_time): train_return = train_info.pop("mean_return") train_ep_len = train_info.pop("mean_ep_len") train_eps = train_info.pop("episodes") + train_eps = mp_util.reduce_sum(train_eps) + self._logger.log("Train_Return", train_return, collection="0_Main") self._logger.log("Train_Episode_Length", train_ep_len, collection="0_Main", quiet=True) self._logger.log("Train_Episodes", train_eps, collection="1_Info", quiet=True) diff --git a/learning/mp_optimizer.py b/learning/mp_optimizer.py new file mode 100644 index 0000000..54495ba --- /dev/null +++ b/learning/mp_optimizer.py @@ -0,0 +1,82 @@ +import torch + +import util.mp_util as mp_util + +class MPOptimizer(): + CHECK_SYNC_STEPS = 1000 + + def __init__(self, config, param_list): + self._param_list = param_list + self._grad_list = None + self._optimizer = self._build_optimizer(config, param_list) + self._steps = 0 + + if (mp_util.enable_mp()): + self._param_buffer = self._build_param_buffer() + + self.sync() + return + + def step(self, loss): + self._optimizer.zero_grad() + loss.backward() + + if (mp_util.enable_mp()): + self._aggregate_mp_grads() + + self._optimizer.step() + + if (mp_util.enable_mp() and (self.get_steps() % self.CHECK_SYNC_STEPS == 0)): + assert(self._check_synced()), "Network parameters desynchronized" + + self._steps += 1 + return + + def get_steps(self): + return self._steps + + def sync(self): + with torch.no_grad(): + for param in self._param_list: + global_param = mp_util.broadcast(param) + param.copy_(global_param) + return + + def _build_optimizer(self, config, param_list): + lr = float(config["learning_rate"]) + optimizer_type = config["type"] + if optimizer_type == "SGD": + optimizer = torch.optim.SGD(param_list, lr, momentum=0.9) + elif optimizer_type == "Adam": + optimizer = torch.optim.Adam(param_list, lr) + else: + assert(False), "Unsupported optimizer type: " + optimizer_type + return optimizer + + def _build_param_buffer(self): + buffer = torch.nn.utils.parameters_to_vector(self._param_list).clone().detach() + return buffer + + def _check_synced(self): + synced = True + for param in self._param_list: + global_param = mp_util.broadcast(param) + param_synced = torch.equal(param, global_param) + if (not param_synced): + synced = False + + device = self._param_list[0].device + buffer = torch.tensor([synced], dtype=torch.int, device=device) + mp_util.reduce_min(buffer) + synced = buffer.item() != 0 + + return synced + + def _aggregate_mp_grads(self): + if (self._grad_list is None): + self._grad_list = [p.grad for p in self._param_list] + + self._param_buffer[:] = torch.nn.utils.parameters_to_vector(self._grad_list) + mp_util.reduce_inplace_mean(self._param_buffer) + torch.nn.utils.vector_to_parameters(self._param_buffer, self._grad_list) + return diff --git a/learning/normalizer.py b/learning/normalizer.py index 4f13165..51f0f24 100644 --- a/learning/normalizer.py +++ b/learning/normalizer.py @@ -1,6 +1,9 @@ import numpy as np import torch +import util.mp_util as mp_util +from util.logger import Logger + class Normalizer(torch.nn.Module): def __init__(self, shape, device, init_mean=None, init_std=None, eps=1e-4, clip=np.inf, dtype=torch.float): super().__init__() @@ -25,7 +28,11 @@ def record(self, x): def update(self): if (self._mean_sq is None): self._mean_sq = self._calc_mean_sq(self._mean, self._std) - + + self._new_count = mp_util.reduce_sum(self._new_count) + mp_util.reduce_inplace_sum(self._new_sum) + mp_util.reduce_inplace_sum(self._new_sum_sq) + new_count = self._new_count new_mean = self._new_sum / new_count new_mean_sq = self._new_sum_sq / new_count @@ -59,10 +66,9 @@ def get_std(self): def set_mean_std(self, mean, std): shape = self.get_shape() - is_array = isinstance(mean, np.ndarray) and isinstance(std, np.ndarray) assert mean.shape == shape and std.shape == shape, \ - print("Normalizer shape mismatch, expecting size {:d}, but got {:d} and {:d}".format(shape, mean.shape, std.shape)) + Logger.print("Normalizer shape mismatch, expecting size {:d}, but got {:d} and {:d}".format(shape, mean.shape, std.shape)) self._mean[:] = mean self._std[:] = std @@ -80,14 +86,13 @@ def unnormalize(self, norm_x): def _calc_std(self, mean, mean_sq): var = mean_sq - torch.square(mean) - # some time floating point errors can lead to small negative numbers - var = torch.clamp_min(var, 0.0) + var = torch.clamp_min(var, 1e-8) std = torch.sqrt(var) std = std.type(self.dtype) return std def _calc_mean_sq(self, mean, std): - mean_sq = torch.square(std) + torch.square(self._mean) + mean_sq = torch.square(std) + torch.square(mean) mean_sq = mean_sq.type(self.dtype) return mean_sq @@ -98,16 +103,16 @@ def _build_params(self, shape, device, init_mean, init_std): if init_mean is not None: assert init_mean.shape == shape, \ - print('Normalizer init mean shape mismatch, expecting {:d}, but got {:d}'.shape(size, init_mean.shape)) + Logger.print('Normalizer init mean shape mismatch, expecting {:d}, but got {:d}'.shape(shape, init_mean.shape)) self._mean[:] = init_mean if init_std is not None: assert init_std.shape == shape, \ - print('Normalizer init std shape mismatch, expecting {:d}, but got {:d}'.format(shape, init_std.shape)) + Logger.print('Normalizer init std shape mismatch, expecting {:d}, but got {:d}'.format(shape, init_std.shape)) self._std[:] = init_std self._mean_sq = None - + self._new_count = 0 self._new_sum = torch.zeros_like(self._mean) self._new_sum_sq = torch.zeros_like(self._mean) diff --git a/run.py b/run.py index f5f14a6..0cfe8c5 100644 --- a/run.py +++ b/run.py @@ -2,9 +2,12 @@ import numpy as np import os import sys -import yaml +import time +import torch + import envs.env_builder as env_builder import learning.agent_builder as agent_builder +import util.mp_util as mp_util import util.util as util def set_np_formatting(): @@ -28,12 +31,11 @@ def load_args(argv): parser.add_argument("--model_file", dest="model_file", type=str, default="") parser.add_argument("--max_samples", dest="max_samples", type=np.int64, default=np.iinfo(np.int64).max) parser.add_argument("--test_episodes", dest="test_episodes", type=np.int64, default=np.iinfo(np.int64).max) + parser.add_argument("--master_port", dest="master_port", type=int, default=None) + parser.add_argument("--num_workers", dest="num_workers", type=int, default=1) args = parser.parse_args() - if (args.rand_seed is not None): - util.set_rand_seed(args.rand_seed) - return args def build_env(args, device, visualize): @@ -59,19 +61,27 @@ def test(agent, test_episodes): return result def create_output_dirs(out_model_file, int_output_dir): - output_dir = os.path.dirname(out_model_file) - if (output_dir != "" and (not os.path.exists(output_dir))): - os.makedirs(output_dir, exist_ok=True) + if (mp_util.is_root_proc()): + output_dir = os.path.dirname(out_model_file) + if (output_dir != "" and (not os.path.exists(output_dir))): + os.makedirs(output_dir, exist_ok=True) - if (int_output_dir != "" and (not os.path.exists(int_output_dir))): - os.makedirs(int_output_dir, exist_ok=True) + if (int_output_dir != "" and (not os.path.exists(int_output_dir))): + os.makedirs(int_output_dir, exist_ok=True) return -def main(argv): - set_np_formatting() +def set_rand_seed(args): + rand_seed = args.rand_seed - args = load_args(argv) + if (rand_seed is None): + rand_seed = np.uint64(time.time() * 256) + + rand_seed += np.uint64(41 * mp_util.get_proc_rank()) + print("Setting seed: {}".format(rand_seed)) + util.set_rand_seed(rand_seed) + return +def run(rank, num_procs, master_port, args): mode = args.mode device = args.device visualize = args.visualize @@ -79,6 +89,11 @@ def main(argv): out_model_file = args.out_model_file int_output_dir = args.int_output_dir model_file = args.model_file + + mp_util.init(rank, num_procs, device, master_port) + + set_rand_seed(args) + set_np_formatting() create_output_dirs(out_model_file, int_output_dir) @@ -97,6 +112,34 @@ def main(argv): test(agent=agent, test_episodes=test_episodes) else: assert(False), "Unsupported mode: {}".format(mode) + + return + + +def main(argv): + args = load_args(argv) + master_port = args.master_port + num_workers = args.num_workers + assert(num_workers > 0) + + # if master port is not specified, then pick a random one + if (master_port is None): + master_port = np.random.randint(6000, 7000) + + torch.multiprocessing.set_start_method("spawn") + + processes = [] + for i in range(num_workers - 1): + rank = i + 1 + proc = torch.multiprocessing.Process(target=run, args=[rank, num_workers, master_port, args]) + proc.start() + processes.append(proc) + + run(0, num_workers, master_port, args) + + for proc in processes: + proc.join() + return if __name__ == "__main__": diff --git a/util/logger.py b/util/logger.py index d66bdf5..194ba0c 100644 --- a/util/logger.py +++ b/util/logger.py @@ -1,4 +1,9 @@ -import os.path as osp, shutil, time, atexit, os, subprocess +import os +import time +import atexit +import torch + +import util.mp_util as mp_util class Logger: class Entry: @@ -7,6 +12,14 @@ def __init__(self, val, quiet=False): self.quiet = quiet return + def print(str, end=None): + if (Logger.is_root()): + print(str, end=end) + return + + def is_root(): + return mp_util.is_root_proc() + def __init__(self): self.output_file = None self.log_headers = [] @@ -14,12 +27,17 @@ def __init__(self): self._dump_str_template = "" self._max_key_len = 0 self._row_count = 0 + self._need_update = True + self._data_buffer = None return def reset(self): self._row_count = 0 self.log_headers = [] self.log_current_row = {} + self._need_update = True + self._data_buffer = None + if self.output_file is not None: self.output_file = open(output_path, 'w') return @@ -36,14 +54,17 @@ def configure_output_file(self, filename=None): out_dir = os.path.dirname(output_path) - if (not os.path.exists(out_dir)): - os.makedirs(out_dir, exist_ok=True) + is_root = Logger.is_root() + if (is_root): + if (not os.path.exists(out_dir)): + os.makedirs(out_dir, exist_ok=True) - self.output_file = open(output_path, 'w') - assert osp.exists(output_path) - atexit.register(self.output_file.close) + if (Logger.is_root()): + self.output_file = open(output_path, 'w') + assert os.path.exists(output_path) + atexit.register(self.output_file.close) - print("Logging data to " + self.output_file.name) + Logger.print("Logging data to " + self.output_file.name) return @@ -58,6 +79,7 @@ def log(self, key, val, quiet=False): else: assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key self.log_current_row[key] = Logger.Entry(val, quiet) + self._need_update = True return def get_num_keys(self): @@ -68,49 +90,58 @@ def print_log(self): Print all of the diagnostics from the current iteration """ + if (mp_util.enable_mp() and self._need_update): + self._mp_aggregate() + key_spacing = self._max_key_len format_str = "| %" + str(key_spacing) + "s | %15s |" - vals = [] - print("-" * (22 + key_spacing)) - for key in self.log_headers: - entry = self.log_current_row.get(key, "") - if not (entry.quiet): - val = entry.val - - if isinstance(val, float): - valstr = "%8.3g"%val - elif isinstance(val, int): - valstr = str(val) - else: - valstr = val - - print(format_str%(key, valstr)) - vals.append(val) - print("-" * (22 + key_spacing)) + if (Logger.is_root()): + vals = [] + Logger.print("-" * (22 + key_spacing)) + for key in self.log_headers: + entry = self.log_current_row.get(key, "") + if not (entry.quiet): + val = entry.val + + if isinstance(val, float): + valstr = "%8.3g"%val + elif isinstance(val, int): + valstr = str(val) + else: + valstr = val + + Logger.print(format_str%(key, valstr)) + vals.append(val) + Logger.print("-" * (22 + key_spacing)) return def write_log(self): """ Write all of the diagnostics from the current iteration """ - if (self._row_count == 0): - self._dump_str_template = self._build_str_template() + + if (self._need_update): + self._mp_aggregate() - vals = [] - for key in self.log_headers: - entry = self.log_current_row.get(key, "") - val = entry.val - vals.append(val) - - if self.output_file is not None: + if (Logger.is_root()): if (self._row_count == 0): - header_str = self._dump_str_template.format(*self.log_headers) - self.output_file.write(header_str + "\r") + self._dump_str_template = self._build_str_template() - val_str = self._dump_str_template.format(*map(str,vals)) - self.output_file.write(val_str + "\r") - self.output_file.flush() + vals = [] + for key in self.log_headers: + entry = self.log_current_row.get(key, "") + val = entry.val + vals.append(val) + + if self.output_file is not None: + if (self._row_count == 0): + header_str = self._dump_str_template.format(*self.log_headers) + self.output_file.write(header_str + "\r") + + val_str = self._dump_str_template.format(*map(str,vals)) + self.output_file.write(val_str + "\r") + self.output_file.flush() self._row_count += 1 return @@ -128,4 +159,26 @@ def get_current_val(self, key): def _build_str_template(self): num_keys = self.get_num_keys() template = "{:<25}" * num_keys - return template \ No newline at end of file + return template + + def _mp_aggregate(self): + if (self._data_buffer is None): + n = len(self.log_headers) + self._data_buffer = torch.zeros(n, dtype=torch.float64, device=mp_util.get_device()) + + for i, key in enumerate(self.log_headers): + entry = self.log_current_row[key] + val = entry.val + self._data_buffer[i] = val + + mp_util.reduce_inplace_mean(self._data_buffer) + + for i, key in enumerate(self.log_headers): + entry = self.log_current_row[key] + val = self._data_buffer[i].item() + if (isinstance(entry.val, int)): + val = int(val) + entry.val = val + + self._need_update = False + return \ No newline at end of file diff --git a/util/mp_util.py b/util/mp_util.py new file mode 100644 index 0000000..b93702f --- /dev/null +++ b/util/mp_util.py @@ -0,0 +1,137 @@ +import os +import platform +import torch + +ROOT_PROC_RANK = 0 + +global_mp_device = None + +def init(rank, num_procs, device, master_port): + global global_mp_device + + assert(global_mp_device is None) + global_mp_device = device + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + print("Using master port: {:d}".format(master_port)) + + if (device == "cpu"): + backend = "gloo" + elif ("cuda" in device): + backend = "nccl" + else: + assert False, "Unsupported multiprocessing device {:s}".format(device) + + os_platform = platform.system() + if (backend == "nccl" and os_platform == "Windows"): + print("Pytorch doesn't support NCCL on Windows, defaulting to gloo backend") + backend = "gloo" + + torch.distributed.init_process_group(backend, rank=rank, world_size=num_procs) + + return + +def get_num_procs(): + try: + num_procs = torch.distributed.get_world_size() + except: + num_procs = 1 + return num_procs + +def get_proc_rank(): + try: + proc_rank = torch.distributed.get_rank() + except: + proc_rank = 0 + return proc_rank + +def is_root_proc(): + rank = get_proc_rank() + return rank == ROOT_PROC_RANK + +def enable_mp(): + num_procs = get_num_procs() + return num_procs > 1 + +def get_device(): + return global_mp_device + +def broadcast(x): + if (enable_mp()): + data = x.clone() + torch.distributed.broadcast(data, src=ROOT_PROC_RANK) + else: + data = x + return data + +def all_gather(x): + n = get_num_procs() + if (enable_mp()): + data = [torch.empty_like(x) for i in range(n)] + torch.distributed.all_gather(data, x) + else: + data = [x] + return data + +def reduce_sum(x): + return reduce_all(x, torch.distributed.ReduceOp.SUM) + +def reduce_prod(x): + return reduce_all(x, torch.distributed.ReduceOp.PROD) + +def reduce_mean(x): + n = get_num_procs() + sum_x = reduce_sum(x) + mean_x = sum_x / n + return mean_x + +def reduce_min(x): + return reduce_all(x, torch.distributed.ReduceOp.MIN) + +def reduce_max(x): + return reduce_all(x, torch.distributed.ReduceOp.MAX) + +def reduce_all(x, op): + if (enable_mp()): + is_tensor = torch.is_tensor(x) + if (is_tensor): + buffer = x.clone() + else: + buffer = torch.tensor(x, device=get_device()) + + torch.distributed.all_reduce(buffer, op=op) + + if (not is_tensor): + buffer = buffer.item() + else: + buffer = x + + return buffer + +def reduce_inplace_sum(x): + reduce_inplace_all(x, torch.distributed.ReduceOp.SUM) + return + +def reduce_inplace_prod(x): + reduce_inplace_all(x, torch.distributed.ReduceOp.PROD) + return + +def reduce_inplace_mean(x): + n = get_num_procs() + reduce_inplace_sum(x) + x /= n + return + +def reduce_inplace_min(x): + reduce_inplace_all(x, torch.distributed.ReduceOp.MIN) + return + +def reduce_inplace_max(x): + reduce_inplace_all(x, torch.distributed.ReduceOp.MAX) + return + +def reduce_inplace_all(x, op): + if (enable_mp()): + torch.distributed.all_reduce(x, op=op) + return diff --git a/util/util.py b/util/util.py index 2950211..0f6d345 100644 --- a/util/util.py +++ b/util/util.py @@ -3,10 +3,8 @@ import torch def set_rand_seed(seed): - print("Setting seed: {}".format(seed)) - random.seed(seed) - np.random.seed(seed) + np.random.seed(np.uint64(seed % (2**32))) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed)