Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[settings]
known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,torch
known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
args: [--exclude=^((examples|docs)/.*)$]

- repo: https://github.com/timothycrosley/isort
rev: 4.3.2
rev: 5.4.2
hooks:
- id: isort

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
Using Shared Parameters in Actor Critic Agents in GenRL
=======================================================

The Actor Critic Agents use two networks, an Actor network to select an action to be taken in the current state, and a
critic network, to estimate the value of the state the agent is currently in. There are two common ways to implement
this actor critic architecture.

The first method - Indpendent Actor and critic networks -

.. code-block:: none

state
/ \
<actor network> <critic network>
/ \
action value

And the second method - Using a set of shared parameters to extract a feature vector from the state. The actor and the
critic network act on this feature vector to select an action and estimate the value

.. code-block:: none

state
|
<decoder>
/ \
<actor network> <critic network>
/ \
action value

GenRL provides support to incorporte this decoder network in all of the actor critic agents through a ``shared_layers``
parameter. ``shared_layers`` takes the sizes of the mlp layers o be used, and ``None`` if no decoder network is to be
used

As an example - in A2C -

.. code-block:: python
# The imports
from genrl.agents import A2C
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer

# Initializing the environment
env = VectorEnv("CartPole-v0", 1)

# Initializing the agent to be used
algo = A2C(
"mlp",
env,
policy_layers=(128,),
value_layers=(128,),
shared_layers=(32, 64),
rollout_size=128,
)

# Finally initializing the trainer and trainer
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()

The above example uses and mlp of layer sizes (32, 64) as the decoder, and can be visualised as follows -

.. code-block:: none

state
|
<32>
|
<64>
/ \
<128> <128>
/ \
action value
14 changes: 9 additions & 5 deletions genrl/agents/deep/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,31 @@ def _create_model(self) -> None:
self.env, self.network
)
if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch_type = self.network
if self.shared_layers is not None:
arch_type += "s"
self.ac = get_model("ac", arch_type)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_type="V",
discrete=discrete,
action_lim=action_lim,
).to(self.device)

else:
self.ac = self.network.to(self.device)

# action_dim = self.network.action_dim

if self.noise is not None:
self.noise = self.noise(
torch.zeros(action_dim), self.noise_std * torch.ones(action_dim)
)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
actor_params, critic_params = self.ac.get_params()
self.optimizer_policy = opt.Adam(critic_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(actor_params, lr=self.lr_value)

def select_action(
self, state: torch.Tensor, deterministic: bool = False
Expand Down
2 changes: 2 additions & 0 deletions genrl/agents/deep/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
create_model: bool = True,
batch_size: int = 64,
gamma: float = 0.99,
shared_layers=None,
policy_layers: Tuple = (64, 64),
value_layers: Tuple = (64, 64),
lr_policy: float = 0.0001,
Expand All @@ -45,6 +46,7 @@ def __init__(
self.create_model = create_model
self.batch_size = batch_size
self.gamma = gamma
self.shared_layers = shared_layers
self.policy_layers = policy_layers
self.rewards = []
self.value_layers = value_layers
Expand Down
6 changes: 3 additions & 3 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
from typing import List
from typing import List, Union

import torch
from torch.nn import functional as F

from genrl.agents.deep.base import BaseAgent
from genrl.core import (
HERWrapper,
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
ReplayBuffer,
Expand Down Expand Up @@ -98,7 +99,7 @@ def sample_from_buffer(self, beta: float = None):
states, actions, rewards, next_states, dones = self._reshape_batch(batch)

# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, ReplayBuffer):
if isinstance(self.replay_buffer, (ReplayBuffer, HERWrapper)):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
Expand Down Expand Up @@ -231,7 +232,6 @@ def get_target_q_values(
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1)
)

target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values

return target_q_values
Expand Down
11 changes: 8 additions & 3 deletions genrl/agents/deep/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def _create_model(self) -> None:
)

if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch_type = self.network
if self.shared_layers is not None:
arch_type += "s"
self.ac = get_model("ac", arch_type)(
state_dim,
action_dim,
self.shared_layers,
self.policy_layers,
self.value_layers,
"Qsa",
Expand All @@ -74,10 +78,11 @@ def _create_model(self) -> None:
else:
self.ac = self.network

actor_params, critic_params = self.ac.get_params()
self.ac_target = deepcopy(self.ac).to(self.device)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value)

def update_params(self, update_interval: int) -> None:
"""Update parameters of the model
Expand Down
5 changes: 5 additions & 0 deletions genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
q_values = self.model(states)
if len(q_values.shape) == 2:
q_values = q_values.unsqueeze(1)
actions = actions.unsqueeze(1)
q_values = q_values.gather(2, actions)
return q_values

Expand All @@ -171,6 +174,8 @@ def get_target_q_values(
target_q_values (:obj:`torch.Tensor`): Target Q values for the DQN
"""
# Next Q-values according to target model
if len(next_states.shape) == 2:
next_states = next_states.unsqueeze(1)
next_q_target_values = self.target_model(next_states)
# Maximum of next q_target values
max_next_q_target_values = next_q_target_values.max(2)[0]
Expand Down
11 changes: 8 additions & 3 deletions genrl/agents/deep/ppo1/ppo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def _create_model(self):
self.env, self.network
)
if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch = self.network
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_typ="V",
Expand All @@ -79,8 +83,9 @@ def _create_model(self):
else:
self.ac = self.network.to(self.device)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
actor_params, critic_params = self.ac.get_params()
self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value)

