diff --git a/genrl/agents/multiagent/base/offpolicy.py b/genrl/agents/multiagent/base/offpolicy.py new file mode 100644 index 00000000..d755b6f5 --- /dev/null +++ b/genrl/agents/multiagent/base/offpolicy.py @@ -0,0 +1,34 @@ +import collections +from abc import ABC + +import torch +import torch.nn as nn +import torch.optim as opt + +from genrl.core import MultiAgentReplayBuffer +from genrl.utils import MutiAgentEnvInterface + + +class MultiAgentOffPolicy(ABC): + """Base class for multiagent algorithms with OffPolicy agents + + Attributes: + network (str): The network type of the Q-value function. + Supported types: ["cnn", "mlp"] + env (Environment): The environment that the agent is supposed to act on + agents (list) : A list of all the agents to be used + create_model (bool): Whether the model of the algo should be created when initialised + batch_size (int): Mini batch size for loading experiences + gamma (float): The discount factor for rewards + layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network + of the Q-value function + lr_policy (float): Learning rate for the policy/actor + lr_value (float): Learning rate for the Q-value function + replay_size (int): Capacity of the Replay Buffer + seed (int): Seed for randomness + render (bool): Should the env be rendered during training? + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + raise NotImplementedError diff --git a/genrl/agents/multiagent/maddpg/__init__.py b/genrl/agents/multiagent/maddpg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/multiagent/maddpg/maddpg.py b/genrl/agents/multiagent/maddpg/maddpg.py new file mode 100644 index 00000000..a592cf55 --- /dev/null +++ b/genrl/agents/multiagent/maddpg/maddpg.py @@ -0,0 +1,274 @@ +from copy import deepcopy +from typing import Any, Tuple + +import numpy as np +import torch +import torch.optim as opt +from torch.nn import functional as F + +from genrl.agents import DDPG +from genrl.core import MultiAgentReplayBuffer, ReplayBufferSamples +from genrl.utils import PettingZooInterface, get_model + + +class MADDPG(ABC): + """MultiAgent Controller using the MADDPG algorithm + + Attributes: + network (str): The network type of the Q-value function. + Supported types: ["mlp"] + env (Environment): The environment that the agent is supposed to act on + create_model (bool): Whether the model of the algo should be created when initialised + batch_size (int): Mini batch size for loading experiences + gamma (float): The discount factor for rewards + shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using + layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network + of the Q-value function + lr_policy (float): Learning rate for the policy/actor + lr_value (float): Learning rate for the critic + replay_size (int): Capacity of the Replay Buffer + polyak (float): Target model update parameter (1 for hard update) + noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration + noise_std (float): Standard deviation of the action noise distribution + max_ep_len (int): Maximum Episode length for training + max_timesteps (int): Maximum limit of timesteps to train for + warmup_steps (int): Number of warmup steps (random actions are taken to add randomness to the training) + start_update (int): Timesteps after which the agent networks should start updating + update_interval (int): Timesteps between target network updates + seed (int): Seed for randomness + render (bool): Should the env be rendered during training? + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + def __init__( + self, + network: Any, + env: Any, + batch_size: int = 64, + gamma: float = 0.99, + shared_layers=None, + policy_layers: Tuple = (64, 64), + value_layers: Tuple = (64, 64), + lr_policy: float = 0.0001, + lr_value: float = 0.001, + replay_size: int = int(1e6), + polyak: float = 0.995, + noise: ActionNoise = None, + noise_std: float = 0.2, + max_ep_len: int = 200, + max_timesteps: int = 5000, + warmup_steps=1000, + start_update: int = 1000, + update_interval: int = 50, + **kwargs, + ): + self.noise = noise + self.doublecritic = False + self.noise_std = noise_std + self.gamma = self.gamma + self.env = env + self.network = network + self.batch_size = batch_size + self.lr_value = lr_value + self.num_agents = self.env.num_agents + self.replay_buffer = MultiAgentReplayBuffer(self.num_agents, buffer_maxlen) + self.render = render + self.warmup_steps = warmup_steps + self.shared_layers = shared_layers + self.policy_layers = policy_layers + self.value_layers = value_layers + self.max_ep_len = max_ep_len + self.max_timesteps = max_timesteps + ac = self._create_model() + self.agents = [ + DDPG( + network=ac, env=env, lr_policy=lr_policy, lr_value=lr_value, gamma=gamma + ) + for agent in self.env.agents + ] + self.EnvInterface = PettingZooInterface(self.env, self.agents) + + def _create_model(self): + state_dim, action_dim, discrete, _ = self.EnvInterface.get_env_properties() + if discrete: + raise Exception( + "Discrete Environments not supported for {}.".format(__class__.__name__) + ) + + if self.noise is not None: + self.noise = self.noise( + torch.zeros(action_dim), self.noise_std * torch.ones(action_dim) + ) + + if isinstance(self.network, str): + arch_type = self.network + arch_type += "c" + if self.shared_layers is not None: + raise NotImplementedError + ac = get_model("ac", arch_type)( + state_dim, + action_dim, + self.num_agents, + self.shared_layers, + self.policy_layers, + self.value_layers, + "Qsa", + False, + ).to(self.device) + else: + ac = self.network + + return ac + + def get_target_q_values(self, agent, global_batch, segmented_batch): + global_next_actions = [ + agent.ac_target.get_action( + segmented_batch[3][:, i, :], deterministic=True + ).numpy() + for agent, i in enumerate(self.agents) + ] + global_next_actions = torch.cat(global_next_actions, dim=1) + global_next_actions = global_next_actions.float() + + if self.doublecritic: + next_q_target_values = agent.ac_target.get_value( + torch.cat([agent_batch.next_states, global_next_actions], dim=-1), + mode="min", + ) + else: + next_q_target_values = agent.ac_target.get_value( + torch.cat([agent_batch.next_states, global_next_actions], dim=-1) + ) + + target_q_values = ( + agent_batch.rewards + + self.gamma * (1 - agent_batch.dones) * next_q_target_values + ) + + return target_q_values + + def get_q_loss(self, agent, agent_batch, segmented_batch): + q_values = agent.get_q_values(global_batch.states, global_batch.actions) + target_q_values = self.get_target_q_values(agent, agent_batch, segmented_batch) + + if self.doublecritic: + loss = F.mse_loss(q_values[0], target_q_values) + F.mse_loss( + q_values[1], target_q_values + ) + else: + loss = F.mse_loss(q_values, target_q_values) + + return loss + + def get_p_loss(self, agent, global_state_batch, segmented_states_batch): + global_next_best_actions = [ + agent.ac.get_action( + segmented_states_batch[:, i, :], deterministic=True + ).numpy() + for agent, i in enumerate(self.agents) + ] + global_next_best_actions = torch.cat(global_next_best_actions, dim=1) + global_next_best_actions = global_next_best_actions.float() + + q_values = agent.ac.get_value( + torch.cat([global_state_batch, global_next_best_actions], dim=-1) + ) + policy_loss = -torch.mean(q_values) + return policy_loss + + def update(self): + segmented_batch, global_batch = self.replay_buffer.sample(self.batch_size) + + for transition in segmented_batch: + for i, _ in enumerate(segmented_batch): + transition[i] = self.EnvInterface.flatten(transition[i]) + + ( + segmented_states, + segmented_actions, + segmented_rewards, + segmented_next_states, + segmented_dones, + ) = map(np.stack, zip(*bitch)) + segmented_batch = [ + torch.from_numpy(v).float() + for v in [ + segmented_states, + segmented_actions, + segmented_rewards, + segmented_next_states, + segmented_dones, + ] + ] + + for i, agent in enumerate(self.agents): + agent_rewards_v = torch.reshape(global_batch[2][:, i], (self.batch_size, 1)) + agent_dones_v = torch.reshape(global_batch[4][:, i], (self.batch_size, 1)) + agent_batch_v = ReplayBufferSamples( + *[ + global_batch[0], + global_batch[1], + agent_rewards_v, + global_batch[3], + agent_dones_v, + ] + ) + value_loss = self.get_q_loss( + agent=agent, agent_batch=agent_batch_v, segmented_batch=segmented_batch + ) + + value_loss.backward() + agent.optimizer_value.step() + + agent_states_p = segmented_batch[0][:, i, :] + policy_loss = self.get_p_loss(agent, global_batch[0], segmented_batch[0]) + + policy_loss.backward() + agent.optimizer_policy.step() + + for agent in self.agents: + agent.update_target_model() + + def train(self): + episode_rewards = [] + for episode in range(self.max_ep_len): + states = self.env.reset() + episode_reward = 0 + step = -1 + for step in range(self.max_timesteps): + if self.render: + self.env.render(mode="human") + + step += 1 + actions = self.EnvInterface.get_actions( + states, + steps, + self.warmup_steps, + type="offpolicy", + deterministic=True, + ) + next_states, rewards, dones, _ = self.env.step(actions) + step_rewards = self.EnvInterface.flatten(rewards) + episode_reward += np.mean(step_rewards) + step_dones = self.EnvInterface.flatten(dones) + if all(step_dones) or step == max_steps - 1: + dones = {agent: True for agent in self.env.agents} + self.replay_buffer.push( + [states, actions, rewards, next_states, dones] + ) + episode_rewards.append(episode_reward) + print( + f"Episode: {episode + 1} | Steps Taken: {step +1} | Reward {episode_reward}" + ) + break + else: + dones = {agent: False for agent in self.env.agents} + + self.replay_buffer.push( + [states, actions, rewards, next_states, dones] + ) + states = next_states + + if step >= self.start_update and step % self.update_interval == 0: + self.update() diff --git a/genrl/core/__init__.py b/genrl/core/__init__.py index 96824ac1..ca118516 100644 --- a/genrl/core/__init__.py +++ b/genrl/core/__init__.py @@ -5,6 +5,7 @@ from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa from genrl.core.buffers import ReplayBuffer # noqa from genrl.core.buffers import ReplayBufferSamples # noqa +from genrl.core.buffers import MultiAgentReplayBuffer from genrl.core.noise import ActionNoise # noqa from genrl.core.noise import NoisyLinear # noqa from genrl.core.noise import NormalActionNoise # noqa @@ -16,6 +17,7 @@ get_policy_from_name, ) from genrl.core.rollout_storage import RolloutBuffer # noqa +from genrl.core.rollout_storage import MultiAgentRolloutBuffer # noqa from genrl.core.values import ( # noqa BaseValue, CnnCategoricalValue, diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1a96e959..e5132278 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -469,12 +469,53 @@ def get_value(self, inp: torch.Tensor) -> torch.Tensor: return value +class MlpActorCentralCritic(BaseActorCritic): + """MLP Actor Central Critic + + Attributes: + state_dim (int): State dimensions of a single agent in the environment + action_dim (int): Action space dimensions of a single agent in the environment + n_agents (int): Number of agents in the environment + policy_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the policy MLP + value_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the value MLP + val_type (str): Value type of the critic network + discrete (bool): True if the action space is discrete, else False + sac (bool): True if a SAC-like network is needed, else False + activation (str): Activation function to be used. Can be either "tanh" or "relu" + """ + + def __init__( + self, + state_dim: spaces.Space, + action_dim: spaces.Space, + n_agents: int, + shared_layers: Tuple = (32, 32), + policy_layers: Tuple = (32, 32), + value_layers: Tuple = (32, 32), + val_type: str = "V", + discrete: bool = True, + **kwargs, + ): + super(MlpActorCentralCritic, self).__init__() + + self.actor = MlpPolicy(state_dim, action_dim, policy_layers, discrete, **kwargs) + self.critic = MlpValue( + n_agents * state_dim, n_agent * action_dim, val_type, value_layers, **kwargs + ) + + def get_params(self): + actor_params = self.actor.parameters() + critic_params = self.critic.parameters() + return actor_params, critic_params + + actor_critic_registry = { "mlp": MlpActorCritic, "cnn": CNNActorCritic, "mlp12": MlpSingleActorTwoCritic, "mlps": MlpSharedActorCritic, "mlp12s": MlpSharedSingleActorTwoCritic, + "mlpc": MlpActorCentralCritic, } diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..e5074892 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -146,15 +146,7 @@ def sample( return [ torch.as_tensor(v, dtype=torch.float32) - for v in [ - states, - actions, - rewards, - next_states, - dones, - indices, - weights, - ] + for v in [states, actions, rewards, next_states, dones, indices, weights,] ] def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> None: @@ -181,3 +173,87 @@ def __len__(self) -> int: @property def pos(self): return len(self.buffer) + + +class MultiAgentReplayBuffer: + """ + Implements the basic Experience Replay Mechanism for MultiAgents + by feeding in global states, global actions, global rewards, + global next_states, global dones + :param capacity: Size of the replay buffer + :type capacity: int + :param num_agents: Number of agents in the environment + :type num_agents: int + """ + + def __init__(self, num_agents: int, capacity: int): + """ + Initialising the buffer + :param num_agents: number of agents in the environment + :type num_agents: int + :param capacity: Max buffer size + :type capacity: int + """ + self.capacity = capacity + self.num_agents = num_agents + self.memory = deque(maxlen=self.capacity) + + def push(self, inp: list) -> None: + """ + Adds new experience to buffer + :param inp: (Tuple containing `state`, `action`, `reward`, + `next_state` and `done`) + :type inp: tuple + :returns: None + """ + self.memory.append(inp) + + def sample(self, batch_size: int): + """Returns randomly sampled experiences from the replay memory + + Args: + batch_size (int): Number of samples per batch + """ + batch = random.sample(self.memory, batch_size) + segmented_batch = map(np.stack, zip(*batch)) + + for transition in batch: + for i, _ in enumerate(transition): + if i == 0 or i == 1 or i == 3: + transition[i] = np.concatenate( + np.array( + [transition[i][agent] for agent in transition[i].keys()] + ) + ) + else: + transition[i] = np.array( + [transition[i][agent] for agent in transition[i].keys()] + ) + + ( + global_states, + global_actions, + global_rewards, + global_next_states, + global_dones, + ) = map(np.stack, zip(*batch)) + + global_batch = [ + torch.from_numpy(v).float() + for v in [ + global_states, + global_actions, + global_rewards, + global_next_states, + global_dones, + ] + ] + + return segmented_batch, global_batch + + def __len__(self): + """ + Gives number of experiences in buffer currently + :returns: Length of replay memory + """ + return len(self.buffer) diff --git a/genrl/core/rollout_storage.py b/genrl/core/rollout_storage.py index 16d1c721..616434a2 100644 --- a/genrl/core/rollout_storage.py +++ b/genrl/core/rollout_storage.py @@ -102,8 +102,7 @@ def reset(self) -> None: self.full = False def sample( - self, - batch_size: int, + self, batch_size: int, ): """ :param batch_size: (int) Number of element to sample @@ -114,8 +113,7 @@ def sample( return self._get_samples(batch_inds) def _get_samples( - self, - batch_inds: np.ndarray, + self, batch_inds: np.ndarray, ): """ :param batch_inds: (torch.Tensor) @@ -257,3 +255,202 @@ def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: self.returns[batch_inds].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class MultiAgentRolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like MAA2C/MAA3C. + :param num_agents: (int) Max number of agents in the environment + :param buffer_size: (int) Max number of element in the buffer + :param env: (Environment) The environment being trained on + :param device: (torch.device) + :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: (float) Discount factor + :param n_envs: (int) Number of parallel environments + """ + + def __init__( + self, + num_agents: int, + buffer_size: int, + env, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + ): + super(MultiAgentRolloutBuffer, self).__init__(buffer_size, env, device) + + self.buffer_size = buffer_size + self.num_agents = num_agents + self.env = env + self.device = device + self.gae_lambda = gae_lambda + self.gamma = gamma + + self.observations, self.actions, self.rewards, self.advantages = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.returns, self.dones, self.values, self.log_probs = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = torch.zeros( + *(self.buffer_size, self.env.n_envs, self.num_agents, *self.env.obs_shape) + ) + self.actions = torch.zeros( + *( + self.buffer_size, + self.env.n_envs, + self.num_agents, + *self.env.action_shape, + ) + ) + self.rewards = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.returns = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.dones = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.values = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.advantages = torch.zeros( + self.buffer_size, self.env.n_envs, self.num_agents + ) + self.generator_ready = False + super(MultiAgentRolloutBuffer, self).reset() + + def add( + self, + obs: torch.zeros, + action: torch.zeros, + reward: torch.zeros, + done: torch.zeros, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """ + :param obs: (torch.zeros) Observation + :param action: (torch.zeros) Action + :param reward: (torch.zeros) + :param done: (torch.zeros) End of episode signal. + :param value: (torch.Tensor) estimated value of the current state + following the current policy. + :param log_prob: (torch.Tensor) log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + self.observations[self.pos] = obs.detach().clone() + self.actions[self.pos] = action.squeeze().detach().clone() + self.rewards[self.pos] = reward.detach().clone() + self.dones[self.pos] = done.detach().clone() + self.values[self.pos] = ( + value.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.log_probs[self.pos] = ( + log_prob.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.pos += 1 + + if self.pos == self.buffer_size: + self.full = True + + def get( + self, batch_size: Optional[int] = None + ) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.env.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.env.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.env.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten().reshape(-1, self.num_agents), + self.log_probs[batch_inds].flatten().reshape(-1, self.num_agents), + self.advantages[batch_inds].flatten().reshape(-1, self.num_agents), + self.returns[batch_inds].flatten().reshape(-1, self.num_agents), + ) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + def compute_returns_and_advantage( + self, last_value: torch.Tensor, dones: torch.zeros, use_gae: bool = False + ) -> None: + """ + Post-processing step: compute the returns (sum of discounted rewards) + and advantage (A(s) = R - V(S)). + Adapted from Stable-Baselines PPO2. + :param last_value: (torch.Tensor) + :param dones: (torch.zeros) + :param use_gae: (bool) Whether to use Generalized Advantage Estimation + or normal advantage for advantage computation. + """ + last_value = last_value.flatten().reshape(-1, self.num_agents) + + if use_gae: + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + next_value = self.values[step + 1] + delta = ( + self.rewards[step] + + self.gamma * next_value * next_non_terminal + - self.values[step] + ) + last_gae_lam = ( + delta + + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + ) + self.advantages[step] = last_gae_lam + self.returns = self.advantages + self.values + else: + # Discounted return with value bootstrap + # Note: this is equivalent to GAE computation + # with gae_lambda = 1.0 + last_return = 0.0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + last_return = self.rewards[step] + next_non_terminal * next_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + last_return = ( + self.rewards[step] + + self.gamma * last_return * next_non_terminal + ) + self.returns[step] = last_return + self.advantages = self.returns - self.values diff --git a/genrl/utils/__init__.py b/genrl/utils/__init__.py index b7f4070d..4eed6ff9 100644 --- a/genrl/utils/__init__.py +++ b/genrl/utils/__init__.py @@ -20,4 +20,6 @@ noisy_mlp, safe_mean, set_seeds, + onehot_from_logits ) +from genrl.utils.pettingzoo_interface import PettingZooInterface # noqa diff --git a/genrl/utils/pettingzoo_interface.py b/genrl/utils/pettingzoo_interface.py new file mode 100644 index 00000000..586feafd --- /dev/null +++ b/genrl/utils/pettingzoo_interface.py @@ -0,0 +1,85 @@ +from abc import ABC +from typing import Any, Dict, Tuple + +import gym +import numpy as np + + +class PettingZooInterface(ABC): + """ + An interface between the PettingZoo API and agents defined in GenRL + + Attributes: + + env (PettingZoo Environment) : The environments in which the agents are acting + agents_list (list) : A list containing all the agent objects present in the environment + """ + + def __init__(self, env: Any, agents_list: list): + self.env = env + self.agents_list = agents_list + + def get_env_properties(self, network: str): + state_dim = list(self.env.observation_spaces.values())[0].shape[0] + if isinstance(list(self.env.action_spaces.vales())[0], gym.spaces.Discrete): + discrete = True + action_dim = list(self.env.action_spaces.values())[0].n + action_lim = None + elif isinstance(list(self.env.action_spaces.values())[0], gym.spaces.Box): + discrete = False + action_dim = list(self.env.action_spaces.values())[0].shape[0] + action_lim = list(self.env.action_spaces.values())[0].high[0] + else: + raise NotImplementedError + + return state_dim, action_dim, discrete, action_lim + + def select_offpolicy_action( + self, state: np.ndarray, agent, noise, deterministic: bool = False + ): + action, _ = agent.ac.get_action(torch.tensor(state), deterministic) + action = action.detach() + + if noise is not None: + action += noise + + return torch.clamp( + action, + list(self.env.action_spaces.values())[0].low[0], + list(self.env.action_spaces.values())[0].high[0], + ).numpy() + + def select_onpolicy_action( + self, state: np.ndarray, agent, deterministic: bool = False + ): + raise NotImplementedError + + def get_actions( + self, + states: Dict[str, np.ndarray], + steps: int, + warmup_steps: int, + type: str, + deterministic: bool = False, + ): + if steps < warmup_steps: + actions = {agent: self.env.action_spaces[agent].sample() for key in states} + else: + if type == "offpolicy": + actions = { + agent: self.select_offpolicy_action( + states[agent], self.agents_list[i], deterministic, noise + ) + for i, agent in enumerate(states) + } + elif type == "onpolicy": + raise NotImplementedError + else: + raise NotImplementedError + + return actions + + def flatten(self, obj: Dict): + flattened_object = np.array([obj[agent] for agent in self.env.agents]) + + return flattened_object diff --git a/genrl/utils/utils.py b/genrl/utils/utils.py index 89e53337..9d9123e7 100644 --- a/genrl/utils/utils.py +++ b/genrl/utils/utils.py @@ -37,9 +37,7 @@ def get_model(type_: str, name_: str) -> Union: def mlp( - sizes: Tuple, - activation: str = "relu", - sac: bool = False, + sizes: Tuple, activation: str = "relu", sac: bool = False, ): """ Generates an MLP model given sizes of each layer @@ -199,3 +197,21 @@ def safe_mean(log: Union[torch.Tensor, List[int]]): else: func = np.mean return func(log) + + +def onehot_from_logits(self, logits, eps=0.0): + # get best (according to current policy) actions in one-hot form + argmax_acs = (logits == logits.max(0, keepdim=True)[0]).float() + if eps == 0.0: + return argmax_acs + # get random actions in one-hot form + rand_acs = torch.eye(logits.shape[1])[ + [np.random.choice(range(logits.shape[1]), size=logits.shape[0])] + ] + # chooses between best and random actions using epsilon greedy + return torch.stack( + [ + argmax_acs[i] if r > eps else rand_acs[i] + for i, r in enumerate(torch.rand(logits.shape[0])) + ] + )