diff --git a/genrl/agents/__init__.py b/genrl/agents/__init__.py index 3257caff..ac5d4a25 100644 --- a/genrl/agents/__init__.py +++ b/genrl/agents/__init__.py @@ -15,6 +15,7 @@ NeuralNoiseSamplingAgent, ) from genrl.agents.bandits.contextual.variational import VariationalAgent # noqa +from genrl.agents.bandits.multiarmed.base import MABAgent # noqa from genrl.agents.bandits.multiarmed.bayesian import BayesianUCBMABAgent # noqa from genrl.agents.bandits.multiarmed.bernoulli_mab import BernoulliMAB # noqa from genrl.agents.bandits.multiarmed.epsgreedy import EpsGreedyMABAgent # noqa @@ -41,5 +42,4 @@ from genrl.agents.deep.sac.sac import SAC # noqa from genrl.agents.deep.td3.td3 import TD3 # noqa from genrl.agents.deep.vpg.vpg import VPG # noqa - -from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa +from genrl.agents.offline.bcq.bcq import BCQ # noqa diff --git a/genrl/agents/deep/base/base.py b/genrl/agents/deep/base/base.py index cf3e40d3..33ce8876 100644 --- a/genrl/agents/deep/base/base.py +++ b/genrl/agents/deep/base/base.py @@ -17,8 +17,9 @@ class BaseAgent(ABC): create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards - layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network - of the Q-value function + policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy + value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics + shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function seed (int): Seed for randomness diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index fa632e8a..794f819a 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -5,11 +5,10 @@ from torch.nn import functional as F from genrl.agents.deep.base import BaseAgent -from genrl.core import ( + +from genrl.core import ( # PrioritizedReplayBufferSamples,; ReplayBufferSamples, PrioritizedBuffer, - PrioritizedReplayBufferSamples, ReplayBuffer, - ReplayBufferSamples, ) @@ -23,8 +22,9 @@ class OffPolicyAgent(BaseAgent): create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards - layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network - of the Q-value function + policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy + value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics + shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function replay_size (int): Capacity of the Replay Buffer @@ -67,19 +67,6 @@ def update_target_model(self) -> None: """ raise NotImplementedError - def _reshape_batch(self, batch: List): - """Function to reshape experiences - - Can be modified for individual algorithm usage - - Args: - batch (:obj:`list`): List of experiences that are being replayed - - Returns: - batch (:obj:`list`): Reshaped experiences for replay - """ - return [*batch] - def sample_from_buffer(self, beta: float = None): """Samples experiences from the buffer and converts them into usable formats @@ -95,18 +82,6 @@ def sample_from_buffer(self, beta: float = None): else: batch = self.replay_buffer.sample(self.batch_size) - states, actions, rewards, next_states, dones = self._reshape_batch(batch) - - # Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples. - if isinstance(self.replay_buffer, ReplayBuffer): - batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones]) - elif isinstance(self.replay_buffer, PrioritizedBuffer): - indices, weights = batch[5], batch[6] - batch = PrioritizedReplayBufferSamples( - *[states, actions, rewards, next_states, dones, indices, weights] - ) - else: - raise NotImplementedError return batch def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor: @@ -136,8 +111,9 @@ class OffPolicyAgentAC(OffPolicyAgent): create_model (bool): Whether the model of the algo should be created when initialised batch_size (int): Mini batch size for loading experiences gamma (float): The discount factor for rewards - layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network - of the Q-value function + policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy + value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics + shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using lr_policy (float): Learning rate for the policy/actor lr_value (float): Learning rate for the Q-value function replay_size (int): Capacity of the Replay Buffer @@ -154,7 +130,7 @@ def __init__(self, *args, polyak=0.995, **kwargs): self.doublecritic = False def select_action( - self, state: torch.Tensor, deterministic: bool = True + self, state: torch.Tensor, deterministic: bool = True, noise: bool = True ) -> torch.Tensor: """Select action given state @@ -163,6 +139,7 @@ def select_action( Args: state (:obj:`torch.Tensor`): Current state of the environment deterministic (bool): Should the policy be deterministic or stochastic + noise (bool): Should noise be added to the agent Returns: action (:obj:`torch.Tensor`): Action taken by the agent @@ -171,7 +148,7 @@ def select_action( action = action.detach() # add noise to output from policy network - if self.noise is not None: + if noise and self.noise is not None: action += self.noise() return torch.clamp( @@ -210,7 +187,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten def get_target_q_values( self, next_states: torch.Tensor, rewards: List[float], dones: List[bool] ) -> torch.Tensor: - """Get target Q values for the TD3 + """Get target Q values Args: next_states (:obj:`torch.Tensor`): Next states for which target Q-values @@ -219,7 +196,7 @@ def get_target_q_values( dones (:obj:`list`): Game over status for each environment Returns: - target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3 + target_q_values (:obj:`torch.Tensor`): Target Q values """ next_target_actions = self.ac_target.get_action(next_states, True)[0] @@ -265,7 +242,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor: Returns: loss (:obj:`torch.Tensor`): Calculated policy loss """ - next_best_actions = self.ac.get_action(states, True)[0] + next_best_actions = self.select_action(states, deterministic=True, noise=False) q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1)) policy_loss = -torch.mean(q_values) return policy_loss diff --git a/genrl/agents/deep/base/onpolicy.py b/genrl/agents/deep/base/onpolicy.py index 0c83cdfd..66cac380 100644 --- a/genrl/agents/deep/base/onpolicy.py +++ b/genrl/agents/deep/base/onpolicy.py @@ -36,7 +36,7 @@ def __init__( if buffer_type == "rollout": self.rollout = RolloutBuffer( - self.rollout_size, self.env, gae_lambda=gae_lambda + self.rollout_size, self.env, gae_lambda=gae_lambda, gamma=self.gamma ) else: raise NotImplementedError diff --git a/genrl/agents/deep/dqn/base.py b/genrl/agents/deep/dqn/base.py index 7642e541..09a149f8 100644 --- a/genrl/agents/deep/dqn/base.py +++ b/genrl/agents/deep/dqn/base.py @@ -56,7 +56,7 @@ def __init__( if self.create_model: self._create_model() - def _create_model(self, *args, **kwargs) -> None: + def _create_model(self, **kwargs) -> None: """Function to initialize Q-value model This will create the Q-value function of the agent. @@ -153,7 +153,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten q_values (:obj:`torch.Tensor`): Q values for the given states and actions """ q_values = self.model(states) - q_values = q_values.gather(2, actions) + q_values = q_values.gather(2, actions.unsqueeze(-1)) return q_values def get_target_q_values( diff --git a/genrl/agents/deep/dqn/utils.py b/genrl/agents/deep/dqn/utils.py index f1cb9015..739ddfde 100644 --- a/genrl/agents/deep/dqn/utils.py +++ b/genrl/agents/deep/dqn/utils.py @@ -102,8 +102,10 @@ def categorical_q_values(agent: DQN, states: torch.Tensor, actions: torch.Tensor # Size of q_value_dist should be [batch_size, n_envs, action_dim, num_atoms] here # To gather the q_values of the respective actions, actions must be of the shape: # [batch_size, n_envs, 1, num_atoms]. It's current shape is [batch_size, n_envs, 1] - actions = actions.unsqueeze(-1).expand( - agent.batch_size, agent.env.n_envs, 1, agent.num_atoms + actions = ( + actions.unsqueeze(-1) + .unsqueeze(-1) + .expand(agent.batch_size, agent.env.n_envs, 1, agent.num_atoms) ) # Now as we gather q_values from the action_dim dimension which is at index 2 q_values = q_value_dist.gather(2, actions) diff --git a/genrl/agents/offline/__init__.py b/genrl/agents/offline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/offline/bcq/__init__.py b/genrl/agents/offline/bcq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/offline/bcq/bcq.py b/genrl/agents/offline/bcq/bcq.py new file mode 100644 index 00000000..a41574a5 --- /dev/null +++ b/genrl/agents/offline/bcq/bcq.py @@ -0,0 +1,252 @@ +import collections +from copy import deepcopy +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn.functional as F +import torch.optim as opt + +from genrl.agents import OffPolicyAgentAC +from genrl.core.models import VAE +from genrl.core.noise import ActionNoise +from genrl.core.policies import MlpPolicy +from genrl.utils.utils import get_env_properties, get_model, safe_mean + + +class BCQ(OffPolicyAgentAC): + """Batch Constrained Q-Learning + + Paper: https://arxiv.org/abs/1812.02900 + + Attributes: + network (str): The network type of the Q-value function. + Supported types: ["cnn", "mlp"] + env (Environment): The environment that the agent is supposed to act on + create_model (bool): Whether the model of the algo should be created when initialised + batch_size (int): Mini batch size for loading experiences + gamma (float): The discount factor for rewards + policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy + value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics + shared_layers (:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using + vae_layers (:obj:`tuple` of :obj:`int`): Sizes of hidden layers in the VAE + lr_policy (float): Learning rate for the policy/actor + lr_value (float): Learning rate for the Q-value function + lr_vae (float): Learning rate for the VAE + replay_size (int): Capacity of the Replay Buffer + buffer_type (str): Choose the type of Buffer: ["push", "prioritized"] + noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration + noise_std (float): Standard deviation of the action noise distribution + seed (int): Seed for randomness + render (bool): Should the env be rendered during training? + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + def __init__( + self, + *args, + noise: ActionNoise = None, + noise_std: float = 0.2, + vae_layers: Tuple = (32, 32), + lr_vae: float = 0.001, + **kwargs + ): + super(BCQ, self).__init__(*args, **kwargs) + self.noise = noise + self.noise_std = noise_std + self.vae_layers = vae_layers + self.lr_vae = lr_vae + self.doublecritic = True + + self._create_model() + self.empty_logs() + + def _create_model(self) -> None: + """Function to initialize the BCQ Model + + This will create the BCQ Q-networks and the VAE + """ + state_dim, action_dim, discrete, action_lim = get_env_properties( + self.env, self.network + ) + if discrete: + raise Exception( + "Only continuous Environments are supported for the original BCQ. For discrete BCQ, use DiscreteBCQ instead" + ) + + if isinstance(self.network, str): + arch = self.network + "12" + if self.shared_layers is not None: + arch += "s" + self.ac = get_model("ac", arch)( + state_dim, + action_dim, + shared_layers=self.shared_layers, + policy_layers=self.policy_layers, + value_layers=self.value_layers, + val_type="Qsa", + discrete=False, + ) + self.ac.actor = MlpPolicy( + state_dim + action_dim, action_dim, self.policy_layers, discrete + ) + self.vae = VAE(state_dim, action_dim, action_lim, self.vae_layers) + else: + ( + self.ac, + self.vae, + ) = ( + self.network + ) # Network must be defined as a tuple of the Actor Critic Network and the VAE + + # Perturbation Model of the BCQ + if self.noise is not None: + self.noise = self.noise( + torch.zeros(action_dim), self.noise_std * torch.ones(action_dim) + ) + + self.ac_target = deepcopy(self.ac) + actor_params, critic_params = self.ac.get_params() + self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value) + self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy) + self.optimizer_vae = opt.Adam(self.vae.parameters(), lr=self.lr_vae) + + def select_action( + self, + state: torch.Tensor, + deterministic: bool = True, + noise=True, + ) -> torch.Tensor: + """Select action given state + + Deterministic Action Selection with Noise + + Args: + state (:obj:`torch.Tensor`): Current state of the environment + deterministic (bool): Should the policy be deterministic or stochastic + + Returns: + action (:obj:`torch.Tensor`): Action taken by the agent + """ + action, _ = self.ac.get_action( + torch.cat([state, self.vae.decode(state)], dim=-1), deterministic + ) + action = action.detach() + + # add noise to output from policy network + if noise and self.noise is not None: + action += self.noise() + + return torch.clamp( + action, self.env.action_space.low[0], self.env.action_space.high[0] + ) + + def get_vae_loss(self) -> None: + """BCQ Function to calculate the loss of the VAE + + Returns: + loss (:obj:`torch.Tensor`): Calculated loss of the VAE of the BCQ + """ + recon, mean, std = self.vae(self.batch.states, self.batch.actions) + recon_loss = F.mse_loss(recon, self.batch.actions) + KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() + vae_loss = recon_loss + 0.5 * KL_loss + return vae_loss + + def get_target_q_values( + self, next_states: torch.Tensor, rewards: List[float], dones: List[bool] + ) -> torch.Tensor: + """Get target Q values for the BCQ + + Args: + next_states (:obj:`torch.Tensor`): Next states for which target Q-values + need to be found + rewards (:obj:`list`): Rewards at each timestep for each environment + dones (:obj:`list`): Game over status for each environment + + Returns: + target_q_values (:obj:`torch.Tensor`): Target Q values for the BCQ + """ + # next_states = torch.repeat_interleave(next_states, 10, 0) + next_state_actions = torch.cat( + [next_states, self.vae.decode(next_states)], dim=-1 + ) + next_target_actions = self.ac_target.get_action(next_state_actions, True)[0] + + next_q_target_values = self.ac_target.get_value( + torch.cat([next_states, next_target_actions], dim=-1), mode="min" + ) + target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values + + return target_q_values + + def update_params(self) -> None: + """Update parameters of the model""" + self.batch = self.sample_from_buffer() + + vae_loss = self.get_vae_loss() + + self.optimizer_vae.zero_grad() + vae_loss.backward() + self.optimizer_vae.step() + + value_loss = self.get_q_loss(self.batch) + + self.optimizer_value.zero_grad() + value_loss.backward() + self.optimizer_value.step() + + policy_loss = self.get_p_loss(self.batch.states) + + self.optimizer_policy.zero_grad() + policy_loss.backward() + self.optimizer_policy.step() + + self.logs["policy_loss"].append(policy_loss.item()) + self.logs["value_loss"].append(value_loss.item()) + self.logs["vae_loss"].append(vae_loss.item()) + + self.update_target_model() + + def get_hyperparams(self) -> Dict[str, Any]: + """Get relevant hyperparameters to save + + Returns: + hyperparams (:obj:`dict`): Hyperparameters to be saved + weights (:obj:`torch.Tensor`): Neural network weights + """ + hyperparams = { + "network": self.network, + "gamma": self.gamma, + "batch_size": self.batch_size, + "replay_size": self.replay_size, + "lr_policy": self.lr_policy, + "lr_value": self.lr_value, + "lr_vae": self.lr_vae, + "polyak": self.polyak, + "noise_std": self.noise_std, + } + + return hyperparams, self.ac.state_dict() + + def get_logging_params(self) -> Dict[str, Any]: + """Gets relevant parameters for logging + + Returns: + logs (:obj:`dict`): Logging parameters for monitoring training + """ + logs = { + "policy_loss": safe_mean(self.logs["policy_loss"]), + "value_loss": safe_mean(self.logs["value_loss"]), + "vae_loss": safe_mean(self.logs["vae_loss"]), + } + + self.empty_logs() + return logs + + def empty_logs(self): + """Empties logs""" + self.logs = {} + self.logs["policy_loss"] = [] + self.logs["value_loss"] = [] + self.logs["vae_loss"] = [] diff --git a/genrl/core/__init__.py b/genrl/core/__init__.py index 96824ac1..bf39b926 100644 --- a/genrl/core/__init__.py +++ b/genrl/core/__init__.py @@ -1,10 +1,7 @@ from genrl.core.actor_critic import MlpActorCritic, get_actor_critic_from_name # noqa -from genrl.core.bandit import Bandit, BanditAgent +from genrl.core.bandit import Bandit, BanditAgent # noqa from genrl.core.base import BaseActorCritic # noqa -from genrl.core.buffers import PrioritizedBuffer # noqa -from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa -from genrl.core.buffers import ReplayBuffer # noqa -from genrl.core.buffers import ReplayBufferSamples # noqa +from genrl.core.buffers import PrioritizedBuffer, ReplayBuffer # noqa from genrl.core.noise import ActionNoise # noqa from genrl.core.noise import NoisyLinear # noqa from genrl.core.noise import NormalActionNoise # noqa @@ -15,7 +12,7 @@ MlpPolicy, get_policy_from_name, ) -from genrl.core.rollout_storage import RolloutBuffer # noqa +from genrl.core.rollouts import RolloutBuffer # noqa from genrl.core.values import ( # noqa BaseValue, CnnCategoricalValue, diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1a96e959..bb05969f 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -184,7 +184,7 @@ def __init__( action_dim: spaces.Space, policy_layers: Tuple = (32, 32), value_layers: Tuple = (32, 32), - val_type: str = "V", + val_type: str = "Qsa", discrete: bool = True, num_critics: int = 2, **kwargs, @@ -194,8 +194,8 @@ def __init__( self.num_critics = num_critics self.actor = MlpPolicy(state_dim, action_dim, policy_layers, discrete, **kwargs) - self.critic1 = MlpValue(state_dim, action_dim, "Qsa", value_layers, **kwargs) - self.critic2 = MlpValue(state_dim, action_dim, "Qsa", value_layers, **kwargs) + self.critic1 = MlpValue(state_dim, action_dim, val_type, value_layers, **kwargs) + self.critic2 = MlpValue(state_dim, action_dim, val_type, value_layers, **kwargs) self.action_scale = kwargs["action_scale"] if "action_scale" in kwargs else 1 self.action_bias = kwargs["action_bias"] if "action_bias" in kwargs else 0 diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..3aa73704 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -1,110 +1,210 @@ +import os import random from collections import deque from typing import NamedTuple, Tuple -import numpy as np import torch -class ReplayBufferSamples(NamedTuple): - states: torch.Tensor - actions: torch.Tensor - rewards: torch.Tensor - next_states: torch.Tensor - dones: torch.Tensor +class BaseBuffer(object): + """Base class that represents a buffer (rollout or replay) + + Attributes: + buffer_size (int): Max number of elements in the buffer + """ + def __init__(self, buffer_size: int): + super(BaseBuffer, self).__init__() + self.buffer_size = buffer_size + self.pos = 0 + self.full = False -class PrioritizedReplayBufferSamples(NamedTuple): + @staticmethod + def swap_and_flatten(arr: torch.Tensor) -> torch.Tensor: + """Swap and Flatten method + + Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) + to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) + to [n_steps * n_envs, ...] (which maintain the order) + + Args: + arr (:obj:`torch.Tensor`): Array to modify + + Returns: + new_arr (:obj:`torch.Tensor`): Modified Array + """ + shape = arr.shape + if len(shape) < 3: + arr = arr.unsqueeze(-1) + shape = shape + (1,) + + return arr.permute(1, 0, *(torch.arange(2, len(shape)))).reshape( + shape[0] * shape[1], *shape[2:] + ) + + def size(self) -> int: + """Returns size of the buffer + + Returns: + size (int): The current size of the buffer + """ + raise NotImplementedError + + def add(self, *args, **kwargs) -> None: + """Adds elements to the buffer""" + raise NotImplementedError + + def reset(self) -> None: + """Resets the buffer""" + self.pos = 0 + self.full = False + + def sample(self, batch_size: int) -> Tuple: + """Sample from the buffer + + Args: + batch_size (int): Number of element to sample + + Returns: + samples (:obj:`namedtuple`): Named tuple of the sampled experiences + """ + raise NotImplementedError + + def save(self, directory: str = None, run_num: int = None) -> None: + """Saves the buffer locally + + The buffers are saved locally so they can be used as a dataset for + Offline RL or for other purposes + + Args: + directory (string): Directory to save buffers in + run_num (int): The run number associated with the training run + """ + raise NotImplementedError + + def load(self, path: str) -> None: + """Loads the buffer from the file + + Args: + path (str): Path of the pickled file of the buffer replays/rollouts + """ + if not os.path.exists(path): + raise FileNotFoundError + + raise NotImplementedError + + +class ReplayBufferSamples(NamedTuple): states: torch.Tensor actions: torch.Tensor rewards: torch.Tensor next_states: torch.Tensor dones: torch.Tensor - indices: torch.Tensor - weights: torch.Tensor -class ReplayBuffer: - """ - Implements the basic Experience Replay Mechanism +class ReplayBuffer(BaseBuffer): + """Vanilla Experience Replay Buffer - :param capacity: Size of the replay buffer - :type capacity: int + Attributes: + buffer_size (int): Max number of element in the buffer + device (:obj:`torch.device` or str): PyTorch device to which the values will be converted """ - def __init__(self, capacity: int): - self.capacity = capacity - self.memory = deque([], maxlen=capacity) + def __init__(self, *args, **kwargs): + super(ReplayBuffer, self).__init__(*args, **kwargs) + self.buffer = deque([], maxlen=self.buffer_size) - def push(self, inp: Tuple) -> None: + def size(self) -> int: + """Returns size of the buffer + + Returns: + size (int): The current size of the buffer """ - Adds new experience to buffer + return len(self.buffer) - :param inp: Tuple containing state, action, reward, next_state and done - :type inp: tuple - :returns: None + def add(self, experience: Tuple) -> None: + """Adds elements to the buffer + + Args: + experience (:obj:`tuple`): Tuple containing state, action, reward, next_state and done """ - self.memory.append(inp) + self.buffer.append(experience) def sample( self, batch_size: int - ) -> (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): - """ - Returns randomly sampled experiences from replay memory + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample from the buffer - :param batch_size: Number of samples per batch - :type batch_size: int - :returns: (Tuple composing of `state`, `action`, `reward`, - `next_state` and `done`) + Args: + batch_size (int): Number of element to sample + + Returns: + samples (:obj:`namedtuple`): Named tuple of the sampled experiences """ - batch = random.sample(self.memory, batch_size) - state, action, reward, next_state, done = map(np.stack, zip(*batch)) - return [ - torch.from_numpy(v).float() - for v in [state, action, reward, next_state, done] - ] + batch = random.sample(self.buffer, batch_size) + states, actions, rewards, next_states, dones = map(torch.stack, zip(*batch)) + return ReplayBufferSamples(states, actions, rewards, next_states, dones) - def __len__(self) -> int: + def save(self, path: str = None) -> None: + """Saves the buffer locally + + The buffers are saved locally so they can be used as a dataset for + Offline RL or for other purposes + + Args: + path (string): Path to save buffers in """ - Gives number of experiences in buffer currently + torch.save(self.buffer, path) - :returns: Length of replay memory + def load(self, path: str) -> None: + """Loads the buffer from the file + + Args: + path (str): Path of the pickled file of the buffer replays/rollouts """ - return self.pos + self.buffer = torch.load(path) -class PrioritizedBuffer: - """ - Implements the Prioritized Experience Replay Mechanism +class PrioritizedReplayBufferSamples(NamedTuple): + states: torch.Tensor + actions: torch.Tensor + rewards: torch.Tensor + next_states: torch.Tensor + dones: torch.Tensor + indices: torch.Tensor + weights: torch.Tensor + - :param capacity: Size of the replay buffer - :param alpha: Level of prioritization - :type capacity: int - :type alpha: int +class PrioritizedBuffer(BaseBuffer): + """Prioritized Experience Replay Mechanism + + Attributes: + buffer_size (int): Max number of element in the buffer + alpha (float): Level of prioritization + beta (float): Bias factor used to correct IS Weights + device (:obj:`torch.device` or str): PyTorch device to which the values will be converted """ - def __init__(self, capacity: int, alpha: float = 0.6, beta: float = 0.4): + def __init__(self, *args, alpha: float = 0.6, beta: float = 0.4, **kwargs): + super(PrioritizedBuffer, self).__init__(*args, **kwargs) + self.alpha = alpha self.beta = beta - self.capacity = capacity - self.buffer = deque([], maxlen=capacity) - self.priorities = deque([], maxlen=capacity) + self.buffer = deque([], maxlen=self.buffer_size) + self.priorities = deque([], maxlen=self.buffer_size) - def push(self, inp: Tuple) -> None: - """ - Adds new experience to buffer + def add(self, experience: Tuple) -> None: + """Adds elements to the buffer - :param inp: (Tuple containing `state`, `action`, `reward`, - `next_state` and `done`) - :type inp: tuple - :returns: None + Args: + experience (:obj:`tuple`): Tuple containing state, action, reward, next_state and done """ + self.buffer.append(experience) max_priority = max(self.priorities) if self.priorities else 1.0 - self.buffer.append(inp) self.priorities.append(max_priority) def sample( - self, batch_size: int, beta: float = None + self, batch_size: int ) -> ( Tuple[ torch.Tensor, @@ -116,46 +216,35 @@ def sample( torch.Tensor, ] ): - """ - (Returns randomly sampled memories from replay memory along with their - respective indices and weights) - - :param batch_size: Number of samples per batch - :param beta: (Bias exponent used to correct - Importance Sampling (IS) weights) - :type batch_size: int - :type beta: float - :returns: (Tuple containing `states`, `actions`, `next_states`, - `rewards`, `dones`, `indices` and `weights`) - """ - if beta is None: - beta = self.beta + """Sample from the buffer + Args: + batch_size (int): Number of element to sample + + Returns: + samples (:obj:`namedtuple`): Named tuple of the sampled experiences + """ total = len(self.buffer) - priorities = np.asarray(self.priorities) + priorities = torch.FloatTensor(self.priorities) probabilities = priorities ** self.alpha probabilities /= probabilities.sum() - indices = np.random.choice(total, batch_size, p=probabilities) + indices = torch.multinomial(probabilities, batch_size) - weights = (total * probabilities[indices]) ** (-beta) + weights = (total * probabilities[indices]) ** (-self.beta) weights /= weights.max() - weights = np.asarray(weights, dtype=np.float32) samples = [self.buffer[i] for i in indices] - (states, actions, rewards, next_states, dones) = map(np.stack, zip(*samples)) - - return [ - torch.as_tensor(v, dtype=torch.float32) - for v in [ - states, - actions, - rewards, - next_states, - dones, - indices, - weights, - ] - ] + (states, actions, rewards, next_states, dones) = map(torch.stack, zip(*samples)) + + return PrioritizedReplayBufferSamples( + states, + actions, + rewards, + next_states, + dones, + indices, + weights, + ) def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> None: """ @@ -170,14 +259,10 @@ def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> No for idx, priority in zip(batch_indices, batch_priorities): self.priorities[int(idx)] = priority.mean() - def __len__(self) -> int: - """ - Gives number of experiences in buffer currently + def size(self) -> int: + """Returns size of the buffer - :returns: Length of replay memory + Returns: + size (int): The current size of the buffer """ return len(self.buffer) - - @property - def pos(self): - return len(self.buffer) diff --git a/genrl/core/models.py b/genrl/core/models.py new file mode 100644 index 00000000..faa15abf --- /dev/null +++ b/genrl/core/models.py @@ -0,0 +1,99 @@ +from typing import Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from genrl.utils.utils import mlp + + +class VAE(nn.Module): + """VAE model to be used + + Currently only used in BCQ + + Attributes: + state_dim (int): State dimensions of the environment + action_dim (int): Action space dimensions of the environment + action_lim (float): Maximum action that can be taken. Used to scale the decoder output to action space + hidden_layers (:obj:`list` or :obj:`tuple`): Hidden layers in the Encoder and Decoder + (will be reversed to use in the decoder) + latent_dim (int): Dimensions of the latent space for the VAE + activation (str): Activation function to be used. Can be either "tanh" or "relu" + """ + + def __init__( + self, + state_dim: int, + action_dim: int, + action_lim: float, + hidden_layers: Tuple = (32, 32), + latent_dim: int = None, + activation: str = "relu", + ): + super(VAE, self).__init__() + + self.latent_dim = latent_dim if latent_dim is not None else action_dim * 2 + self.max_action = action_lim + + self.encoder = mlp( + [state_dim + action_dim, hidden_layers], activation=activation + ) + + self.mean = nn.Linear(hidden_layers[-1], self.latent_dim) + self.log_std = nn.Linear(hidden_layers[-1], self.latent_dim) + + self.decoder = mlp( + [state_dim + self.latent_dim, hidden_layers[::-1], action_dim], + activation=activation, + ) + + def forward( + self, state: torch.Tensor, action: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass through VAE + + Passes a concatenated vector of the states and actions to an encoder. The latent space is then + modelled by the variable z. The latent space vector is then passed through the decoder to give + the VAE-suggested actions. + + Args: + state (:obj:`torch.Tensor`): State being observed by agent + action (:obj:`torch.Tensor`): Action being observed by the agent for the respective state + + Returns: + u (:obj:`torch.Tensor`): VAE-suggested action + mean (:obj:`torch.Tensor`): Mean of VAE latent space + std (:obj:`torch.Tensor`): Standard deviation of VAE latent space + """ + e = F.relu(self.encoder(torch.cat([state, action], dim=-1))) + + mean = self.mean(e) + log_std = self.log_std(e).clamp(-4, 15) + std = torch.exp(log_std) + z = mean + std * torch.randn_like(std) + + u = self.decode(state, z) + return u, mean, std + + def decode(self, state: torch.Tensor, z: torch.Tensor = None) -> torch.Tensor: + """Decoder output + + Decodes a given state to give an action + + Args: + state (:obj:`torch.Tensor`): State being observed by agent + z (:obj:`torch.Tensor`): Latent space vector + + Returns: + u (:obj:`torch.Tensor`): VAE-suggested action + """ + if z is None: + z = ( + torch.randn((*state.shape[:-1], self.latent_dim)) + .to(self.device) + .clamp(-0.5, 0.5) + ) + + d = F.tanh(self.decoder(torch.cat([state, z], dim=-1))) + return self.max_action * d diff --git a/genrl/core/rollout_storage.py b/genrl/core/rollout_storage.py deleted file mode 100644 index 16d1c721..00000000 --- a/genrl/core/rollout_storage.py +++ /dev/null @@ -1,259 +0,0 @@ -from typing import Generator, NamedTuple, Optional, Union - -import gym -import numpy as np -import torch - -from genrl.environments.vec_env import VecEnv - - -class RolloutBufferSamples(NamedTuple): - observations: torch.Tensor - actions: torch.Tensor - old_values: torch.Tensor - old_log_prob: torch.Tensor - advantages: torch.Tensor - returns: torch.Tensor - - -class ReplayBufferSamples(NamedTuple): - observations: torch.Tensor - actions: torch.Tensor - next_observations: torch.Tensor - dones: torch.Tensor - rewards: torch.Tensor - - -class RolloutReturn(NamedTuple): - episode_reward: float - episode_timesteps: int - n_episodes: int - continue_training: bool - - -class BaseBuffer(object): - """ - Base class that represent a buffer (rollout or replay) - :param buffer_size: (int) Max number of element in the buffer - :param env: (Environment) The environment being trained on - :param device: (Union[torch.device, str]) PyTorch device - to which the values will be converted - :param n_envs: (int) Number of parallel environments - """ - - def __init__( - self, - buffer_size: int, - env: Union[gym.Env, VecEnv], - device: Union[torch.device, str] = "cpu", - ): - super(BaseBuffer, self).__init__() - self.buffer_size = buffer_size - self.env = env - self.pos = 0 - self.full = False - self.device = device - - @staticmethod - def swap_and_flatten(arr: np.ndarray) -> np.ndarray: - """ - Swap and then flatten axes 0 (buffer_size) and 1 (n_envs) - to convert shape from [n_steps, n_envs, ...] (when ... is the shape of the features) - to [n_steps * n_envs, ...] (which maintain the order) - :param arr: (np.ndarray) - :return: (np.ndarray) - """ - shape = arr.shape - if len(shape) < 3: - arr = arr.unsqueeze(-1) - shape = shape + (1,) - - return arr.permute(1, 0, *(np.arange(2, len(shape)))).reshape( - shape[0] * shape[1], *shape[2:] - ) - - def size(self) -> int: - """ - :return: (int) The current size of the buffer - """ - if self.full: - return self.buffer_size - return self.pos - - def add(self, *args, **kwargs) -> None: - """ - Add elements to the buffer. - """ - raise NotImplementedError() - - def extend(self, *args, **kwargs) -> None: - """ - Add a new batch of transitions to the buffer - """ - # Do a for loop along the batch axis - for data in zip(*args): - self.add(*data) - - def reset(self) -> None: - """ - Reset the buffer. - """ - self.pos = 0 - self.full = False - - def sample( - self, - batch_size: int, - ): - """ - :param batch_size: (int) Number of element to sample - :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) - """ - upper_bound = self.buffer_size if self.full else self.pos - batch_inds = np.random.randint(0, upper_bound, size=batch_size) - return self._get_samples(batch_inds) - - def _get_samples( - self, - batch_inds: np.ndarray, - ): - """ - :param batch_inds: (torch.Tensor) - :return: (Union[RolloutBufferSamples, ReplayBufferSamples]) - """ - raise NotImplementedError() - - def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: - """ - Convert a numpy array to a PyTorch tensor. - Note: it copies the data by default - :param array: (np.ndarray) - :param copy: (bool) Whether to copy or not the data - (may be useful to avoid changing things be reference) - :return: (torch.Tensor) - """ - if copy: - return array.detach().clone() - return array - - -class RolloutBuffer(BaseBuffer): - """ - Rollout buffer used in on-policy algorithms like A2C/PPO. - :param buffer_size: (int) Max number of element in the buffer - :param env: (Environment) The environment being trained on - :param device: (torch.device) - :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator - Equivalent to classic advantage when set to 1. - :param gamma: (float) Discount factor - :param n_envs: (int) Number of parallel environments - """ - - def __init__( - self, - buffer_size: int, - env: Union[gym.Env, VecEnv], - device: Union[torch.device, str] = "cpu", - gae_lambda: float = 1, - gamma: float = 0.99, - ): - - super(RolloutBuffer, self).__init__(buffer_size, env, device) - self.gae_lambda = gae_lambda - self.gamma = gamma - self.observations, self.actions, self.rewards, self.advantages = ( - None, - None, - None, - None, - ) - self.returns, self.dones, self.values, self.log_probs = None, None, None, None - self.generator_ready = False - self.reset() - - def reset(self) -> None: - self.observations = torch.zeros( - *(self.buffer_size, self.env.n_envs, *self.env.obs_shape) - ) - self.actions = torch.zeros( - *(self.buffer_size, self.env.n_envs, *self.env.action_shape) - ) - self.rewards = torch.zeros(self.buffer_size, self.env.n_envs) - self.returns = torch.zeros(self.buffer_size, self.env.n_envs) - self.dones = torch.zeros(self.buffer_size, self.env.n_envs) - self.values = torch.zeros(self.buffer_size, self.env.n_envs) - self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs) - self.advantages = torch.zeros(self.buffer_size, self.env.n_envs) - self.generator_ready = False - super(RolloutBuffer, self).reset() - - def add( - self, - obs: torch.zeros, - action: torch.zeros, - reward: torch.zeros, - done: torch.zeros, - value: torch.Tensor, - log_prob: torch.Tensor, - ) -> None: - """ - :param obs: (torch.zeros) Observation - :param action: (torch.zeros) Action - :param reward: (torch.zeros) - :param done: (torch.zeros) End of episode signal. - :param value: (torch.Tensor) estimated value of the current state - following the current policy. - :param log_prob: (torch.Tensor) log probability of the action - following the current policy. - """ - if len(log_prob.shape) == 0: - # Reshape 0-d tensor to avoid error - log_prob = log_prob.reshape(-1, 1) - - self.observations[self.pos] = obs.detach().clone() - self.actions[self.pos] = action.detach().clone() - self.rewards[self.pos] = reward.detach().clone() - self.dones[self.pos] = done.detach().clone() - self.values[self.pos] = value.detach().clone().flatten() - self.log_probs[self.pos] = log_prob.detach().clone().flatten() - self.pos += 1 - if self.pos == self.buffer_size: - self.full = True - - def get( - self, batch_size: Optional[int] = None - ) -> Generator[RolloutBufferSamples, None, None]: - assert self.full, "" - indices = np.random.permutation(self.buffer_size * self.env.n_envs) - # Prepare the data - if not self.generator_ready: - for tensor in [ - "observations", - "actions", - "values", - "log_probs", - "advantages", - "returns", - ]: - self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) - self.generator_ready = True - - # Return everything, don't create minibatches - if batch_size is None: - batch_size = self.buffer_size * self.env.n_envs - - start_idx = 0 - while start_idx < self.buffer_size * self.env.n_envs: - yield self._get_samples(indices[start_idx : start_idx + batch_size]) - start_idx += batch_size - - def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: - data = ( - self.observations[batch_inds], - self.actions[batch_inds], - self.values[batch_inds].flatten(), - self.log_probs[batch_inds].flatten(), - self.advantages[batch_inds].flatten(), - self.returns[batch_inds].flatten(), - ) - return RolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/genrl/core/rollouts.py b/genrl/core/rollouts.py new file mode 100644 index 00000000..54838965 --- /dev/null +++ b/genrl/core/rollouts.py @@ -0,0 +1,142 @@ +from typing import Generator, NamedTuple, Optional, Union + +import gym +import torch + +from genrl.core.buffers import BaseBuffer +from genrl.environments.vec_env import VecEnv + + +class RolloutBufferSamples(NamedTuple): + observations: torch.Tensor + actions: torch.Tensor + old_values: torch.Tensor + old_log_prob: torch.Tensor + advantages: torch.Tensor + returns: torch.Tensor + + +class RolloutBuffer(BaseBuffer): + """Rollout buffer used in on-policy algorithms like A2C/PPO + + Attributes: + buffer_size (int): Max number of element in the buffer + env (Environment): The environment being trained on + device (:obj:`torch.device` or str): PyTorch device to which the values will be converted + gae_lambda (float): Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + gamma (float): Discount factor + """ + + def __init__( + self, + buffer_size: int, + env: Union[gym.Env, VecEnv], + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = None, + ): + super(RolloutBuffer, self).__init__(buffer_size) + self.device = device if type(device) is torch.device else torch.device(device) + self.env = env + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = ( + None, + None, + None, + None, + ) + self.returns, self.dones, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + """Resets the buffer""" + self.observations = torch.zeros( + *(self.buffer_size, self.env.n_envs, *self.env.obs_shape) + ) + self.actions = torch.zeros( + *(self.buffer_size, self.env.n_envs, *self.env.action_shape) + ) + self.rewards = torch.zeros(self.buffer_size, self.env.n_envs) + self.returns = torch.zeros(self.buffer_size, self.env.n_envs) + self.dones = torch.zeros(self.buffer_size, self.env.n_envs) + self.values = torch.zeros(self.buffer_size, self.env.n_envs) + self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs) + self.advantages = torch.zeros(self.buffer_size, self.env.n_envs) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( + self, + obs: torch.Tensor, + action: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """Adds elements to the buffer + + Args: + obs (torch.Tensor): Observation + action (torch.Tensor): Action + reward (torch.Tensor): Reward + done (torch.Tensor): End of episode signal. + value (torch.Tensor): Estimated value of the current state + following the current policy. + log_prob (torch.Tensor): Log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + self.observations[self.pos] = obs.detach().clone() + self.actions[self.pos] = action.detach().clone() + self.rewards[self.pos] = reward.detach().clone() + self.dones[self.pos] = done.detach().clone() + self.values[self.pos] = value.detach().clone().flatten() + self.log_probs[self.pos] = log_prob.detach().clone().flatten() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get( + self, batch_size: Optional[int] = None + ) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = torch.randperm(self.buffer_size * self.env.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.env.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.env.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: torch.Tensor) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return RolloutBufferSamples(*data) diff --git a/genrl/trainers/__init__.py b/genrl/trainers/__init__.py index 7410831b..9c4557dc 100644 --- a/genrl/trainers/__init__.py +++ b/genrl/trainers/__init__.py @@ -1,5 +1,6 @@ from genrl.trainers.bandit import BanditTrainer, DCBTrainer, MABTrainer # noqa from genrl.trainers.base import Trainer # noqa from genrl.trainers.classical import ClassicalTrainer # noqa +from genrl.trainers.offline import OfflineTrainer # noqa from genrl.trainers.offpolicy import OffPolicyTrainer # noqa from genrl.trainers.onpolicy import OnPolicyTrainer # noqa diff --git a/genrl/trainers/base.py b/genrl/trainers/base.py index 0ce666ef..c7e952dc 100644 --- a/genrl/trainers/base.py +++ b/genrl/trainers/base.py @@ -31,6 +31,7 @@ class Trainer(ABC): run_num (int): A run number allotted to the save of parameters load_weights (str): Weights file load_hyperparams (str): File to load hyperparameters + load_buffer (str): File to load buffer from render (bool): True if environment is to be rendered during training, else False evaluate_episodes (int): Number of episodes to evaluate for seed (int): Set seed for reproducibility @@ -52,6 +53,7 @@ def __init__( run_num: int = None, load_weights: str = None, load_hyperparams: str = None, + load_buffer: str = None, render: bool = False, evaluate_episodes: int = 25, seed: Optional[int] = None, @@ -70,6 +72,7 @@ def __init__( self.run_num = run_num self.load_weights = load_weights self.load_hyperparams = load_hyperparams + self.load_buffer = load_buffer self.render = render self.evaluate_episodes = evaluate_episodes @@ -111,7 +114,7 @@ def evaluate(self, render: bool = False) -> None: for i, di in enumerate(done): if di: episode += 1 - episode_rewards.append(episode_reward[i].clone().detach()) + episode_rewards.append(episode_reward[i].detach().clone()) episode_reward[i] = 0 self.env.reset_single_env(i) if episode == self.evaluate_episodes: @@ -124,7 +127,7 @@ def evaluate(self, render: bool = False) -> None: ) return - def save(self, timestep: int) -> None: + def save(self, timestep: int, save_buffer: bool = False) -> None: """Function to save all relevant parameters of a given agent Args: @@ -153,12 +156,19 @@ def save(self, timestep: int) -> None: filename_hyperparams = "{}/{}-log-{}.toml".format(path, run_num, timestep) filename_weights = "{}/{}-log-{}.pt".format(path, run_num, timestep) + filename_buffer = "{}/{}-buffer-{}.pt".format(path, run_num, timestep) hyperparameters, weights = self.agent.get_hyperparams() with open(filename_hyperparams, mode="w") as f: toml.dump(hyperparameters, f) torch.save(weights, filename_weights) + if save_buffer: + if self.off_policy: + self.agent.replay_buffer.save(filename_buffer) + else: + self.agent.rollout.save(filename_buffer) + def load(self): """Function to load saved parameters of a given agent""" try: @@ -178,7 +188,18 @@ def load(self): except FileNotFoundError: raise Exception("Invalid weights File Name") - print("Loaded Pretrained Model weights and hyperparameters!") + if self.load_buffer is not None: + try: + if self.off_policy: + self.agent.replay_buffer.load(self.load_buffer) + else: + self.agent.rollout.load(self.load_buffer) + except FileNotFoundError: + raise Exception("Invalid buffer File Name") + else: + print("Not loading buffer as no File Name has been passed...") + + print("Loaded Pretrained Model weights, buffer and hyperparameters!") @property def n_envs(self) -> int: diff --git a/genrl/trainers/offline.py b/genrl/trainers/offline.py new file mode 100644 index 00000000..00ddf246 --- /dev/null +++ b/genrl/trainers/offline.py @@ -0,0 +1,118 @@ +from typing import List + +import numpy as np +import torch + +from genrl.trainers.offpolicy import OffPolicyTrainer + + +class OfflineTrainer(OffPolicyTrainer): + """Offline RL Trainer Class + + Trainer class for all the Offline RL Agents: BCQ (more to be added) + + Attributes: + agent (object): Agent algorithm object + env (object): Environment + buffer (object): Replay Buffer object + buffer_path (str): Path to the saved buffer file + max_ep_len (int): Maximum Episode length for training + max_timesteps (int): Maximum limit of timesteps to train for + warmup_steps (int): Number of warmup steps. (random actions are taken to add randomness to training) + start_update (int): Timesteps after which the agent networks should start updating + update_interval (int): Timesteps between target network updates + log_mode (:obj:`list` of str): List of different kinds of logging. Supported: ["csv", "stdout", "tensorboard"] + log_key (str): Key plotted on x_axis. Supported: ["timestep", "episode"] + log_interval (int): Timesteps between successive logging of parameters onto the console + logdir (str): Directory where log files should be saved. + epochs (int): Total number of epochs to train for + off_policy (bool): True if the agent is an off policy agent, False if it is on policy + save_interval (int): Timesteps between successive saves of the agent's important hyperparameters + save_model (str): Directory where the checkpoints of agent parameters should be saved + run_num (int): A run number allotted to the save of parameters + load_weights (str): Weights file + load_hyperparams (str): File to load hyperparameters + load_buffer (str): File to load buffer from + render (bool): True if environment is to be rendered during training, else False + evaluate_episodes (int): Number of episodes to evaluate for + seed (int): Set seed for reproducibility + """ + + def __init__(self, *args, buffer_path: str = None, **kwargs): + super(OfflineTrainer, self).__init__( + *args, start_update=0, warmup_steps=0, update_interval=1, **kwargs + ) + self.buffer_path = buffer_path + + if self.buffer_path is None: + self.generate_buffer("random") + + def generate_buffer(self, generate_type: str = "random") -> None: + """Make a replay buffer from a specific kind of agent + + Args: + generate_type (str): Type of generation for the buffer. Can choose from ["random", "agent"] + Not generatable at the moment. + """ + raise NotImplementedError + + def check_game_over_status(self, timestep: int) -> bool: + """Takes care of game over status of envs + + Whenever a trajectory shows done, the reward accumulated is stored in a list + + Args: + timestep (int): Timestep for which game over condition needs to be checked + + Return: + game_over (bool): True, if at least one environment was done. Else, False + """ + game_over = False + + for i, batch_done in enumerate(self.agent.batch.dones): + for j, done in enumerate(batch_done): + if done or timestep == self.max_ep_len: + self.episodes += 1 + game_over = True + + return game_over + + def log(self, timestep: int) -> None: + """Helper function to log + + Sends useful parameters to the logger. + + Args: + timestep (int): Current timestep of training + """ + self.logger.write( + { + "timestep": timestep, + "Episode": self.episodes, + **self.agent.get_logging_params(), + }, + self.log_key, + ) + + def train(self) -> None: + """Main training method""" + self.buffer.load(self.buffer_path) + self.noise_reset() + + self.training_rewards = [] + self.episodes = 0 + + for timestep in range(0, self.max_timesteps): + self.agent.update_params() + + if timestep % self.log_interval == 0: + self.log(timestep) + + if self.episodes >= self.epochs: + break + + if self.save_interval != 0 and timestep % self.save_interval == 0: + self.save(timestep) + + self.env.close() + self.logger.close() diff --git a/genrl/trainers/offpolicy.py b/genrl/trainers/offpolicy.py index 7e0571c2..becb6b42 100644 --- a/genrl/trainers/offpolicy.py +++ b/genrl/trainers/offpolicy.py @@ -1,6 +1,7 @@ from typing import List, Type, Union import numpy as np +import torch from genrl.core import PrioritizedBuffer, ReplayBuffer from genrl.trainers import Trainer @@ -30,7 +31,9 @@ class OffPolicyTrainer(Trainer): save_interval (int): Timesteps between successive saves of the agent's important hyperparameters save_model (str): Directory where the checkpoints of agent parameters should be saved run_num (int): A run number allotted to the save of parameters - load_model (str): File to load saved parameter checkpoint from + load_weights (str): Weights file + load_hyperparams (str): File to load hyperparameters + load_buffer (str): File to load buffer from render (bool): True if environment is to be rendered during training, else False evaluate_episodes (int): Number of episodes to evaluate for seed (int): Set seed for reproducibility @@ -155,8 +158,10 @@ def train(self) -> None: # true_dones contains the "true" value of the dones (game over statuses). It is set # to False when the environment is not actually done but instead reaches the max # episode length. - true_dones = [info[i]["done"] for i in range(self.env.n_envs)] - self.buffer.push((state, action, reward, next_state, true_dones)) + true_dones = torch.FloatTensor( + [info[i]["done"] for i in range(self.env.n_envs)] + ) + self.buffer.add((state, action, reward, next_state, true_dones)) state = next_state.detach().clone() diff --git a/genrl/trainers/onpolicy.py b/genrl/trainers/onpolicy.py index 920caf38..227b7ea4 100644 --- a/genrl/trainers/onpolicy.py +++ b/genrl/trainers/onpolicy.py @@ -21,7 +21,9 @@ class OnPolicyTrainer(Trainer): save_interval (int): Timesteps between successive saves of the agent's important hyperparameters save_model (str): Directory where the checkpoints of agent parameters should be saved run_num (int): A run number allotted to the save of parameters - load_model (str): File to load saved parameter checkpoint from + load_weights (str): Weights file + load_hyperparams (str): File to load hyperparameters + load_buffer (str): File to load buffer from render (bool): True if environment is to be rendered during training, else False evaluate_episodes (int): Number of episodes to evaluate for seed (int): Set seed for reproducibility diff --git a/tester.py b/tester.py new file mode 100644 index 00000000..e461a1bb --- /dev/null +++ b/tester.py @@ -0,0 +1,36 @@ +from genrl.agents import BCQ, DDPG +from genrl.environments import VectorEnv +from genrl.trainers import OfflineTrainer, OffPolicyTrainer + +env = VectorEnv("Pendulum-v0") +# agent = DDPG( +# "mlp", +# env, +# replay_size=500, +# ) +# trainer = OffPolicyTrainer( +# agent, +# env, +# epochs=40, +# log_interval=5, +# max_timesteps=50000, +# save_interval=2000, +# ) +# trainer.train() +# trainer.evaluate() +# trainer.save(2000, True) + +agent = BCQ( + "mlp", + env, + replay_size=500, +) +trainer = OfflineTrainer( + agent, + env, + log_interval=50, + buffer_path="checkpoints/DDPG_Pendulum-v0/6-buffer-2000.pt", + max_timesteps=1000, +) +trainer.train() +trainer.evaluate()