def select_action(
self, state: torch.Tensor, deterministic: bool = False
Expand Down
16 changes: 7 additions & 9 deletions genrl/agents/deep/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def _create_model(self, **kwargs) -> None:
state_dim, action_dim, discrete, _ = get_env_properties(
self.env, self.network
)

self.ac = get_model("ac", self.network + "12")(
arch = self.network + "12"
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
policy_layers=self.policy_layers,
Expand All @@ -91,13 +93,9 @@ def _create_model(self, **kwargs) -> None:
self.model = self.network

self.ac_target = deepcopy(self.ac)

self.critic_params = list(self.ac.critic1.parameters()) + list(
self.ac.critic2.parameters()
)

self.optimizer_value = opt.Adam(self.critic_params, self.lr_value)
self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), self.lr_policy)
actor_params, critic_params = self.ac.get_params()
self.optimizer_value = opt.Adam(critic_params, self.lr_value)
self.optimizer_policy = opt.Adam(actor_params, self.lr_policy)

if self.entropy_tuning:
self.target_entropy = -torch.prod(
Expand Down
18 changes: 8 additions & 10 deletions genrl/agents/deep/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def _create_model(self) -> None:
)

if isinstance(self.network, str):
# Below, the "12" corresponds to the Single Actor, Double Critic network architecture
self.ac = get_model("ac", self.network + "12")(
arch = self.network + "12"
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_type="Qsa",
Expand All @@ -85,14 +88,9 @@ def _create_model(self) -> None:
)

self.ac_target = deepcopy(self.ac)

self.critic_params = list(self.ac.critic1.parameters()) + list(
self.ac.critic2.parameters()
)
self.optimizer_value = torch.optim.Adam(self.critic_params, lr=self.lr_value)
self.optimizer_policy = torch.optim.Adam(
self.ac.actor.parameters(), lr=self.lr_policy
)
actor_params, critic_params = self.ac.get_params()
self.optimizer_value = torch.optim.Adam(critic_params, lr=self.lr_value)
self.optimizer_policy = torch.optim.Adam(actor_params, lr=self.lr_policy)

def update_params(self, update_interval: int) -> None:
"""Update parameters of the model
Expand Down
1 change: 1 addition & 0 deletions genrl/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from genrl.core.actor_critic import MlpActorCritic, get_actor_critic_from_name # noqa
from genrl.core.bandit import Bandit, BanditAgent
from genrl.core.base import BaseActorCritic # noqa
from genrl.core.buffers import HERWrapper # noqa
from genrl.core.buffers import PrioritizedBuffer # noqa
from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa
from genrl.core.buffers import ReplayBuffer # noqa
Expand Down
Loading