Skip to content

Commit 7c71688

Browse files
authored
Merge branch 'master' into feat/mps-support
2 parents 263e657 + e4f4f12 commit 7c71688

File tree

18 files changed

+156
-122
lines changed

18 files changed

+156
-122
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ jobs:
2121
strategy:
2222
matrix:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
24-
24+
include:
25+
# Default version
26+
- gymnasium-version: "1.0.0"
27+
# Add a new config to test gym<1.0
28+
- python-version: "3.10"
29+
gymnasium-version: "0.29.1"
2530
steps:
2631
- uses: actions/checkout@v3
2732
- name: Set up Python ${{ matrix.python-version }}
@@ -37,15 +42,14 @@ jobs:
3742
# See https://github.com/astral-sh/uv/issues/1497
3843
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3944
40-
# Install Atari Roms
41-
uv pip install --system autorom
42-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
43-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
44-
AutoROM --accept-license --source-file Roms.tar.gz
45-
46-
uv pip install --system .[extra_no_roms,tests,docs]
45+
uv pip install --system .[extra,tests,docs]
4746
# Use headless version
4847
uv pip install --system opencv-python-headless
48+
- name: Install specific version of gym
49+
run: |
50+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51+
# Only run for python 3.10, downgrade gym to 0.29.1
52+
if: matrix.gymnasium-version != '1.0.0'
4953
- name: Lint with ruff
5054
run: |
5155
make lint

docs/conda_env.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ dependencies:
88
- python=3.11
99
- pytorch=2.5.0=py3.11_cpu_0
1010
- pip:
11-
- gymnasium>=0.28.1,<0.30
11+
- gymnasium>=0.29.1,<1.1.0
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas
1515
- numpy>=1.20,<2.0
1616
- matplotlib
17-
- sphinx>=5,<8
17+
- sphinx>=5,<9
1818
- sphinx_rtd_theme>=1.3.0
1919
- sphinx_copybutton

docs/misc/changelog.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a10 (WIP)
6+
Release 2.4.0a11 (WIP)
77
--------------------------
88

9-
**New algorithm: CrossQ in SB3 Contrib**
9+
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
1010

1111
.. note::
1212

@@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP)
2424

2525
Breaking Changes:
2626
^^^^^^^^^^^^^^^^^
27+
- Increase minimum required version of Gymnasium to 0.29.1
2728

2829
New Features:
2930
^^^^^^^^^^^^^
3031
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
3132
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
3233
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
34+
- Added support for Gymnasium v1.0
3335

3436
Bug Fixes:
3537
^^^^^^^^^^
@@ -57,6 +59,7 @@ Bug Fixes:
5759
`SBX`_ (SB3 + Jax)
5860
^^^^^^^^^^^^^^^^^^
5961
- Added CNN support for DQN
62+
- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3
6063

6164
Deprecations:
6265
^^^^^^^^^^^^^
@@ -69,6 +72,7 @@ Others:
6972
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
7073
- Switched to uv to download packages faster on GitHub CI
7174
- Updated dependencies for read the doc
75+
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
7276

7377
Bug Fixes:
7478
^^^^^^^^^^
@@ -77,6 +81,7 @@ Documentation:
7781
^^^^^^^^^^^^^^
7882
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
7983
- Clarified documentation about planned features and citing software
84+
- Added a note about the fact we are optimizing log of ent coeff for SAC
8085

8186
Release 2.3.2 (2024-04-27)
8287
--------------------------

docs/modules/dqn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Notes
2525

2626
- Original paper: https://arxiv.org/abs/1312.5602
2727
- Further reference: https://www.nature.com/articles/nature14236
28+
- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial
2829

2930
.. note::
3031
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.

docs/modules/sac.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Notes
3535
which is the equivalent to the inverse of reward scale in the original SAC paper.
3636
The main reason is that it avoids having too high errors when updating the Q functions.
3737

38+
.. note::
39+
When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
40+
(see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).
3841

3942
.. note::
4043

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
1818
# ClassVar, implicit optional check not needed for tests
1919
"./tests/*.py" = ["RUF012", "RUF013"]
2020

21-
2221
[tool.ruff.lint.mccabe]
2322
# Unlike Flake8, default to a complexity level of 10.
2423
max-complexity = 15

