forked from openai/baselines
-
Notifications
You must be signed in to change notification settings - Fork 724
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Gym Env checker * Test common failures * Declare param as unused * Update tests/test_envs.py Co-Authored-By: Adam Gleave <[email protected]> * Update docs/guide/rl_tips.rst Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update docs/guide/rl_tips.rst Co-Authored-By: Adam Gleave <[email protected]> * Split checks * Split tests * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Update stable_baselines/common/env_checker.py Co-Authored-By: Adam Gleave <[email protected]> * Reformat files
- Loading branch information
Showing
10 changed files
with
404 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
.. _env_checker: | ||
|
||
Gym Environment Checker | ||
======================== | ||
|
||
.. automodule:: stable_baselines.common.env_checker | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,3 +101,8 @@ Piecewise | |
csv | ||
nvidia | ||
visdom | ||
tensorboard | ||
preprocessed | ||
namespace | ||
sklearn | ||
GoalEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import warnings | ||
from typing import Union | ||
|
||
import gym | ||
from gym import spaces | ||
import numpy as np | ||
|
||
from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan | ||
|
||
|
||
def _enforce_array_obs(observation_space: spaces.Space) -> bool: | ||
""" | ||
Whether to check that the returned observation is a numpy array | ||
it is not mandatory for `Dict` and `Tuple` spaces. | ||
""" | ||
return not isinstance(observation_space, (spaces.Dict, spaces.Tuple)) | ||
|
||
|
||
def _check_image_input(observation_space: spaces.Box) -> None: | ||
""" | ||
Check that the input will be compatible with Stable-Baselines | ||
when the observation is apparently an image. | ||
""" | ||
if observation_space.dtype != np.uint8: | ||
warnings.warn("It seems that your observation is an image but the `dtype` " | ||
"of your observation_space is not `np.uint8`. " | ||
"If your observation is not an image, we recommend you to flatten the observation " | ||
"to have only a 1D vector") | ||
|
||
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): | ||
warnings.warn("It seems that your observation space is an image but the " | ||
"upper and lower bounds are not in [0, 255]. " | ||
"Because the CNN policy normalize automatically the observation " | ||
"you may encounter issue if the values are not in that range." | ||
) | ||
|
||
if observation_space.shape[0] < 36 or observation_space.shape[1] < 36: | ||
warnings.warn("The minimal resolution for an image is 36x36 for the default CnnPolicy. " | ||
"You might need to use a custom `cnn_extractor` " | ||
"cf https://stable-baselines.readthedocs.io/en/master/guide/custom_policy.html") | ||
|
||
|
||
def _check_unsupported_obs_spaces(env: gym.Env, observation_space: spaces.Space) -> None: | ||
"""Emit warnings when the observation space used is not supported by Stable-Baselines.""" | ||
|
||
if isinstance(observation_space, spaces.Dict) and not isinstance(env, gym.GoalEnv): | ||
warnings.warn("The observation space is a Dict but the environment is not a gym.GoalEnv " | ||
"(cf https://github.com/openai/gym/blob/master/gym/core.py), " | ||
"this is currently not supported by Stable Baselines " | ||
"(cf https://github.com/hill-a/stable-baselines/issues/133), " | ||
"you will need to use a custom policy. " | ||
) | ||
|
||
if isinstance(observation_space, spaces.Tuple): | ||
warnings.warn("The observation space is a Tuple," | ||
"this is currently not supported by Stable Baselines " | ||
"(cf https://github.com/hill-a/stable-baselines/issues/133), " | ||
"you will need to flatten the observation and maybe use a custom policy. " | ||
) | ||
|
||
|
||
def _check_nan(env: gym.Env) -> None: | ||
"""Check for Inf and NaN using the VecWrapper.""" | ||
vec_env = VecCheckNan(DummyVecEnv([lambda: env])) | ||
for _ in range(10): | ||
action = [env.action_space.sample()] | ||
_, _, _, _ = vec_env.step(action) | ||
|
||
|
||
def _check_obs(obs: Union[tuple, dict, np.ndarray, int], | ||
observation_space: spaces.Space, | ||
method_name: str) -> None: | ||
""" | ||
Check that the observation returned by the environment | ||
correspond to the declared one. | ||
""" | ||
if not isinstance(observation_space, spaces.Tuple): | ||
assert not isinstance(obs, tuple), ("The observation returned by the `{}()` " | ||
"method should be a single value, not a tuple".format(method_name)) | ||
|
||
# The check for a GoalEnv is done by the base class | ||
if isinstance(observation_space, spaces.Discrete): | ||
assert isinstance(obs, int), "The observation returned by `{}()` method must be an int".format(method_name) | ||
elif _enforce_array_obs(observation_space): | ||
assert isinstance(obs, np.ndarray), ("The observation returned by `{}()` " | ||
"method must be a numpy array".format(method_name)) | ||
|
||
assert observation_space.contains(obs), ("The observation returned by the `{}()` " | ||
"method does not match the given observation space".format(method_name)) | ||
|
||
|
||
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None: | ||
""" | ||
Check the returned values by the env when calling `.reset()` or `.step()` methods. | ||
""" | ||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists | ||
obs = env.reset() | ||
|
||
_check_obs(obs, observation_space, 'reset') | ||
|
||
# Sample a random action | ||
action = action_space.sample() | ||
data = env.step(action) | ||
|
||
assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info" | ||
|
||
# Unpack | ||
obs, reward, done, info = data | ||
|
||
_check_obs(obs, observation_space, 'step') | ||
|
||
# We also allow int because the reward will be cast to float | ||
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" | ||
assert isinstance(done, bool), "The `done` signal must be a boolean" | ||
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" | ||
|
||
if isinstance(env, gym.GoalEnv): | ||
# For a GoalEnv, the keys are checked at reset | ||
assert reward == env.compute_reward(obs['achieved_goal'], obs['desired_goal'], info) | ||
|
||
|
||
def _check_spaces(env: gym.Env) -> None: | ||
""" | ||
Check that the observation and action spaces are defined | ||
and inherit from gym.spaces.Space. | ||
""" | ||
# Helper to link to the code, because gym has no proper documentation | ||
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" | ||
|
||
assert hasattr(env, 'observation_space'), "You must specify an observation space (cf gym.spaces)" + gym_spaces | ||
assert hasattr(env, 'action_space'), "You must specify an action space (cf gym.spaces)" + gym_spaces | ||
|
||
assert isinstance(env.observation_space, | ||
spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces | ||
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces | ||
|
||
|
||
def _check_render(env: gym.Env, warn=True, headless=False) -> None: | ||
""" | ||
Check the declared render modes and the `render()`/`close()` | ||
method of the environment. | ||
:param env: (gym.Env) The environment to check | ||
:param warn: (bool) Whether to output additional warnings | ||
:param headless: (bool) Whether to disable render modes | ||
that require a graphical interface. False by default. | ||
""" | ||
render_modes = env.metadata.get('render.modes') | ||
if render_modes is None: | ||
if warn: | ||
warnings.warn("No render modes was declared in the environment " | ||
" (env.metadata['render.modes'] is None or not defined), " | ||
"you may have trouble when calling `.render()`") | ||
|
||
else: | ||
# Don't check render mode that require a | ||
# graphical interface (useful for CI) | ||
if headless and 'human' in render_modes: | ||
render_modes.remove('human') | ||
# Check all declared render modes | ||
for render_mode in render_modes: | ||
env.render(mode=render_mode) | ||
env.close() | ||
|
||
|
||
def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None: | ||
""" | ||
Check that an environment follows Gym API. | ||
This is particularly useful when using a custom environment. | ||
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py | ||
for more information about the API. | ||
It also optionally check that the environment is compatible with Stable-Baselines. | ||
:param env: (gym.Env) The Gym environment that will be checked | ||
:param warn: (bool) Whether to output additional warnings | ||
mainly related to the interaction with Stable Baselines | ||
:param skip_render_check: (bool) Whether to skip the checks for the render method. | ||
True by default (useful for the CI) | ||
""" | ||
assert isinstance(env, gym.Env), ("You environment must inherit from gym.Env class " | ||
" cf https://github.com/openai/gym/blob/master/gym/core.py") | ||
|
||
# ============= Check the spaces (observation and action) ================ | ||
_check_spaces(env) | ||
|
||
# Define aliases for convenience | ||
observation_space = env.observation_space | ||
action_space = env.action_space | ||
|
||
# Warn the user if needed. | ||
# A warning means that the environment may run but not work properly with Stable Baselines algorithms | ||
if warn: | ||
_check_unsupported_obs_spaces(env, observation_space) | ||
|
||
# If image, check the low and high values, the type and the number of channels | ||
# and the shape (minimal value) | ||
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: | ||
_check_image_input(observation_space) | ||
|
||
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) not in [1, 3]: | ||
warnings.warn("Your observation has an unconventional shape (neither an image, nor a 1D vector). " | ||
"We recommend you to flatten the observation " | ||
"to have only a 1D vector") | ||
|
||
# Check for the action space, it may lead to hard-to-debug issues | ||
if (isinstance(action_space, spaces.Box) and | ||
(np.abs(action_space.low) != np.abs(action_space.high) | ||
or np.abs(action_space.low) > 1 or np.abs(action_space.high) > 1)): | ||
warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " | ||
"cf https://stable-baselines.readthedocs.io/en/master/guide/rl_tips.html") | ||
|
||
# ============ Check the returned values =============== | ||
_check_returned_values(env, observation_space, action_space) | ||
|
||
# ==== Check the render method and the declared render modes ==== | ||
if not skip_render_check: | ||
_check_render(env, warn=warn) | ||
|
||
# The check only works with numpy arrays | ||
if _enforce_array_obs(observation_space): | ||
_check_nan(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.