Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest-cov codecov
pip install git+https://github.com/eleurent/highway-env
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

- name: Tests
Expand Down
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[settings]
known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch
known_third_party = cv2,gym,highway_env,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
1 change: 1 addition & 0 deletions .scripts/unix_cpu_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
python -m pip install --upgrade pip
pip install torch==1.4.0 --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
pip install -r requirements.txt
pip install git+https://github.com/eleurent/highway-env
1 change: 1 addition & 0 deletions .scripts/windows_cpu_build.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ conda install -c conda-forge swig
python -m pip install --upgrade pip
pip install torch==1.4.0 --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade
pip install -r requirements.txt
pip install git+https://github.com/eleurent/highway-env
5 changes: 3 additions & 2 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
from typing import List
from typing import List, Union

import torch
from torch.nn import functional as F

from genrl.agents.deep.base import BaseAgent
from genrl.core import (
HERWrapper,
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
ReplayBuffer,
Expand Down Expand Up @@ -98,7 +99,7 @@ def sample_from_buffer(self, beta: float = None):
states, actions, rewards, next_states, dones = self._reshape_batch(batch)

# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, ReplayBuffer):
if isinstance(self.replay_buffer, (ReplayBuffer, HERWrapper)):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
Expand Down
8 changes: 7 additions & 1 deletion genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def _create_model(self, *args, **kwargs) -> None:
)
else:
self.model = self.network

self.target_model = deepcopy(self.model)

self.optimizer = opt.Adam(self.model.parameters(), lr=self.lr_value)
Expand Down Expand Up @@ -104,6 +103,8 @@ def get_greedy_action(self, state: torch.Tensor) -> torch.Tensor:
Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
"""
if not isinstance(state, torch.Tensor):
state = torch.as_tensor(state).float()
q_values = self.model(state.unsqueeze(0))
action = torch.argmax(q_values.squeeze(), dim=-1)
return action
Expand Down Expand Up @@ -152,6 +153,9 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
Returns:
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
if len(states.shape) < 3:
states = states.unsqueeze(1)
actions = actions.unsqueeze(1)
q_values = self.model(states)
q_values = q_values.gather(2, actions)
return q_values
Expand All @@ -170,6 +174,8 @@ def get_target_q_values(
Returns:
target_q_values (:obj:`torch.Tensor`): Target Q values for the DQN
"""
if len(next_states.shape) < 3:
next_states = next_states.unsqueeze(1)
# Next Q-values according to target model
next_q_target_values = self.target_model(next_states)
# Maximum of next q_target values
Expand Down
1 change: 1 addition & 0 deletions genrl/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from genrl.core.actor_critic import MlpActorCritic, get_actor_critic_from_name # noqa
from genrl.core.bandit import Bandit, BanditAgent
from genrl.core.base import BaseActorCritic # noqa
from genrl.core.buffers import HERWrapper # noqa
from genrl.core.buffers import PrioritizedBuffer # noqa
from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa
from genrl.core.buffers import ReplayBuffer # noqa
Expand Down
17 changes: 7 additions & 10 deletions genrl/core/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def get_action(self, state: torch.Tensor, deterministic: bool = False):
(None if determinist
"""
state = torch.as_tensor(state).float()

if self.actor.sac:
mean, log_std = self.actor(state)
std = log_std.exp()
Expand Down Expand Up @@ -270,7 +269,7 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor:
values = self.forward(state)
elif mode == "min":
values = self.forward(state)
values = torch.min(*values).squeeze(-1)
values = torch.min(*values)
elif mode == "first":
values = self.critic1(state)
else:
Expand Down Expand Up @@ -340,6 +339,7 @@ def get_features(self, state: torch.Tensor):
Returns:
features (:obj:`torch.Tensor`): The feature(s) extracted from the state
"""
state = torch.as_tensor(state).float()
features = self.shared_network(state)
return features

Expand Down Expand Up @@ -373,15 +373,12 @@ def get_value(self, state: torch.Tensor, mode="first"):
values (:obj:`list`): List of values as estimated by each individual critic
"""
state = torch.as_tensor(state).float()
# state shape = [batch_size, number of vec envs, (state_dim + action_dim)]

# extract shard features for just the state
# state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim]
x = self.get_features(state[:, :, : -self.action_dim])

# concatenate the actions to the extracted shared features
# state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim]
state = torch.cat([x, state[:, :, -self.action_dim :]], dim=-1)
state_shape = state.shape
temp = state.reshape(-1, state_shape[-1])[:, : -self.action_dim]
x = self.get_features(temp.reshape(list(state_shape[:-1]) + [-1]))
temp = state.reshape(-1, state_shape[-1])[:, -self.action_dim :]
state = torch.cat([x, temp.reshape(list(state_shape[:-1]) + [-1])], dim=-1)
return super(MlpSharedSingleActorTwoCritic, self).get_value(state, mode)


Expand Down
111 changes: 109 additions & 2 deletions genrl/core/buffers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random
from collections import deque
from typing import NamedTuple, Tuple
Expand Down Expand Up @@ -70,7 +71,7 @@ def __len__(self) -> int:

:returns: Length of replay memory
"""
return self.pos
return len(self.memory)


class PrioritizedBuffer:
Expand Down Expand Up @@ -172,7 +173,7 @@ def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> No

def __len__(self) -> int:
"""
Gives number of experiences in buffer currently
Gives number of expesampleriences in buffer currently

:returns: Length of replay memory
"""
Expand All @@ -181,3 +182,109 @@ def __len__(self) -> int:
@property
def pos(self):
return len(self.buffer)


class HERWrapper:
"""
A wrapper class to convert a replay buffer to a HER Style Buffer

Args:
replay_buffer (ReplayBuffer): An instance of the replay buffer to be converted to a HER style buffer
n_sampled_goals (int): The number of artificial transitions to generate for each actual transition
goal_selection_strategy (str): The strategy to be used to generate goals for the artificial transitions
env (HerGoalEnvWrapper): The goal env, wrapped using HERGoalEnvWrapper
"""

def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, env):

