diff --git a/.github/workflows/test_linux.yml b/.github/workflows/test_linux.yml index cad13ccf..b7ae5ac6 100644 --- a/.github/workflows/test_linux.yml +++ b/.github/workflows/test_linux.yml @@ -52,7 +52,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - uses: goanpeca/setup-miniconda@v1 + - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test_macos.yml b/.github/workflows/test_macos.yml index 539cacc9..af7f4907 100644 --- a/.github/workflows/test_macos.yml +++ b/.github/workflows/test_macos.yml @@ -52,7 +52,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - uses: goanpeca/setup-miniconda@v1 + - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/test_windows.yml b/.github/workflows/test_windows.yml index 77be36c8..9e01bc2d 100644 --- a/.github/workflows/test_windows.yml +++ b/.github/workflows/test_windows.yml @@ -52,7 +52,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - - uses: goanpeca/setup-miniconda@v1 + - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true python-version: ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e772ecc..0fcfb45d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: rev: 20.8b1 hooks: - id: black - language_version: python3.7 + language_version: python3.8 - repo: https://gitlab.com/pycqa/flake8 rev: 3.8.3 diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 1f94992e..df258d00 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -118,7 +118,7 @@ def select_action( action, dist = self.ac.get_action(state, deterministic=deterministic) value = self.ac.get_value(state) - return action.detach(), value, dist.log_prob(action).cpu() + return action.detach(), value, dist.log_prob(action) def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None: """Get loss from trajectory traversed by agent during rollouts @@ -129,8 +129,8 @@ def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None: values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`list` of bool): Game over statuses of each environment """ - compute_returns_and_advantage( - self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy() + self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage( + self.rollout, values.detach(), dones.to(self.device) ) def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor): @@ -150,7 +150,7 @@ def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor): states, actions = states.to(self.device), actions.to(self.device) _, dist = self.ac.get_action(states, deterministic=False) values = self.ac.get_value(states) - return values, dist.log_prob(actions).cpu(), dist.entropy().cpu() + return values, dist.log_prob(actions), dist.entropy() def update_params(self) -> None: """Updates the the A2C network @@ -171,7 +171,7 @@ def update_params(self) -> None: policy_loss = -torch.mean(policy_loss) self.logs["policy_loss"].append(policy_loss.item()) - value_loss = self.value_coeff * F.mse_loss(rollout.returns, values.cpu()) + value_loss = self.value_coeff * F.mse_loss(rollout.returns, values) self.logs["value_loss"].append(torch.mean(value_loss).item()) entropy_loss = -torch.mean(entropy) # Change this to entropy diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index fa632e8a..0a473a07 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -47,6 +47,7 @@ def __init__( self.replay_buffer = PrioritizedBuffer(self.replay_size) else: raise NotImplementedError + # self.replay_buffer = self.replay_buffer.to(self.device) def update_params_before_select_action(self, timestep: int) -> None: """Update any parameters before selecting action like epsilon for decaying epsilon greedy @@ -107,6 +108,7 @@ def sample_from_buffer(self, beta: float = None): ) else: raise NotImplementedError + # print(batch.device) return batch def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor: @@ -118,9 +120,13 @@ def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor: Returns: loss (:obj:`torch.Tensor`): Calculated loss of the Q-function """ - q_values = self.get_q_values(batch.states, batch.actions) + q_values = self.get_q_values( + batch.states.to(self.device), batch.actions.to(self.device) + ) target_q_values = self.get_target_q_values( - batch.next_states, batch.rewards, batch.dones + batch.next_states.to(self.device), + batch.rewards.to(self.device), + batch.dones.to(self.device), ) loss = F.mse_loss(q_values, target_q_values) return loss @@ -167,15 +173,16 @@ def select_action( Returns: action (:obj:`torch.Tensor`): Action taken by the agent """ + state = state.to(self.device) action, _ = self.ac.get_action(state, deterministic) action = action.detach() # add noise to output from policy network if self.noise is not None: - action += self.noise() + action += self.noise().to(self.device) return torch.clamp( - action, self.env.action_space.low[0], self.env.action_space.high[0] + action.cpu(), self.env.action_space.low[0], self.env.action_space.high[0] ) def update_target_model(self) -> None: @@ -199,6 +206,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten Returns: q_values (:obj:`torch.Tensor`): Q values for the given states and actions """ + states, actions = states.to(self.device), actions.to(self.device) if self.doublecritic: q_values = self.ac.get_value( torch.cat([states, actions], dim=-1), mode="both" @@ -221,6 +229,7 @@ def get_target_q_values( Returns: target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3 """ + next_states = next_states.to(self.device) next_target_actions = self.ac_target.get_action(next_states, True)[0] if self.doublecritic: @@ -231,7 +240,10 @@ def get_target_q_values( next_q_target_values = self.ac_target.get_value( torch.cat([next_states, next_target_actions], dim=-1) ) - target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values + target_q_values = ( + rewards.to(self.device) + + self.gamma * (1 - dones.to(self.device)) * next_q_target_values + ) return target_q_values @@ -265,6 +277,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor: Returns: loss (:obj:`torch.Tensor`): Calculated policy loss """ + states = states.to(self.device) next_best_actions = self.ac.get_action(states, True)[0] q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1)) policy_loss = -torch.mean(q_values) diff --git a/genrl/agents/deep/base/onpolicy.py b/genrl/agents/deep/base/onpolicy.py index 0c83cdfd..2eccdbec 100644 --- a/genrl/agents/deep/base/onpolicy.py +++ b/genrl/agents/deep/base/onpolicy.py @@ -36,7 +36,7 @@ def __init__( if buffer_type == "rollout": self.rollout = RolloutBuffer( - self.rollout_size, self.env, gae_lambda=gae_lambda + self.rollout_size, self.env, gae_lambda=gae_lambda, device=self.device ) else: raise NotImplementedError @@ -73,7 +73,7 @@ def collect_rollouts(self, state: torch.Tensor): dones (:obj:`torch.Tensor`): Game over statuses of each environment """ for i in range(self.rollout_size): - action, values, old_log_probs = self.select_action(state) + action, values, old_log_probs = self.select_action(state.to(self.device)) next_state, reward, dones, _ = self.env.step(action) diff --git a/genrl/agents/deep/ppo1/ppo1.py b/genrl/agents/deep/ppo1/ppo1.py index 9d573069..4d721321 100644 --- a/genrl/agents/deep/ppo1/ppo1.py +++ b/genrl/agents/deep/ppo1/ppo1.py @@ -113,7 +113,7 @@ def select_action( action, dist = self.ac.get_action(state, deterministic=deterministic) value = self.ac.get_value(state) - return action.detach(), value, dist.log_prob(action).cpu() + return action.detach(), value, dist.log_prob(action) def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor): """Evaluates actions taken by actor @@ -132,7 +132,7 @@ def evaluate_actions(self, states: torch.Tensor, actions: torch.Tensor): states, actions = states.to(self.device), actions.to(self.device) _, dist = self.ac.get_action(states, deterministic=False) values = self.ac.get_value(states) - return values, dist.log_prob(actions).cpu(), dist.entropy().cpu() + return values, dist.log_prob(actions), dist.entropy() def get_traj_loss(self, values, dones): """Get loss from trajectory traversed by agent during rollouts @@ -143,10 +143,10 @@ def get_traj_loss(self, values, dones): values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`list` of bool): Game over statuses of each environment """ - compute_returns_and_advantage( + self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage( self.rollout, - values.detach().cpu().numpy(), - dones.cpu().numpy(), + values.detach(), + dones.to(self.device), use_gae=True, ) @@ -180,7 +180,7 @@ def update_params(self): values = values.flatten() value_loss = self.value_coeff * nn.functional.mse_loss( - rollout.returns, values.cpu() + rollout.returns, values ) self.logs["value_loss"].append(torch.mean(value_loss).item()) diff --git a/genrl/agents/deep/sac/sac.py b/genrl/agents/deep/sac/sac.py index ff93eeaa..5dc4253f 100644 --- a/genrl/agents/deep/sac/sac.py +++ b/genrl/agents/deep/sac/sac.py @@ -67,10 +67,10 @@ def _create_model(self, **kwargs) -> None: else: self.action_scale = torch.FloatTensor( (self.env.action_space.high - self.env.action_space.low) / 2.0 - ) + ).to(self.device) self.action_bias = torch.FloatTensor( (self.env.action_space.high + self.env.action_space.low) / 2.0 - ) + ).to(self.device) if isinstance(self.network, str): state_dim, action_dim, discrete, _ = get_env_properties( @@ -89,7 +89,7 @@ def _create_model(self, **kwargs) -> None: sac=True, action_scale=self.action_scale, action_bias=self.action_bias, - ) + ).to(self.device) else: self.model = self.network @@ -102,7 +102,7 @@ def _create_model(self, **kwargs) -> None: self.target_entropy = -torch.prod( torch.Tensor(self.env.action_space.shape) ).item() - self.log_alpha = torch.zeros(1, requires_grad=True) + self.log_alpha = torch.zeros(1, device=self.device, requires_grad=True) self.optimizer_alpha = opt.Adam([self.log_alpha], lr=self.lr_policy) def select_action( @@ -119,8 +119,9 @@ def select_action( Returns: action (:obj:`np.ndarray`): Action taken by the agent """ - action, _, _ = self.ac.get_action(state, deterministic) - return action.detach() + state = state.to(self.device) + action, _, _ = self.ac.get_action(state.to(self.device), deterministic) + return action.detach().cpu() def update_target_model(self) -> None: """Function to update the target Q model @@ -147,11 +148,15 @@ def get_target_q_values( Returns: target_q_values (:obj:`torch.Tensor`): Target Q values for the SAC """ + next_states = next_states.to(self.device) next_target_actions, next_log_probs, _ = self.ac.get_action(next_states) next_q_target_values = self.ac_target.get_value( torch.cat([next_states, next_target_actions], dim=-1), mode="min" ).squeeze() - self.alpha * next_log_probs.squeeze(1) - target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values + target_q_values = ( + rewards.to(self.device) + + self.gamma * (1 - dones.to(self.device)) * next_q_target_values + ) return target_q_values def get_p_loss(self, states: torch.Tensor) -> torch.Tensor: @@ -163,6 +168,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor: Returns: loss (:obj:`torch.Tensor`): Calculated policy loss """ + states = states.to(self.device) next_best_actions, log_probs, _ = self.ac.get_action(states) q_values = self.ac.get_value( torch.cat([states, next_best_actions], dim=-1), mode="min" diff --git a/genrl/agents/deep/td3/td3.py b/genrl/agents/deep/td3/td3.py index 88f8c074..5f3f1585 100644 --- a/genrl/agents/deep/td3/td3.py +++ b/genrl/agents/deep/td3/td3.py @@ -79,9 +79,9 @@ def _create_model(self) -> None: value_layers=self.value_layers, val_type="Qsa", discrete=False, - ) + ).to(self.device) else: - self.ac = self.network + self.ac = self.network.to(self.device) if self.noise is not None: self.noise = self.noise( diff --git a/genrl/agents/deep/vpg/vpg.py b/genrl/agents/deep/vpg/vpg.py index 2d6a6ef9..a236880f 100644 --- a/genrl/agents/deep/vpg/vpg.py +++ b/genrl/agents/deep/vpg/vpg.py @@ -86,8 +86,8 @@ def select_action( return ( action.detach(), - torch.zeros((1, self.env.n_envs)), - dist.log_prob(action).cpu(), + torch.zeros((1, self.env.n_envs), device=self.device), + dist.log_prob(action), ) def get_log_probs(self, states: torch.Tensor, actions: torch.Tensor): @@ -105,7 +105,7 @@ def get_log_probs(self, states: torch.Tensor, actions: torch.Tensor): """ states, actions = states.to(self.device), actions.to(self.device) _, dist = self.actor.get_action(states, deterministic=False) - return dist.log_prob(actions).cpu() + return dist.log_prob(actions) def get_traj_loss(self, values, dones): """Get loss from trajectory traversed by agent during rollouts @@ -116,8 +116,8 @@ def get_traj_loss(self, values, dones): values (:obj:`torch.Tensor`): Values of states encountered during the rollout dones (:obj:`list` of bool): Game over statuses of each environment """ - compute_returns_and_advantage( - self.rollout, values.detach().cpu().numpy(), dones.cpu().numpy() + self.rollout.returns, self.rollout.advantages = compute_returns_and_advantage( + self.rollout, values.detach().to(self.device), dones.to(self.device) ) def update_params(self) -> None: diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..8c330487 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -32,9 +32,10 @@ class ReplayBuffer: :type capacity: int """ - def __init__(self, capacity: int): + def __init__(self, capacity: int, device="cpu"): self.capacity = capacity self.memory = deque([], maxlen=capacity) + self.device = device def push(self, inp: Tuple) -> None: """ @@ -60,7 +61,7 @@ def sample( batch = random.sample(self.memory, batch_size) state, action, reward, next_state, done = map(np.stack, zip(*batch)) return [ - torch.from_numpy(v).float() + torch.from_numpy(v).float().to(self.device) for v in [state, action, reward, next_state, done] ] diff --git a/genrl/core/rollout_storage.py b/genrl/core/rollout_storage.py index 16d1c721..9d13ecfd 100644 --- a/genrl/core/rollout_storage.py +++ b/genrl/core/rollout_storage.py @@ -133,8 +133,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: :return: (torch.Tensor) """ if copy: - return array.detach().clone() - return array + return array.detach().clone().to(self.device) + return array.to(self.device) class RolloutBuffer(BaseBuffer): @@ -173,17 +173,27 @@ def __init__( def reset(self) -> None: self.observations = torch.zeros( - *(self.buffer_size, self.env.n_envs, *self.env.obs_shape) + *(self.buffer_size, self.env.n_envs, *self.env.obs_shape), + device=self.device ) self.actions = torch.zeros( - *(self.buffer_size, self.env.n_envs, *self.env.action_shape) + *(self.buffer_size, self.env.n_envs, *self.env.action_shape), + device=self.device + ) + self.rewards = torch.zeros( + self.buffer_size, self.env.n_envs, device=self.device + ) + self.returns = torch.zeros( + self.buffer_size, self.env.n_envs, device=self.device + ) + self.dones = torch.zeros(self.buffer_size, self.env.n_envs, device=self.device) + self.values = torch.zeros(self.buffer_size, self.env.n_envs, device=self.device) + self.log_probs = torch.zeros( + self.buffer_size, self.env.n_envs, device=self.device + ) + self.advantages = torch.zeros( + self.buffer_size, self.env.n_envs, device=self.device ) - self.rewards = torch.zeros(self.buffer_size, self.env.n_envs) - self.returns = torch.zeros(self.buffer_size, self.env.n_envs) - self.dones = torch.zeros(self.buffer_size, self.env.n_envs) - self.values = torch.zeros(self.buffer_size, self.env.n_envs) - self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs) - self.advantages = torch.zeros(self.buffer_size, self.env.n_envs) self.generator_ready = False super(RolloutBuffer, self).reset() @@ -210,12 +220,12 @@ def add( # Reshape 0-d tensor to avoid error log_prob = log_prob.reshape(-1, 1) - self.observations[self.pos] = obs.detach().clone() - self.actions[self.pos] = action.detach().clone() - self.rewards[self.pos] = reward.detach().clone() - self.dones[self.pos] = done.detach().clone() - self.values[self.pos] = value.detach().clone().flatten() - self.log_probs[self.pos] = log_prob.detach().clone().flatten() + self.observations[self.pos] = obs.detach().clone().to(self.device) + self.actions[self.pos] = action.detach().clone().to(self.device) + self.rewards[self.pos] = reward.detach().clone().to(self.device) + self.dones[self.pos] = done.detach().clone().to(self.device) + self.values[self.pos] = value.detach().clone().flatten().to(self.device) + self.log_probs[self.pos] = log_prob.detach().clone().flatten().to(self.device) self.pos += 1 if self.pos == self.buffer_size: self.full = True diff --git a/genrl/trainers/offpolicy.py b/genrl/trainers/offpolicy.py index 7e0571c2..a95324c2 100644 --- a/genrl/trainers/offpolicy.py +++ b/genrl/trainers/offpolicy.py @@ -65,7 +65,9 @@ def __init__( self.buffer = self.agent.replay_buffer def noise_reset(self) -> None: - """Resets the agent's action noise functions""" + """ + Reaseas + """ if "noise" in self.agent.__dict__ and self.agent.noise is not None: self.agent.noise.reset() diff --git a/genrl/utils/discount.py b/genrl/utils/discount.py index 54d8754d..aeec56ba 100644 --- a/genrl/utils/discount.py +++ b/genrl/utils/discount.py @@ -29,23 +29,26 @@ def compute_returns_and_advantage( if use_gae: gae_lambda = rollout_buffer.gae_lambda else: - gae_lambda = 1 + gae_lambda = 1.0 - next_values = last_value - next_non_terminal = 1 - dones + next_value = last_value + next_non_terminal = 1.0 - dones running_advantage = 0.0 for step in reversed(range(rollout_buffer.buffer_size)): delta = ( rollout_buffer.rewards[step] - + rollout_buffer.gamma * next_non_terminal * next_values + + rollout_buffer.gamma * next_value * next_non_terminal - rollout_buffer.values[step] ) running_advantage = ( - delta + rollout_buffer.gamma * gae_lambda * running_advantage + delta + + rollout_buffer.gamma * gae_lambda * next_non_terminal * running_advantage ) next_non_terminal = 1 - rollout_buffer.dones[step] - next_values = rollout_buffer.values[step] + next_value = rollout_buffer.values[step] rollout_buffer.advantages[step] = running_advantage rollout_buffer.returns = rollout_buffer.advantages + rollout_buffer.values + + return rollout_buffer.returns, rollout_buffer.advantages diff --git a/tests/test_agents/test_bandit/__init__.py b/tests/test_agents/test_bandit/__init__.py index 4411dff3..8faedc3d 100644 --- a/tests/test_agents/test_bandit/__init__.py +++ b/tests/test_agents/test_bandit/__init__.py @@ -1,6 +1,6 @@ from tests.test_agents.test_bandit.test_cb_agents import TestCBAgent # noqa from tests.test_agents.test_bandit.test_data_bandits import TestDataBandit # noqa from tests.test_agents.test_bandit.test_mab_agents import TestMABAgent # noqa -from tests.test_agents.test_bandit.test_multi_armed_bandits import ( - TestMultiArmedBandit, # noqa +from tests.test_agents.test_bandit.test_multi_armed_bandits import ( # noqa + TestMultiArmedBandit, )