diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 350e4d23..ecbbbc8e 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -32,6 +32,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pytest-cov codecov + pip install git+https://github.com/eleurent/highway-env if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Tests diff --git a/.isort.cfg b/.isort.cfg index 4b0feff5..e0f65cb3 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] -known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch +known_third_party = cv2,gym,highway_env,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/.scripts/unix_cpu_build.sh b/.scripts/unix_cpu_build.sh index 8bb5b134..ff747208 100755 --- a/.scripts/unix_cpu_build.sh +++ b/.scripts/unix_cpu_build.sh @@ -3,3 +3,4 @@ python -m pip install --upgrade pip pip install torch==1.4.0 --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade pip install -r requirements.txt +pip install git+https://github.com/eleurent/highway-env diff --git a/.scripts/windows_cpu_build.ps1 b/.scripts/windows_cpu_build.ps1 index cd724a12..5d08ca5d 100644 --- a/.scripts/windows_cpu_build.ps1 +++ b/.scripts/windows_cpu_build.ps1 @@ -4,3 +4,4 @@ conda install -c conda-forge swig python -m pip install --upgrade pip pip install torch==1.4.0 --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade pip install -r requirements.txt +pip install git+https://github.com/eleurent/highway-env diff --git a/genrl/agents/deep/base/offpolicy.py b/genrl/agents/deep/base/offpolicy.py index fa632e8a..de59010d 100644 --- a/genrl/agents/deep/base/offpolicy.py +++ b/genrl/agents/deep/base/offpolicy.py @@ -1,11 +1,12 @@ import collections -from typing import List +from typing import List, Union import torch from torch.nn import functional as F from genrl.agents.deep.base import BaseAgent from genrl.core import ( + HERWrapper, PrioritizedBuffer, PrioritizedReplayBufferSamples, ReplayBuffer, @@ -98,7 +99,7 @@ def sample_from_buffer(self, beta: float = None): states, actions, rewards, next_states, dones = self._reshape_batch(batch) # Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples. - if isinstance(self.replay_buffer, ReplayBuffer): + if isinstance(self.replay_buffer, (ReplayBuffer, HERWrapper)): batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones]) elif isinstance(self.replay_buffer, PrioritizedBuffer): indices, weights = batch[5], batch[6] diff --git a/genrl/agents/deep/dqn/base.py b/genrl/agents/deep/dqn/base.py index 7642e541..a2d83dbf 100644 --- a/genrl/agents/deep/dqn/base.py +++ b/genrl/agents/deep/dqn/base.py @@ -71,7 +71,6 @@ def _create_model(self, *args, **kwargs) -> None: ) else: self.model = self.network - self.target_model = deepcopy(self.model) self.optimizer = opt.Adam(self.model.parameters(), lr=self.lr_value) @@ -104,6 +103,8 @@ def get_greedy_action(self, state: torch.Tensor) -> torch.Tensor: Returns: action (:obj:`torch.Tensor`): Action taken by the agent """ + if not isinstance(state, torch.Tensor): + state = torch.as_tensor(state).float() q_values = self.model(state.unsqueeze(0)) action = torch.argmax(q_values.squeeze(), dim=-1) return action @@ -152,6 +153,9 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten Returns: q_values (:obj:`torch.Tensor`): Q values for the given states and actions """ + if len(states.shape) < 3: + states = states.unsqueeze(1) + actions = actions.unsqueeze(1) q_values = self.model(states) q_values = q_values.gather(2, actions) return q_values @@ -170,6 +174,8 @@ def get_target_q_values( Returns: target_q_values (:obj:`torch.Tensor`): Target Q values for the DQN """ + if len(next_states.shape) < 3: + next_states = next_states.unsqueeze(1) # Next Q-values according to target model next_q_target_values = self.target_model(next_states) # Maximum of next q_target values diff --git a/genrl/core/__init__.py b/genrl/core/__init__.py index 96824ac1..ab71a27f 100644 --- a/genrl/core/__init__.py +++ b/genrl/core/__init__.py @@ -1,6 +1,7 @@ from genrl.core.actor_critic import MlpActorCritic, get_actor_critic_from_name # noqa from genrl.core.bandit import Bandit, BanditAgent from genrl.core.base import BaseActorCritic # noqa +from genrl.core.buffers import HERWrapper # noqa from genrl.core.buffers import PrioritizedBuffer # noqa from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa from genrl.core.buffers import ReplayBuffer # noqa diff --git a/genrl/core/actor_critic.py b/genrl/core/actor_critic.py index 1a96e959..c9afc809 100644 --- a/genrl/core/actor_critic.py +++ b/genrl/core/actor_critic.py @@ -225,7 +225,6 @@ def get_action(self, state: torch.Tensor, deterministic: bool = False): (None if determinist """ state = torch.as_tensor(state).float() - if self.actor.sac: mean, log_std = self.actor(state) std = log_std.exp() @@ -270,7 +269,7 @@ def get_value(self, state: torch.Tensor, mode="first") -> torch.Tensor: values = self.forward(state) elif mode == "min": values = self.forward(state) - values = torch.min(*values).squeeze(-1) + values = torch.min(*values) elif mode == "first": values = self.critic1(state) else: @@ -340,6 +339,7 @@ def get_features(self, state: torch.Tensor): Returns: features (:obj:`torch.Tensor`): The feature(s) extracted from the state """ + state = torch.as_tensor(state).float() features = self.shared_network(state) return features @@ -373,15 +373,12 @@ def get_value(self, state: torch.Tensor, mode="first"): values (:obj:`list`): List of values as estimated by each individual critic """ state = torch.as_tensor(state).float() - # state shape = [batch_size, number of vec envs, (state_dim + action_dim)] - - # extract shard features for just the state - # state[:, :, :-action_dim] -> [batch_size, number of vec envs, state_dim] - x = self.get_features(state[:, :, : -self.action_dim]) - # concatenate the actions to the extracted shared features - # state[:, :, -action_dim:] -> [batch_size, number of vec envs, action_dim] - state = torch.cat([x, state[:, :, -self.action_dim :]], dim=-1) + state_shape = state.shape + temp = state.reshape(-1, state_shape[-1])[:, : -self.action_dim] + x = self.get_features(temp.reshape(list(state_shape[:-1]) + [-1])) + temp = state.reshape(-1, state_shape[-1])[:, -self.action_dim :] + state = torch.cat([x, temp.reshape(list(state_shape[:-1]) + [-1])], dim=-1) return super(MlpSharedSingleActorTwoCritic, self).get_value(state, mode) diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..0c9fab59 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -1,3 +1,4 @@ +import copy import random from collections import deque from typing import NamedTuple, Tuple @@ -70,7 +71,7 @@ def __len__(self) -> int: :returns: Length of replay memory """ - return self.pos + return len(self.memory) class PrioritizedBuffer: @@ -172,7 +173,7 @@ def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> No def __len__(self) -> int: """ - Gives number of experiences in buffer currently + Gives number of expesampleriences in buffer currently :returns: Length of replay memory """ @@ -181,3 +182,109 @@ def __len__(self) -> int: @property def pos(self): return len(self.buffer) + + +class HERWrapper: + """ + A wrapper class to convert a replay buffer to a HER Style Buffer + + Args: + replay_buffer (ReplayBuffer): An instance of the replay buffer to be converted to a HER style buffer + n_sampled_goals (int): The number of artificial transitions to generate for each actual transition + goal_selection_strategy (str): The strategy to be used to generate goals for the artificial transitions + env (HerGoalEnvWrapper): The goal env, wrapped using HERGoalEnvWrapper + """ + + def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, env): + + self.n_sampled_goal = n_sampled_goal + self.goal_selection_strategy = goal_selection_strategy + self.replay_buffer = replay_buffer + self.transitions = [] + self.allowed_strategies = ["future", "final", "episode", "random"] + self.env = env + + def push(self, inp: Tuple): + state, action, reward, next_state, done, info = inp + if isinstance(state, dict): + state = self.env.convert_dict_to_obs(state) + next_state = self.env.convert_dict_to_obs(next_state) + + self.transitions.append((state, action, reward, next_state, done, info)) + self.replay_buffer.push((state, action, reward, next_state, done)) + + if inp[-1]: + self._store_episode() + self.transitions = [] + + def sample(self, batch_size): + return self.replay_buffer.sample(batch_size) + + def _sample_achieved_goal(self, ep_transitions, transition_idx): + if self.goal_selection_strategy == "future": + # Sample a goal that was observed in the future + selected_idx = np.random.choice( + np.arange(transition_idx + 1, len(ep_transitions)) + ) + selected_transition = ep_transitions[selected_idx] + elif self.goal_selection_strategy == "final": + # Sample the goal that was finally achieved during the episode + selected_transition = ep_transitions[-1] + elif self.goal_selection_strategy == "episode": + # Sample a goal that was observed in the episode + selected_idx = np.random.choice(np.arange(len(ep_transitions))) + selected_transition = ep_transitions[selected_idx] + elif self.goal_selection_strategy == "random": + # Sample a random goal from the entire replay buffer + selected_idx = np.random.choice(len(self.replay_buffer)) + selected_transition = self.replay_buffer.memory[selected_idx] + else: + raise ValueError( + f"Goal selection strategy must be one of {self.allowed_strategies}" + ) + + return self.env.convert_obs_to_dict(selected_transition[0])["achieved_goal"] + + def _sample_batch_goals(self, ep_transitions, transition_idx): + return [ + self._sample_achieved_goal(ep_transitions, transition_idx) + for _ in range(self.n_sampled_goal) + ] + + def _store_episode(self): + for transition_idx, transition in enumerate(self.transitions): + + # We cannot sample from the future on the last step + if ( + transition_idx == len(self.transitions) - 1 + and self.goal_selection_strategy == "future" + ): + break + + sampled_goals = self._sample_batch_goals(self.transitions, transition_idx) + + for goal in sampled_goals: + state, action, reward, next_state, done, info = copy.deepcopy( + transition + ) + + # Convert concatenated obs to dict, so we can update the goals + state_dict = self.env.convert_obs_to_dict(state) + next_state_dict = self.env.convert_obs_to_dict(next_state) + + # Update the desired goals in the transition + state_dict["desired_goal"] = goal + next_state_dict["desired_goal"] = goal + + # Update the reward according to the new desired goal + reward = self.env.compute_reward( + next_state_dict["achieved_goal"], goal, info + ) + + # Store the newly created transition in the replay buffer + state = self.env.convert_dict_to_obs(state_dict) + next_state = self.env.convert_dict_to_obs(next_state_dict) + self.replay_buffer.push((state, action, reward, next_state, done)) + + def __len__(self): + return len(self.replay_buffer) diff --git a/genrl/environments/__init__.py b/genrl/environments/__init__.py index a25fa0d9..49c33112 100644 --- a/genrl/environments/__init__.py +++ b/genrl/environments/__init__.py @@ -4,6 +4,7 @@ from genrl.environments.base_wrapper import BaseWrapper # noqa from genrl.environments.frame_stack import FrameStack # noqa from genrl.environments.gym_wrapper import GymWrapper # noqa +from genrl.environments.her_wrapper import HERGoalEnvWrapper # noqa from genrl.environments.suite import AtariEnv, GymEnv, VectorEnv # noqa from genrl.environments.time_limit import AtariTimeLimit, TimeLimit # noqa from genrl.environments.vec_env import VecEnv, VecNormalize # noqa diff --git a/genrl/environments/custom_envs/BitFlipEnv.py b/genrl/environments/custom_envs/BitFlipEnv.py new file mode 100644 index 00000000..e4f4310d --- /dev/null +++ b/genrl/environments/custom_envs/BitFlipEnv.py @@ -0,0 +1,145 @@ +from collections import OrderedDict + +import numpy as np +import torch +from gym import GoalEnv, spaces + + +class BitFlippingEnv(GoalEnv): + """ + Simple bit flipping env, useful to test HER. + The goal is to flip all the bits to get a vector of ones. + In the continuous variant, if the ith action component has a value > 0, + then the ith bit will be flipped. + :param n_bits: (int) Number of bits to flip + :param continuous: (bool) Whether to use the continuous actions version or not, + by default, it uses the discrete one + :param max_steps: (int) Max number of steps, by default, equal to n_bits + :param discrete_obs_space: (bool) Whether to use the discrete observation + version or not, by default, it uses the MultiBinary one + + Adopted from Stable Baselines + """ + + def __init__( + self, n_bits=10, continuous=False, max_steps=None, discrete_obs_space=False + ): + super(BitFlippingEnv, self).__init__() + # The achieved goal is determined by the current state + # here, it is a special where they are equal + if discrete_obs_space: + # In the discrete case, the agent act on the binary + # representation of the observation replay + self.observation_space = spaces.Dict( + { + "observation": spaces.Discrete(2 ** n_bits - 1), + "achieved_goal": spaces.Discrete(2 ** n_bits - 1), + "desired_goal": spaces.Discrete(2 ** n_bits - 1), + } + ) + else: + self.observation_space = spaces.Dict( + { + "observation": spaces.MultiBinary(n_bits), + "achieved_goal": spaces.MultiBinary(n_bits), + "desired_goal": spaces.MultiBinary(n_bits), + } + ) + + self.obs_space = spaces.MultiBinary(n_bits) + self.n_envs = 3 + self.episode_reward = torch.zeros(1) + self.obs_dim = n_bits + self.goal_dim = n_bits + + if continuous: + self.action_space = spaces.Box(-1, 1, shape=(n_bits,), dtype=np.float32) + else: + self.action_space = spaces.Discrete(n_bits) + self.continuous = continuous + self.discrete_obs_space = discrete_obs_space + self.state = None + self.desired_goal = np.ones((n_bits,)) + if max_steps is None: + max_steps = n_bits + self.max_steps = max_steps + self.current_step = 0 + self.reset() + self.keys = ["observation", "achieved_goal", "desired_goal"] + + def convert_if_needed(self, state): + """ + Convert to discrete space if needed. + :param state: (np.ndarray) + :return: (np.ndarray or int) + """ + if self.discrete_obs_space: + # The internal state is the binary representation of the + # observed one + return int(sum([state[i] * 2 ** i for i in range(len(state))])) + return state + + def _get_obs(self): + """ + Helper to create the observation. + :return: (OrderedDict) + """ + return OrderedDict( + [ + ("observation", self.convert_if_needed(self.state.copy())), + ("achieved_goal", self.convert_if_needed(self.state.copy())), + ("desired_goal", self.convert_if_needed(self.desired_goal.copy())), + ] + ) + + def convert_dict_to_obs(self, obs_dict): + return torch.as_tensor( + np.concatenate([obs_dict[key] for key in self.keys]), dtype=torch.float32 + ) + + def convert_obs_to_dict(self, obs): + return OrderedDict( + [ + ("observation", obs[: self.obs_dim]), + ("achieved_goal", obs[self.obs_dim : self.obs_dim + self.goal_dim]), + ("desired_goal", obs[self.obs_dim + self.goal_dim :]), + ] + ) + + def reset(self): + self.current_step = 0 + self.state = self.obs_space.sample() + return self._get_obs() + + def sample(self): + return torch.as_tensor(self.action_space.sample()) + + def step(self, action): + if self.continuous: + self.state[action > 0] = 1 - self.state[action > 0] + else: + self.state[action] = 1 - self.state[action] + obs = self._get_obs() + reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"]) + done = reward == 0 + self.current_step += 1 + # Episode terminate when we reached the goal or the max number of steps + info = {"done": done} + done = done or self.current_step >= self.max_steps + return obs, reward, done, info + + def compute_reward( + self, achieved_goal: np.ndarray, desired_goal: np.ndarray + ) -> float: + # Deceptive reward: it is positive only when the goal is achieved + if self.discrete_obs_space: + return 0.0 if achieved_goal == desired_goal else -1.0 + return 0.0 if (achieved_goal == desired_goal).all() else -1.0 + + def render(self, mode="human"): + if mode == "rgb_array": + return self.state.copy() + print(self.state) + + def close(self): + pass diff --git a/genrl/environments/custom_envs/__init__.py b/genrl/environments/custom_envs/__init__.py new file mode 100644 index 00000000..21ebda29 --- /dev/null +++ b/genrl/environments/custom_envs/__init__.py @@ -0,0 +1 @@ +from genrl.environments.custom_envs.BitFlipEnv import BitFlippingEnv # noqa diff --git a/genrl/environments/her_wrapper.py b/genrl/environments/her_wrapper.py new file mode 100644 index 00000000..37cf74a8 --- /dev/null +++ b/genrl/environments/her_wrapper.py @@ -0,0 +1,94 @@ +import collections +from collections import OrderedDict + +import numpy as np +import torch +from gym import GoalEnv, spaces + +from genrl.environments.torch import TorchWrapper +from genrl.environments.vec_env import SubProcessVecEnv + + +class HERGoalEnvWrapper(GoalEnv): + def __init__(self, env): + self.env = env + self.action_space = env.action_space + self.spaces = list(env.observation_space.spaces.values()) + + if isinstance(self.spaces[0], spaces.Discrete): + self.obs_dim = 1 + self.goal_dim = 1 + else: + goal_space_shape = env.observation_space.spaces["achieved_goal"].shape + self.obs_dim = env.observation_space.spaces["observation"].shape[0] + self.goal_dim = goal_space_shape[0] + + if len(goal_space_shape) == 2: + assert ( + goal_space_shape[1] == 1 + ), "Only 1D observation spaces are supported yet" + else: + assert ( + len(goal_space_shape) == 1 + ), "Only 1D observation spaces are supported yet" + + if isinstance(self.spaces[0], spaces.MultiBinary): + total_dim = self.obs_dim + 2 * self.goal_dim + self.observation_space = spaces.MultiBinary(total_dim) + + elif isinstance(self.spaces[0], spaces.Box): + lows = np.concatenate([space.low for space in self.spaces]) + highs = np.concatenate([space.high for space in self.spaces]) + self.observation_space = spaces.Box(lows, highs, dtype=np.float32) + + elif isinstance(self.spaces[0], spaces.Discrete): + dimensions = [env.observation_space.spaces[key].n for key in KEY_ORDER] + self.observation_space = spaces.MultiDiscrete(dimensions) + + else: + raise NotImplementedError(f"{type(self.spaces[0])} space is not supported") + + self.keys = ["observation", "achieved_goal", "desired_goal"] + self.episode_reward = torch.zeros(1) + + def convert_dict_to_obs(self, obs_dict): + return np.concatenate([obs_dict[key] for key in self.keys]) + + def convert_obs_to_dict(self, obs): + return OrderedDict( + [ + ("observation", obs[: self.obs_dim]), + ( + "achieved_goal", + obs[self.obs_dim : self.obs_dim + self.goal_dim], + ), + ("desired_goal", obs[self.obs_dim + self.goal_dim :]), + ] + ) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.episode_reward += reward + + if "done" not in info.keys(): + info["done"] = done + return self.convert_dict_to_obs(obs), reward, done, info + + def sample(self): + return self.env.action_space.sample() + + def seed(self, seed=None): + return self.env.seed(seed) + + def reset(self): + self.episode_reward = torch.zeros(1) + return self.convert_dict_to_obs(self.env.reset()) + + def compute_reward(self, achieved_goal, desired_goal, info): + return self.env.compute_reward(achieved_goal, desired_goal, info) + + def render(self, mode="human"): + return self.env.render(mode) + + def close(self): + return self.env.close() diff --git a/genrl/trainers/__init__.py b/genrl/trainers/__init__.py index 7410831b..eb9b74b7 100644 --- a/genrl/trainers/__init__.py +++ b/genrl/trainers/__init__.py @@ -1,5 +1,6 @@ from genrl.trainers.bandit import BanditTrainer, DCBTrainer, MABTrainer # noqa from genrl.trainers.base import Trainer # noqa from genrl.trainers.classical import ClassicalTrainer # noqa +from genrl.trainers.her_trainer import HERTrainer # noqa from genrl.trainers.offpolicy import OffPolicyTrainer # noqa from genrl.trainers.onpolicy import OnPolicyTrainer # noqa diff --git a/genrl/trainers/her_trainer.py b/genrl/trainers/her_trainer.py new file mode 100644 index 00000000..7b8650d2 --- /dev/null +++ b/genrl/trainers/her_trainer.py @@ -0,0 +1,85 @@ +from typing import Type, Union + +import numpy as np + +from genrl.trainers.offpolicy import OffPolicyTrainer + + +class HERTrainer(OffPolicyTrainer): + def __init__(self, *args, **kwargs): + super(HERTrainer, self).__init__(*args, **kwargs) + + def _check_state(self, state): + if isinstance(state, dict): + return self.env.convert_dict_to_obs(state) + return state + + def get_action(self, state, timestep): + if timestep < self.warmup_steps: + action = self.env.sample() + else: + action = self.agent.select_action(self._check_state(state)) + return action + + def check_game_over_status(self, timestep: int, dones: bool): + game_over = False + + if dones: + self.training_rewards.append(self.env.episode_reward.detach().clone()) + self.env.reset() + self.episodes += 1 + game_over = True + + return game_over + + def train(self) -> None: + """Main training method""" + if self.load_weights is not None or self.load_hyperparams is not None: + self.load() + + state = self.env.reset() + self.noise_reset() + + self.training_rewards = [] + self.episodes = 0 + + for timestep in range(0, self.max_timesteps): + self.agent.update_params_before_select_action(timestep) + action = self.get_action(np.expand_dims(state, axis=0), timestep) + + if ( + not isinstance(action, int) + and action.shape != self.env.action_space.shape + ): + action = action.squeeze(0) + next_state, reward, done, info = self.env.step(action) + + if self.render: + self.env.render() + + # true_dones contains the "true" value of the dones (game over statuses). It is set + # to False when the environment is not actually done but instead reaches the max + # episode length. + true_dones = info["done"] + self.buffer.push((state, action, reward, next_state, true_dones, info)) + + state = next_state + + if self.check_game_over_status(timestep, done): + self.noise_reset() + + if self.episodes % self.log_interval == 0: + self.log(timestep) + + if timestep >= self.start_update and timestep % self.update_interval == 0: + self.agent.update_params(self.update_interval) + + if ( + timestep >= self.start_update + and self.save_interval != 0 + and timestep % self.save_interval == 0 + ): + self.save(timestep) + + self.env.close() + self.logger.close() diff --git a/genrl/utils/utils.py b/genrl/utils/utils.py index 89e53337..fa003986 100644 --- a/genrl/utils/utils.py +++ b/genrl/utils/utils.py @@ -144,16 +144,17 @@ def get_env_properties( discreteness of action space and action limit (highest action value) :rtype: int, float, ...; int, float, ...; bool; int, float, ... """ - if network == "cnn": - state_dim = env.framestack - elif network == "mlp": - state_dim = env.observation_space.shape[0] - elif isinstance(network, (BasePolicy, BaseValue)): - state_dim = network.state_dim - elif isinstance(network, BaseActorCritic): - state_dim = network.actor.state_dim - else: - raise TypeError + if isinstance(env, (VecEnv, gym.Env)): + if network == "cnn": + state_dim = env.framestack + elif network == "mlp": + state_dim = env.observation_space.shape[0] + elif isinstance(network, (BasePolicy, BaseValue)): + state_dim = network.state_dim + elif isinstance(network, BaseActorCritic): + state_dim = network.actor.state_dim + else: + raise TypeError if isinstance(env.action_space, gym.spaces.Discrete): action_dim = env.action_space.n diff --git a/tests/test_deep/test_agents/test_her.py b/tests/test_deep/test_agents/test_her.py new file mode 100644 index 00000000..99fc9a86 --- /dev/null +++ b/tests/test_deep/test_agents/test_her.py @@ -0,0 +1,40 @@ +import shutil + +import gym +import highway_env +from gym.core import Wrapper + +from genrl.agents import DDPG, SAC, TD3 +from genrl.core import HERWrapper, ReplayBuffer +from genrl.environments import HERGoalEnvWrapper, VectorEnv +from genrl.trainers import HERTrainer + + +class TestHER: + def _test_agent(self, agent): + env = gym.make("parking-v0") + env = HERGoalEnvWrapper(env) + algo = agent("mlp", env, batch_size=10, policy_layers=[1], value_layers=[1]) + buffer = HERWrapper(ReplayBuffer(10), 1, "future", env) + trainer = HERTrainer( + algo, + env, + buffer=buffer, + log_mode=["csv"], + logdir="./logs", + max_ep_len=10, + epochs=int(1), + warmup_steps=1, + start_update=1, + ) + trainer.train() + shutil.rmtree("./logs") + + def test_DDPG(self): + self._test_agent(DDPG) + + def test_SAC(self): + self._test_agent(SAC) + + def test_TD3(self): + self._test_agent(TD3) diff --git a/tests/test_environments/test_custom_env.py b/tests/test_environments/test_custom_env.py new file mode 100644 index 00000000..483e7c80 --- /dev/null +++ b/tests/test_environments/test_custom_env.py @@ -0,0 +1,28 @@ +import shutil + +from genrl.agents import DQN +from genrl.core import HERWrapper, ReplayBuffer +from genrl.environments import HERGoalEnvWrapper +from genrl.environments.custom_envs import BitFlipEnv, BitFlippingEnv +from genrl.trainers import HERTrainer + + +def test_her(): + env = BitFlippingEnv() + env = HERGoalEnvWrapper(env) + algo = DQN("mlp", env, batch_size=5, replay_size=10, value_layers=[1, 1]) + buffer = HERWrapper(ReplayBuffer(1000), 1, "future", env) + print(isinstance(buffer, ReplayBuffer)) + trainer = HERTrainer( + algo, + env, + buffer=buffer, + log_mode=["csv"], + logdir="./logs", + max_ep_len=200, + epochs=100, + warmup_steps=10, + start_update=10, + ) + trainer.train() + shutil.rmtree("./logs")