-
Notifications
You must be signed in to change notification settings - Fork 58
CEM #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
CEM #373
Changes from 16 commits
1d49049
ef4a179
2ecd086
53450a8
274aff9
38f95f0
835819e
c94a9a1
bf71710
fc356b9
844c53d
d3830e0
a90e8d0
6cb6d5c
3b2067d
f86b046
f5a189d
4b11c16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| from abc import ABC | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class Planner: | ||
| def __init__(self, initial_state, dynamics_model=None): | ||
| if dynamics_model is not None: | ||
| self.dynamics_model = dynamics_model | ||
| self.initial_state = initial_state | ||
|
|
||
| def _learn_dynamics_model(self, state): | ||
| raise NotImplementedError | ||
|
|
||
| def plan(self): | ||
| raise NotImplementedError | ||
|
|
||
| def execute_actions(self): | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class ModelBasedAgent(ABC): | ||
| def __init__(self, env, planner=None, render=False, device="cpu"): | ||
| self.env = env | ||
| self.planner = planner | ||
| self.render = render | ||
| self.device = torch.device(device) | ||
|
|
||
| def plan(self): | ||
| """ | ||
| To be used to plan out a sequence of actions | ||
| """ | ||
| if self.planner is not None: | ||
| raise ValueError("Provide a planner to plan for the environment") | ||
| self.planner.plan() | ||
|
|
||
| def generate_data(self): | ||
| """ | ||
| To be used to generate synthetic data via a model (may be learnt or specified beforehand) | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def value_equivalence(self, state_space): | ||
| """ | ||
| To be used for approximate value estimation methods e.g. Value Iteration Networks | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def update_params(self): | ||
| """ | ||
| Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used) | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def get_hyperparans(self): | ||
| raise NotImplementedError | ||
|
|
||
| def get_logging_params(self): | ||
| raise NotImplementedError | ||
|
|
||
| def _load_weights(self, weights): | ||
| raise NotImplementedError | ||
|
|
||
| def empty_logs(self): | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from genrl.agents import ModelBasedAgent | ||
| from genrl.core import RolloutBuffer | ||
| from genrl.utils import get_env_properties, get_model, safe_mean | ||
|
|
||
|
|
||
| class CEM(ModelBasedAgent): | ||
| def __init__( | ||
| self, | ||
| *args, | ||
| network: str = "mlp", | ||
| policy_layers: tuple = (100,), | ||
| lr_policy=1e-3, | ||
| percentile: int = 70, | ||
| rollout_size, | ||
| **kwargs | ||
| ): | ||
| super(CEM, self).__init__(*args, **kwargs) | ||
| self.network = network | ||
| self.rollout_size = rollout_size | ||
| self.rollout = RolloutBuffer(self.rollout_size, self.env) | ||
| self.policy_layers = policy_layers | ||
| self.lr_policy = lr_policy | ||
| self.percentile = percentile | ||
|
|
||
| self._create_model() | ||
| self.empty_logs() | ||
|
|
||
| def _create_model(self): | ||
| self.state_dim, self.action_dim, discrete, action_lim = get_env_properties( | ||
| self.env, self.network | ||
| ) | ||
| self.agent = get_model("p", self.network)( | ||
| self.state_dim, | ||
| self.action_dim, | ||
| self.policy_layers, | ||
| "V", | ||
| discrete, | ||
| action_lim, | ||
| ) | ||
| self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy) | ||
|
|
||
| def plan(self): | ||
| state = self.env.reset() | ||
| self.rollout.reset() | ||
| states, actions = self.collect_rollouts(state) | ||
| return (states, actions, self.rewards[-1]) | ||
|
|
||
| def select_elites(self, states_batch, actions_batch, rewards_batch): | ||
| reward_threshold = np.percentile(rewards_batch, self.percentile) | ||
| elite_states = [ | ||
| s.unsqueeze(0).clone() | ||
| for i in range(len(states_batch)) | ||
| if rewards_batch[i] >= reward_threshold | ||
| for s in states_batch[i] | ||
| ] | ||
| elite_actions = [ | ||
| a.unsqueeze(0).clone() | ||
| for i in range(len(actions_batch)) | ||
| if rewards_batch[i] >= reward_threshold | ||
| for a in actions_batch[i] | ||
| ] | ||
|
|
||
| return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0) | ||
|
|
||
| def select_action(self, state): | ||
| state = torch.as_tensor(state).float() | ||
| action, dist = self.agent.get_action(state) | ||
| return action, torch.zeros((1, self.env.n_envs)), dist.log_prob(action).cpu() | ||
|
|
||
| def update_params(self): | ||
| sess = [self.plan() for _ in range(100)] | ||
| batch_states, batch_actions, batch_rewards = zip(*sess) | ||
| elite_states, elite_actions = self.select_elites( | ||
| batch_states, batch_actions, batch_rewards | ||
| ) | ||
| action_probs = self.agent.forward(elite_states.float()) | ||
| loss = F.cross_entropy( | ||
| action_probs.view(-1, self.action_dim), | ||
| elite_actions.long().view(-1), | ||
| ) | ||
| self.logs["crossentropy_loss"].append(loss.item()) | ||
| loss.backward() | ||
| # torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5) | ||
| self.optim.step() | ||
|
|
||
| def get_traj_loss(self, values, dones): | ||
| # No need for this here | ||
| pass | ||
|
|
||
| def collect_rollouts(self, state: torch.Tensor): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks pretty similar to the |
||
| """Function to collect rollouts | ||
|
|
||
| Collects rollouts by playing the env like a human agent and inputs information into | ||
| the rollout buffer. | ||
|
|
||
| Args: | ||
| state (:obj:`torch.Tensor`): The starting state of the environment | ||
|
|
||
| Returns: | ||
| values (:obj:`torch.Tensor`): Values of states encountered during the rollout | ||
| dones (:obj:`torch.Tensor`): Game over statuses of each environment | ||
| """ | ||
| states = [] | ||
| actions = [] | ||
| for i in range(self.rollout_size): | ||
| action, value, log_probs = self.select_action(state) | ||
| states.append(state) | ||
| actions.append(action) | ||
|
|
||
| next_state, reward, dones, _ = self.env.step(action) | ||
|
|
||
| if self.render: | ||
| self.env.render() | ||
|
|
||
| self.rollout.add( | ||
| state, | ||
| action.reshape(self.env.n_envs, 1), | ||
| reward, | ||
| dones, | ||
| value, | ||
| log_probs.detach(), | ||
| ) | ||
|
|
||
| state = next_state | ||
|
|
||
| self.collect_rewards(dones, i) | ||
|
|
||
| if dones: | ||
| break | ||
|
|
||
| return states, actions | ||
|
|
||
| def collect_rewards(self, dones: torch.Tensor, timestep: int): | ||
| """Helper function to collect rewards | ||
|
|
||
| Runs through all the envs and collects rewards accumulated during rollouts | ||
|
|
||
| Args: | ||
| dones (:obj:`torch.Tensor`): Game over statuses of each environment | ||
| timestep (int): Timestep during rollout | ||
| """ | ||
| for i, done in enumerate(dones): | ||
| if done or timestep == self.rollout_size - 1: | ||
| self.rewards.append(self.env.episode_reward[i].detach().clone()) | ||
| # self.env.reset_single_env(i) | ||
|
||
|
|
||
| def get_hyperparams(self): | ||
| hyperparams = { | ||
| "network": self.network, | ||
| "lr_policy": self.lr_policy, | ||
| "rollout_size": self.rollout_size, | ||
| } | ||
| return hyperparams | ||
|
|
||
| def get_logging_params(self): | ||
| logs = { | ||
| "crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]), | ||
| "mean_reward": safe_mean(self.rewards), | ||
| } | ||
| return logs | ||
|
|
||
| def empty_logs(self): | ||
| self.logs = {} | ||
| self.logs["crossentropy_loss"] = [] | ||
| self.rewards = [] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| from genrl.agents import CEM | ||
| from genrl.environments import VectorEnv | ||
| from genrl.trainers import OnPolicyTrainer | ||
|
|
||
|
|
||
| def test_CEM(): | ||
|
||
| env = VectorEnv("CartPole-v0", 1) | ||
|
||
| algo = CEM(env, percentile=70, policy_layers=[100], rollout_size=100) | ||
| trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) | ||
| trainer.train() | ||
hades-rp2010 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this inherit from the genrl/deep
BaseAgent?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can, and I think thats a better option (for now at least)