Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions genrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NeuralNoiseSamplingAgent,
)
from genrl.agents.bandits.contextual.variational import VariationalAgent # noqa
from genrl.agents.bandits.multiarmed.base import MABAgent # noqa
from genrl.agents.bandits.multiarmed.bayesian import BayesianUCBMABAgent # noqa
from genrl.agents.bandits.multiarmed.bernoulli_mab import BernoulliMAB # noqa
from genrl.agents.bandits.multiarmed.epsgreedy import EpsGreedyMABAgent # noqa
Expand All @@ -41,5 +42,4 @@
from genrl.agents.deep.sac.sac import SAC # noqa
from genrl.agents.deep.td3.td3 import TD3 # noqa
from genrl.agents.deep.vpg.vpg import VPG # noqa

from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa
from genrl.agents.offline.bcq.bcq import BCQ # noqa
5 changes: 3 additions & 2 deletions genrl/agents/deep/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ class BaseAgent(ABC):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
seed (int): Seed for randomness
Expand Down
51 changes: 14 additions & 37 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from torch.nn import functional as F

from genrl.agents.deep.base import BaseAgent
from genrl.core import (

from genrl.core import ( # PrioritizedReplayBufferSamples,; ReplayBufferSamples,
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
ReplayBuffer,
ReplayBufferSamples,
)


Expand All @@ -23,8 +22,9 @@ class OffPolicyAgent(BaseAgent):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
Expand Down Expand Up @@ -67,19 +67,6 @@ def update_target_model(self) -> None:
"""
raise NotImplementedError

def _reshape_batch(self, batch: List):
"""Function to reshape experiences

Can be modified for individual algorithm usage

Args:
batch (:obj:`list`): List of experiences that are being replayed

Returns:
batch (:obj:`list`): Reshaped experiences for replay
"""
return [*batch]

def sample_from_buffer(self, beta: float = None):
"""Samples experiences from the buffer and converts them into usable formats

Expand All @@ -95,18 +82,6 @@ def sample_from_buffer(self, beta: float = None):
else:
batch = self.replay_buffer.sample(self.batch_size)

states, actions, rewards, next_states, dones = self._reshape_batch(batch)

# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, ReplayBuffer):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
batch = PrioritizedReplayBufferSamples(
*[states, actions, rewards, next_states, dones, indices, weights]
)
else:
raise NotImplementedError
return batch

def get_q_loss(self, batch: collections.namedtuple) -> torch.Tensor:
Expand Down Expand Up @@ -136,8 +111,9 @@ class OffPolicyAgentAC(OffPolicyAgent):
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
policy_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the policy
value_layers (:obj:`tuple` of :obj:`int`): Neural network layer dimensions for the critics
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
Expand All @@ -154,7 +130,7 @@ def __init__(self, *args, polyak=0.995, **kwargs):
self.doublecritic = False

def select_action(
self, state: torch.Tensor, deterministic: bool = True
self, state: torch.Tensor, deterministic: bool = True, noise: bool = True
) -> torch.Tensor:
"""Select action given state

Expand All @@ -163,6 +139,7 @@ def select_action(
Args:
state (:obj:`torch.Tensor`): Current state of the environment
deterministic (bool): Should the policy be deterministic or stochastic
noise (bool): Should noise be added to the agent

Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
Expand All @@ -171,7 +148,7 @@ def select_action(
action = action.detach()

# add noise to output from policy network
if self.noise is not None:
if noise and self.noise is not None:
action += self.noise()

return torch.clamp(
Expand Down Expand Up @@ -210,7 +187,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
def get_target_q_values(
self, next_states: torch.Tensor, rewards: List[float], dones: List[bool]
) -> torch.Tensor:
"""Get target Q values for the TD3
"""Get target Q values

Args:
next_states (:obj:`torch.Tensor`): Next states for which target Q-values
Expand All @@ -219,7 +196,7 @@ def get_target_q_values(
dones (:obj:`list`): Game over status for each environment

Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the TD3
target_q_values (:obj:`torch.Tensor`): Target Q values
"""
next_target_actions = self.ac_target.get_action(next_states, True)[0]

Expand Down Expand Up @@ -265,7 +242,7 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
Returns:
loss (:obj:`torch.Tensor`): Calculated policy loss
"""
next_best_actions = self.ac.get_action(states, True)[0]
next_best_actions = self.select_action(states, deterministic=True, noise=False)
q_values = self.ac.get_value(torch.cat([states, next_best_actions], dim=-1))
policy_loss = -torch.mean(q_values)
return policy_loss
Expand Down
2 changes: 1 addition & 1 deletion genrl/agents/deep/base/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(

if buffer_type == "rollout":
self.rollout = RolloutBuffer(
self.rollout_size, self.env, gae_lambda=gae_lambda
self.rollout_size, self.env, gae_lambda=gae_lambda, gamma=self.gamma
)
else:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
if self.create_model:
self._create_model()

def _create_model(self, *args, **kwargs) -> None:
def _create_model(self, **kwargs) -> None:
"""Function to initialize Q-value model

This will create the Q-value function of the agent.
Expand Down Expand Up @@ -153,7 +153,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
q_values = self.model(states)
q_values = q_values.gather(2, actions)
q_values = q_values.gather(2, actions.unsqueeze(-1))
return q_values

def get_target_q_values(
Expand Down
6 changes: 4 additions & 2 deletions genrl/agents/deep/dqn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def categorical_q_values(agent: DQN, states: torch.Tensor, actions: torch.Tensor
# Size of q_value_dist should be [batch_size, n_envs, action_dim, num_atoms] here
# To gather the q_values of the respective actions, actions must be of the shape:
# [batch_size, n_envs, 1, num_atoms]. It's current shape is [batch_size, n_envs, 1]
actions = actions.unsqueeze(-1).expand(
agent.batch_size, agent.env.n_envs, 1, agent.num_atoms
actions = (
actions.unsqueeze(-1)
.unsqueeze(-1)
.expand(agent.batch_size, agent.env.n_envs, 1, agent.num_atoms)
)
# Now as we gather q_values from the action_dim dimension which is at index 2
q_values = q_value_dist.gather(2, actions)
Expand Down
Empty file.
Empty file.
Loading