self.n_sampled_goal = n_sampled_goal
self.goal_selection_strategy = goal_selection_strategy
self.replay_buffer = replay_buffer
self.transitions = []
self.allowed_strategies = ["future", "final", "episode", "random"]
self.env = env

def push(self, inp: Tuple):
state, action, reward, next_state, done, info = inp
if isinstance(state, dict):
state = self.env.convert_dict_to_obs(state)
next_state = self.env.convert_dict_to_obs(next_state)

self.transitions.append((state, action, reward, next_state, done, info))
self.replay_buffer.push((state, action, reward, next_state, done))

if inp[-1]:
self._store_episode()
self.transitions = []

def sample(self, batch_size):
return self.replay_buffer.sample(batch_size)

def _sample_achieved_goal(self, ep_transitions, transition_idx):
if self.goal_selection_strategy == "future":
# Sample a goal that was observed in the future
selected_idx = np.random.choice(
np.arange(transition_idx + 1, len(ep_transitions))
)
selected_transition = ep_transitions[selected_idx]
elif self.goal_selection_strategy == "final":
# Sample the goal that was finally achieved during the episode
selected_transition = ep_transitions[-1]
elif self.goal_selection_strategy == "episode":
# Sample a goal that was observed in the episode
selected_idx = np.random.choice(np.arange(len(ep_transitions)))
selected_transition = ep_transitions[selected_idx]
elif self.goal_selection_strategy == "random":
# Sample a random goal from the entire replay buffer
selected_idx = np.random.choice(len(self.replay_buffer))
selected_transition = self.replay_buffer.memory[selected_idx]
else:
raise ValueError(
f"Goal selection strategy must be one of {self.allowed_strategies}"
)

return self.env.convert_obs_to_dict(selected_transition[0])["achieved_goal"]

def _sample_batch_goals(self, ep_transitions, transition_idx):
return [
self._sample_achieved_goal(ep_transitions, transition_idx)
for _ in range(self.n_sampled_goal)
]

def _store_episode(self):
for transition_idx, transition in enumerate(self.transitions):

# We cannot sample from the future on the last step
if (
transition_idx == len(self.transitions) - 1
and self.goal_selection_strategy == "future"
):
break

sampled_goals = self._sample_batch_goals(self.transitions, transition_idx)

for goal in sampled_goals:
state, action, reward, next_state, done, info = copy.deepcopy(
transition
)

# Convert concatenated obs to dict, so we can update the goals
state_dict = self.env.convert_obs_to_dict(state)
next_state_dict = self.env.convert_obs_to_dict(next_state)

# Update the desired goals in the transition
state_dict["desired_goal"] = goal
next_state_dict["desired_goal"] = goal

# Update the reward according to the new desired goal
reward = self.env.compute_reward(
next_state_dict["achieved_goal"], goal, info
)

# Store the newly created transition in the replay buffer
state = self.env.convert_dict_to_obs(state_dict)
next_state = self.env.convert_dict_to_obs(next_state_dict)
self.replay_buffer.push((state, action, reward, next_state, done))

def __len__(self):
return len(self.replay_buffer)
1 change: 1 addition & 0 deletions genrl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from genrl.environments.base_wrapper import BaseWrapper # noqa
from genrl.environments.frame_stack import FrameStack # noqa
from genrl.environments.gym_wrapper import GymWrapper # noqa
from genrl.environments.her_wrapper import HERGoalEnvWrapper # noqa
from genrl.environments.suite import AtariEnv, GymEnv, VectorEnv # noqa
from genrl.environments.time_limit import AtariTimeLimit, TimeLimit # noqa
from genrl.environments.vec_env import VecEnv, VecNormalize # noqa
Loading