Skip to content

Commit 8f0b488

Browse files
Update Gymnasium to v1.0.0 (DLR-RM#1837)
* Update Gymnasium to v1.0.0a1 * Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord) * Fix ruff warnings * Register Atari envs * Update `getattr` to `Env.get_wrapper_attr` * Reorder imports * Fix `seed` order * Fix collecting `max_steps` * Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper * Use `typing.List` rather than list * Fix env attribute forwarding * Separate out env attribute collection from its utilisation * Update for Gymnasium alpha 2 * Remove assert for OrderedDict * Update setup.py * Add type: ignore * Test with Gymnasium main * Remove `gymnasium.logger.debug/info` * Fix github CI yaml * Run gym 0.29.1 on python 3.10 * Update lower bounds * Integrate video recorder * Remove ordered dict * Update changelog --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent dd3d0ac commit 8f0b488

File tree

16 files changed

+148
-120
lines changed

16 files changed

+148
-120
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ jobs:
2121
strategy:
2222
matrix:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
24-
24+
include:
25+
# Default version
26+
- gymnasium-version: "1.0.0"
27+
# Add a new config to test gym<1.0
28+
- python-version: "3.10"
29+
gymnasium-version: "0.29.1"
2530
steps:
2631
- uses: actions/checkout@v3
2732
- name: Set up Python ${{ matrix.python-version }}
@@ -37,15 +42,14 @@ jobs:
3742
# See https://github.com/astral-sh/uv/issues/1497
3843
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3944
40-
# Install Atari Roms
41-
uv pip install --system autorom
42-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
43-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
44-
AutoROM --accept-license --source-file Roms.tar.gz
45-
46-
uv pip install --system .[extra_no_roms,tests,docs]
45+
uv pip install --system .[extra,tests,docs]
4746
# Use headless version
4847
uv pip install --system opencv-python-headless
48+
- name: Install specific version of gym
49+
run: |
50+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51+
# Only run for python 3.10, downgrade gym to 0.29.1
52+
if: matrix.gymnasium-version != '1.0.0'
4953
- name: Lint with ruff
5054
run: |
5155
make lint

docs/conda_env.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies:
88
- python=3.11
99
- pytorch=2.5.0=py3.11_cpu_0
1010
- pip:
11-
- gymnasium>=0.28.1,<0.30
11+
- gymnasium>=0.29.1,<1.1.0
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas

docs/misc/changelog.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a10 (WIP)
6+
Release 2.4.0a11 (WIP)
77
--------------------------
88

9-
**New algorithm: CrossQ in SB3 Contrib**
9+
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
1010

1111
.. note::
1212

@@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP)
2424

2525
Breaking Changes:
2626
^^^^^^^^^^^^^^^^^
27+
- Increase minimum required version of Gymnasium to 0.29.1
2728

2829
New Features:
2930
^^^^^^^^^^^^^
3031
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
3132
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
3233
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
34+
- Added support for Gymnasium v1.0
3335

3436
Bug Fixes:
3537
^^^^^^^^^^
@@ -69,6 +71,7 @@ Others:
6971
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
7072
- Switched to uv to download packages faster on GitHub CI
7173
- Updated dependencies for read the doc
74+
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
7275

7376
Bug Fixes:
7477
^^^^^^^^^^

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
1818
# ClassVar, implicit optional check not needed for tests
1919
"./tests/*.py" = ["RUF012", "RUF013"]
2020

21-
2221
[tool.ruff.lint.mccabe]
2322
# Unlike Flake8, default to a complexity level of 10.
2423
max-complexity = 15