setup.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,37 +70,13 @@
7070
7171
""" # noqa:E501
7272

73-
# Atari Games download is sometimes problematic:
74-
# https://github.com/Farama-Foundation/AutoROM/issues/39
75-
# That's why we define extra packages without it.
76-
extra_no_roms = [
77-
# For render
78-
"opencv-python",
79-
"pygame",
80-
# Tensorboard support
81-
"tensorboard>=2.9.1",
82-
# Checking memory taken by replay buffer
83-
"psutil",
84-
# For progress bar callback
85-
"tqdm",
86-
"rich",
87-
# For atari games,
88-
"shimmy[atari]~=1.3.0",
89-
"pillow",
90-
]
91-
92-
extra_packages = extra_no_roms + [ # noqa: RUF005
93-
# For atari roms,
94-
"autorom[accept-rom-license]~=0.6.1",
95-
]
96-
9773

9874
setup(
9975
name="stable_baselines3",
10076
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
10177
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
10278
install_requires=[
103-
"gymnasium>=0.28.1,<0.30",
79+
"gymnasium>=0.29.1,<1.1.0",
10480
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
10581
"torch>=1.13",
10682
# For saving models
@@ -125,16 +101,29 @@
125101
"black>=24.2.0,<25",
126102
],
127103
"docs": [
128-
"sphinx>=5,<8",
104+
"sphinx>=5,<9",
129105
"sphinx-autobuild",
130106
"sphinx-rtd-theme>=1.3.0",
131107
# For spelling
132108
"sphinxcontrib.spelling",
133109
# Copy button for code snippets
134110
"sphinx_copybutton",
135111
],
136-
"extra": extra_packages,
137-
"extra_no_roms": extra_no_roms,
112+
"extra": [
113+
# For render
114+
"opencv-python",
115+
"pygame",
116+
# Tensorboard support
117+
"tensorboard>=2.9.1",
118+
# Checking memory taken by replay buffer
119+
"psutil",
120+
# For progress bar callback
121+
"tqdm",
122+
"rich",
123+
# For atari games,
124+
"ale-py>=0.9.0",
125+
"pillow",
126+
],
138127
},
139128
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
140129
author="Antonin Raffin",

stable_baselines3/common/vec_env/dummy_vec_env.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
1010
from stable_baselines3.common.vec_env.patch_gym import _patch_env
11-
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
11+
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info
1212

1313

1414
class DummyVecEnv(VecEnv):
@@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
110110
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]
111111

112112
def _obs_from_buf(self) -> VecEnvObs:
113-
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
113+
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))
114114

115115
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
116116
"""Return attribute from vectorized environment (see base class)."""
117117
target_envs = self._get_target_envs(indices)
118-
return [getattr(env_i, attr_name) for env_i in target_envs]
118+
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]
119119

120120
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
121121
"""Set attribute inside vectorized environments (see base class)."""
@@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
126126
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
127127
"""Call instance methods of vectorized environments."""
128128
target_envs = self._get_target_envs(indices)
129-
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
129+
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]
130130

131131
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
132132
"""Check if worker environments are wrapped with a given wrapper"""

stable_baselines3/common/vec_env/patch_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
4343
"Missing shimmy installation. You provided an OpenAI Gym environment. "
4444
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
4545
"In order to use OpenAI Gym environments with SB3, you need to "
46-
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
46+
"install shimmy (`pip install 'shimmy>=2.0'`)."
4747
) from e
4848

4949
warnings.warn(

stable_baselines3/common/vec_env/subproc_vec_env.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import multiprocessing as mp
22
import warnings
3-
from collections import OrderedDict
43
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
54

65
import gymnasium as gym
@@ -54,10 +53,10 @@ def _worker(
5453
elif cmd == "get_spaces":
5554
remote.send((env.observation_space, env.action_space))
5655
elif cmd == "env_method":
57-
method = getattr(env, data[0])
56+
method = env.get_wrapper_attr(data[0])
5857
remote.send(method(*data[1], **data[2]))
5958
elif cmd == "get_attr":
60-
remote.send(getattr(env, data))
59+
remote.send(env.get_wrapper_attr(data))
6160
elif cmd == "set_attr":
6261
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
6362
elif cmd == "is_wrapped":
@@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn:
129128
results = [remote.recv() for remote in self.remotes]
130129
self.waiting = False
131130
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
132-
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
131+
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
133132

134133
def reset(self) -> VecEnvObs:
135134
for env_idx, remote in enumerate(self.remotes):
@@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs:
139138
# Seeds and options are only used once
140139
self._reset_seeds()
141140
self._reset_options()
142-
return _flatten_obs(obs, self.observation_space)
141+
return _stack_obs(obs, self.observation_space)
143142

144143
def close(self) -> None:
145144
if self.closed:
@@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
206205
return [self.remotes[i] for i in indices]
207206

208207

209-
def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
208+
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
210209
"""
211-
Flatten observations, depending on the observation space.
210+
Stack observations (convert from a list of single env obs to a stack of obs),
211+
depending on the observation space.
212212
213213
:param obs: observations.
214214
A list or tuple of observations, one per environment.
215215
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
216-
:return: flattened observations.
217-
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
216+
:return: Concatenated observations.
217+
A NumPy array or a dict or tuple of stacked numpy arrays.
218218
Each NumPy array has the environment index as its first axis.
219219
"""
220-
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
221-
assert len(obs) > 0, "need observations from at least one environment"
220+
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
221+
assert len(obs_list) > 0, "need observations from at least one environment"
222222

223223
if isinstance(space, spaces.Dict):
224-
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
225-
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
226-
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
224+
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
225+
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
226+
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
227227
elif isinstance(space, spaces.Tuple):
228-
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
228+
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
229229
obs_len = len(space.spaces)
230-
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
230+
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
231231
else:
232-
return np.stack(obs) # type: ignore[arg-type]
232+
return np.stack(obs_list) # type: ignore[arg-type]

0 commit comments

Comments
 (0)