Skip to content
Open

CEM #373

Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions genrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@
from genrl.agents.deep.sac.sac import SAC # noqa
from genrl.agents.deep.td3.td3 import TD3 # noqa
from genrl.agents.deep.vpg.vpg import VPG # noqa
from genrl.agents.modelbased.base import ModelBasedAgent # noqa
from genrl.agents.modelbased.cem.cem import CEM # noqa

from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa
Empty file.
65 changes: 65 additions & 0 deletions genrl/agents/modelbased/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC

import torch


class Planner:
def __init__(self, initial_state, dynamics_model=None):
if dynamics_model is not None:
self.dynamics_model = dynamics_model
self.initial_state = initial_state

def _learn_dynamics_model(self, state):
raise NotImplementedError

def plan(self):
raise NotImplementedError

def execute_actions(self):
raise NotImplementedError


class ModelBasedAgent(ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this inherit from the genrl/deep BaseAgent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, and I think thats a better option (for now at least)

def __init__(self, env, planner=None, render=False, device="cpu"):
self.env = env
self.planner = planner
self.render = render
self.device = torch.device(device)

def plan(self):
"""
To be used to plan out a sequence of actions
"""
if self.planner is not None:
raise ValueError("Provide a planner to plan for the environment")
self.planner.plan()

def generate_data(self):
"""
To be used to generate synthetic data via a model (may be learnt or specified beforehand)
"""
raise NotImplementedError

def value_equivalence(self, state_space):
"""
To be used for approximate value estimation methods e.g. Value Iteration Networks
"""
raise NotImplementedError

def update_params(self):
"""
Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used)
"""
raise NotImplementedError

def get_hyperparans(self):
raise NotImplementedError

def get_logging_params(self):
raise NotImplementedError

def _load_weights(self, weights):
raise NotImplementedError

def empty_logs(self):
raise NotImplementedError
Empty file.
169 changes: 169 additions & 0 deletions genrl/agents/modelbased/cem/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
import torch
import torch.nn.functional as F

from genrl.agents import ModelBasedAgent
from genrl.core import RolloutBuffer
from genrl.utils import get_env_properties, get_model, safe_mean


class CEM(ModelBasedAgent):
def __init__(
self,
*args,
network: str = "mlp",
policy_layers: tuple = (100,),
lr_policy=1e-3,
percentile: int = 70,
rollout_size,
**kwargs
):
super(CEM, self).__init__(*args, **kwargs)
self.network = network
self.rollout_size = rollout_size
self.rollout = RolloutBuffer(self.rollout_size, self.env)
self.policy_layers = policy_layers
self.lr_policy = lr_policy
self.percentile = percentile

self._create_model()
self.empty_logs()

def _create_model(self):
self.state_dim, self.action_dim, discrete, action_lim = get_env_properties(
self.env, self.network
)
self.agent = get_model("p", self.network)(
self.state_dim,
self.action_dim,
self.policy_layers,
"V",
discrete,
action_lim,
)
self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy)

def plan(self):
state = self.env.reset()
self.rollout.reset()
states, actions = self.collect_rollouts(state)
return (states, actions, self.rewards[-1])

def select_elites(self, states_batch, actions_batch, rewards_batch):
reward_threshold = np.percentile(rewards_batch, self.percentile)
elite_states = [
s.unsqueeze(0).clone()
for i in range(len(states_batch))
if rewards_batch[i] >= reward_threshold
for s in states_batch[i]
]
elite_actions = [
a.unsqueeze(0).clone()
for i in range(len(actions_batch))
if rewards_batch[i] >= reward_threshold
for a in actions_batch[i]
]

return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0)

def select_action(self, state):
state = torch.as_tensor(state).float()
action, dist = self.agent.get_action(state)
return action, torch.zeros((1, self.env.n_envs)), dist.log_prob(action).cpu()

def update_params(self):
sess = [self.plan() for _ in range(100)]
batch_states, batch_actions, batch_rewards = zip(*sess)
elite_states, elite_actions = self.select_elites(
batch_states, batch_actions, batch_rewards
)
action_probs = self.agent.forward(elite_states.float())
loss = F.cross_entropy(
action_probs.view(-1, self.action_dim),
elite_actions.long().view(-1),
)
self.logs["crossentropy_loss"].append(loss.item())
loss.backward()
# torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5)
self.optim.step()

def get_traj_loss(self, values, dones):
# No need for this here
pass

def collect_rollouts(self, state: torch.Tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty similar to the OnPolicyAgent method. Shouldn't this return values and dones though? Not sure if this is a consequence of the algo.

"""Function to collect rollouts

Collects rollouts by playing the env like a human agent and inputs information into
the rollout buffer.

Args:
state (:obj:`torch.Tensor`): The starting state of the environment

Returns:
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`torch.Tensor`): Game over statuses of each environment
"""
states = []
actions = []
for i in range(self.rollout_size):
action, value, log_probs = self.select_action(state)
states.append(state)
actions.append(action)

next_state, reward, dones, _ = self.env.step(action)

if self.render:
self.env.render()

self.rollout.add(
state,
action.reshape(self.env.n_envs, 1),
reward,
dones,
value,
log_probs.detach(),
)

state = next_state

self.collect_rewards(dones, i)

if dones:
break

return states, actions

def collect_rewards(self, dones: torch.Tensor, timestep: int):
"""Helper function to collect rewards

Runs through all the envs and collects rewards accumulated during rollouts

Args:
dones (:obj:`torch.Tensor`): Game over statuses of each environment
timestep (int): Timestep during rollout
"""
for i, done in enumerate(dones):
if done or timestep == self.rollout_size - 1:
self.rewards.append(self.env.episode_reward[i].detach().clone())
# self.env.reset_single_env(i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? This is necessary to reset environments immediately as they are set to done. (Not a good practice to do env.step() if the env is already returning done = True)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I am breaking the loop of actions if a env.step() returns done=True, and every plan session (the plan function) starts with env.reset(), I think this is redundant here, hence its commented out


def get_hyperparams(self):
hyperparams = {
"network": self.network,
"lr_policy": self.lr_policy,
"rollout_size": self.rollout_size,
}
return hyperparams

def get_logging_params(self):
logs = {
"crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]),
"mean_reward": safe_mean(self.rewards),
}
return logs

def empty_logs(self):
self.logs = {}
self.logs["crossentropy_loss"] = []
self.rewards = []
10 changes: 10 additions & 0 deletions tests/test_deep/test_agents/test_cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from genrl.agents import CEM
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer


def test_CEM():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please make this a class so the tests are easier to find/understand

env = VectorEnv("CartPole-v0", 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why set it to 1? It does work with multiple envs right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it does

algo = CEM(env, percentile=70, policy_layers=[100], rollout_size=100)
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()