setup.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,37 +70,13 @@
7070
7171
""" # noqa:E501
7272

73-
# Atari Games download is sometimes problematic:
74-
# https://github.com/Farama-Foundation/AutoROM/issues/39
75-
# That's why we define extra packages without it.
76-
extra_no_roms = [
77-
# For render
78-
"opencv-python",
79-
"pygame",
80-
# Tensorboard support
81-
"tensorboard>=2.9.1",
82-
# Checking memory taken by replay buffer
83-
"psutil",
84-
# For progress bar callback
85-
"tqdm",
86-
"rich",
87-
# For atari games,
88-
"shimmy[atari]~=1.3.0",
89-
"pillow",
90-
]
91-
92-
extra_packages = extra_no_roms + [ # noqa: RUF005
93-
# For atari roms,
94-
"autorom[accept-rom-license]~=0.6.1",
95-
]
96-
9773

9874
setup(
9975
name="stable_baselines3",
10076
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
10177
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
10278
install_requires=[
103-
"gymnasium>=0.28.1,<0.30",
79+
"gymnasium>=0.29.1,<1.1.0",
10480
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
10581
"torch>=1.13",
10682
# For saving models
@@ -133,8 +109,21 @@
133109
# Copy button for code snippets
134110
"sphinx_copybutton",
135111
],
136-
"extra": extra_packages,
137-
"extra_no_roms": extra_no_roms,
112+
"extra": [
113+
# For render
114+
"opencv-python",
115+
"pygame",
116+
# Tensorboard support
117+
"tensorboard>=2.9.1",
118+
# Checking memory taken by replay buffer
119+
"psutil",
120+
# For progress bar callback
121+
"tqdm",
122+
"rich",
123+
# For atari games,
124+
"ale-py>=0.9.0",
125+
"pillow",
126+
],
138127
},
139128
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
140129
author="Antonin Raffin",

stable_baselines3/common/vec_env/dummy_vec_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
1010
from stable_baselines3.common.vec_env.patch_gym import _patch_env
11-
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
11+
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
1212

1313

1414
class DummyVecEnv(VecEnv):
@@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
110110
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
111111

112112
def _obs_from_buf(self) -> VecEnvObs:
113-
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
113+
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
114114

115115
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
116116
"""Return attribute from vectorized environment (see base class)."""
117117
target_envs = self._get_target_envs(indices)
118-
return [getattr(env_i, attr_name) for env_i in target_envs]
118+
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
119119

120120
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
121121
"""Set attribute inside vectorized environments (see base class)."""
@@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
126126
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
127127
"""Call instance methods of vectorized environments."""
128128
target_envs = self._get_target_envs(indices)
129-
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
129+
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
130130

131131
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
132132
"""Check if worker environments are wrapped with a given wrapper"""

stable_baselines3/common/vec_env/patch_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
4343
"Missing shimmy installation. You provided an OpenAI Gym environment. "
4444
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
4545
"In order to use OpenAI Gym environments with SB3, you need to "
46-
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
46+
"install shimmy (`pip install 'shimmy>=2.0'`)."
4747
) from e
4848

4949
warnings.warn(

stable_baselines3/common/vec_env/subproc_vec_env.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import multiprocessing as mp
22
import warnings
3-
from collections import OrderedDict
43
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
54

65
import gymnasium as gym
@@ -54,10 +53,10 @@ def _worker(
5453
elif cmd == "get_spaces":
5554
remote.send((env.observation_space, env.action_space))
5655
elif cmd == "env_method":
57-
method = getattr(env, data[0])
56+
method = env.get_wrapper_attr(data[0])
5857
remote.send(method(*data[1], **data[2]))
5958
elif cmd == "get_attr":
60-
remote.send(getattr(env, data))
59+
remote.send(env.get_wrapper_attr(data))
6160
elif cmd == "set_attr":
6261
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
6362
elif cmd == "is_wrapped":
@@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn:
129128
results = [remote.recv() for remote in self.remotes]
130129
self.waiting = False
131130
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
132-
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
131+
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
133132

134133
def reset(self) -> VecEnvObs:
135134
for env_idx, remote in enumerate(self.remotes):
@@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs:
139138
# Seeds and options are only used once
140139
self._reset_seeds()
141140
self._reset_options()
142-
return _flatten_obs(obs, self.observation_space)
141+
return _stack_obs(obs, self.observation_space)
143142

144143
def close(self) -> None:
145144
if self.closed:
@@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
206205
return [self.remotes[i] for i in indices]
207206

208207

209-
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
208+
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
210209
"""
211-
Flatten observations, depending on the observation space.
210+
Stack observations (convert from a list of single env obs to a stack of obs),
211+
depending on the observation space.
212212
213213
:param obs: observations.
214214
A list or tuple of observations, one per environment.
215215
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
216-
:return: flattened observations.
217-
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
216+
:return: Concatenated observations.
217+
A NumPy array or a dict or tuple of stacked numpy arrays.
218218
Each NumPy array has the environment index as its first axis.
219219
"""
220-
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
221-
assert len(obs) > 0, "need observations from at least one environment"
220+
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
221+
assert len(obs_list) > 0, "need observations from at least one environment"
222222

223223
if isinstance(space, spaces.Dict):
224-
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
225-
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
226-
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
224+
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
225+
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
226+
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
227227
elif isinstance(space, spaces.Tuple):
228-
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
228+
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
229229
obs_len = len(space.spaces)
230-
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
230+
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
231231
else:
232-
return np.stack(obs) # type: ignore[arg-type]
232+
return np.stack(obs_list) # type: ignore[arg-type]

stable_baselines3/common/vec_env/util.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Helpers for dealing with vectorized environments.
33
"""
44

5-
from collections import OrderedDict
65
from typing import Any, Dict, List, Tuple
76

87
import numpy as np
@@ -12,17 +11,6 @@
1211
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs
1312

1413

15-
def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
16-
"""
17-
Deep-copy a dict of numpy arrays.
18-
19-
:param obs: a dict of numpy arrays.
20-
:return: a dict of copied numpy arrays.
21-
"""
22-
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
23-
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])
24-
25-
2614
def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
2715
"""
2816
Convert an internal representation raw_obs into the appropriate type
@@ -60,18 +48,18 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
6048
"""
6149
check_for_nested_spaces(obs_space)
6250
if isinstance(obs_space, spaces.Dict):
63-
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
51+
assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
6452
subspaces = obs_space.spaces
6553
elif isinstance(obs_space, spaces.Tuple):
66-
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
54+
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc]
6755
else:
6856
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
69-
subspaces = {None: obs_space} # type: ignore[assignment]
57+
subspaces = {None: obs_space} # type: ignore[assignment,dict-item]
7058
keys = []
7159
shapes = {}
7260
dtypes = {}
7361
for key, box in subspaces.items():
7462
keys.append(key)
7563
shapes[key] = box.shape
7664
dtypes[key] = box.dtype
77-
return keys, shapes, dtypes
65+
return keys, shapes, dtypes # type: ignore[return-value]

0 commit comments

Comments
 (0)