Skip to content

Commit 0ec37d8

Browse files
authored
Merge branch 'feat/mps-support' into feat/mps-support
2 parents 4c03a25 + 9489b1a commit 0ec37d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+681
-618
lines changed

.github/workflows/ci.yml

+14-11
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10", "3.11"]
24-
23+
python-version: ["3.9", "3.10", "3.11", "3.12"]
24+
include:
25+
# Default version
26+
- gymnasium-version: "1.0.0"
27+
# Add a new config to test gym<1.0
28+
- python-version: "3.10"
29+
gymnasium-version: "0.29.1"
2530
steps:
2631
- uses: actions/checkout@v3
2732
- name: Set up Python ${{ matrix.python-version }}
@@ -37,15 +42,15 @@ jobs:
3742
# See https://github.com/astral-sh/uv/issues/1497
3843
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3944
40-
# Install Atari Roms
41-
uv pip install --system autorom
42-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
43-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
44-
AutoROM --accept-license --source-file Roms.tar.gz
45-
46-
uv pip install --system .[extra_no_roms,tests,docs]
45+
uv pip install --system .[extra,tests,docs]
4746
# Use headless version
4847
uv pip install --system opencv-python-headless
48+
- name: Install specific version of gym
49+
run: |
50+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51+
uv pip install --system "numpy<2"
52+
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
53+
if: matrix.gymnasium-version != '1.0.0'
4954
- name: Lint with ruff
5055
run: |
5156
make lint
@@ -58,8 +63,6 @@ jobs:
5863
- name: Type check
5964
run: |
6065
make type
61-
# Do not run for python 3.8 (mypy internal error)
62-
if: matrix.python-version != '3.8'
6366
- name: Test with pytest
6467
run: |
6568
make pytest

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
100100

101101
## Installation
102102

103-
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
103+
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
104104

105105
### Prerequisites
106-
Stable Baselines3 requires Python 3.8+.
106+
Stable Baselines3 requires Python 3.9+.
107107

108108
#### Windows
109109

docs/conda_env.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ dependencies:
88
- python=3.11
99
- pytorch=2.5.0=py3.11_cpu_0
1010
- pip:
11-
- gymnasium>=0.28.1,<0.30
11+
- gymnasium>=0.29.1,<1.1.0
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas
15-
- numpy>=1.20,<2.0
15+
- numpy>=1.20,<3.0
1616
- matplotlib
17-
- sphinx>=5,<8
17+
- sphinx>=5,<9
1818
- sphinx_rtd_theme>=1.3.0
1919
- sphinx_copybutton

docs/conf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import datetime
1515
import os
1616
import sys
17-
from typing import Dict
1817

1918
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
2019
# PyEnchant.
@@ -151,7 +150,7 @@ def setup(app):
151150

152151
# -- Options for LaTeX output ------------------------------------------------
153152

154-
latex_elements: Dict[str, str] = {
153+
latex_elements: dict[str, str] = {
155154
# The paper size ('letterpaper' or 'a4paper').
156155
#
157156
# 'papersize': 'letterpaper',

docs/guide/install.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Installation
77
Prerequisites
88
-------------
99

10-
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
10+
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
1111

1212
Windows
1313
~~~~~~~

docs/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
2020

2121
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
2222

23+
SBX (SB3 + Jax): https://github.com/araffin/sbx
24+
2325

2426
Main Features
2527
--------------

docs/misc/changelog.rst

+44-7
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,45 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a10 (WIP)
6+
Release 2.5.0a0 (WIP)
77
--------------------------
88

9-
**New algorithm: CrossQ in SB3 Contrib**
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
- Increased minimum required version of PyTorch to 2.3.0
12+
- Removed support for Python 3.8
13+
14+
New Features:
15+
^^^^^^^^^^^^^
16+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
17+
- Added official support for Python 3.12
18+
19+
Bug Fixes:
20+
^^^^^^^^^^
21+
22+
`SB3-Contrib`_
23+
^^^^^^^^^^^^^^
24+
25+
`RL Zoo`_
26+
^^^^^^^^^
27+
28+
`SBX`_ (SB3 + Jax)
29+
^^^^^^^^^^^^^^^^^^
30+
31+
Deprecations:
32+
^^^^^^^^^^^^^
33+
34+
Others:
35+
^^^^^^^
36+
37+
Documentation:
38+
^^^^^^^^^^^^^^
39+
40+
41+
Release 2.4.0 (2024-11-18)
42+
--------------------------
43+
44+
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
1045

1146
.. note::
1247

@@ -18,18 +53,20 @@ Release 2.4.0a10 (WIP)
1853
.. warning::
1954

2055
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
21-
and PyTorch < 2.0.
22-
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
56+
and PyTorch < 2.3.
57+
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).
2358

2459

2560
Breaking Changes:
2661
^^^^^^^^^^^^^^^^^
62+
- Increased minimum required version of Gymnasium to 0.29.1
2763

2864
New Features:
2965
^^^^^^^^^^^^^
3066
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
3167
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
3268
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
69+
- Added support for Gymnasium v1.0
3370

3471
Bug Fixes:
3572
^^^^^^^^^^
@@ -57,6 +94,7 @@ Bug Fixes:
5794
`SBX`_ (SB3 + Jax)
5895
^^^^^^^^^^^^^^^^^^
5996
- Added CNN support for DQN
97+
- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3
6098

6199
Deprecations:
62100
^^^^^^^^^^^^^
@@ -69,14 +107,13 @@ Others:
69107
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
70108
- Switched to uv to download packages faster on GitHub CI
71109
- Updated dependencies for read the doc
72-
73-
Bug Fixes:
74-
^^^^^^^^^^
110+
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
75111

76112
Documentation:
77113
^^^^^^^^^^^^^^
78114
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
79115
- Clarified documentation about planned features and citing software
116+
- Added a note about the fact we are optimizing log of ent coeff for SAC
80117

81118
Release 2.3.2 (2024-04-27)
82119
--------------------------

docs/modules/dqn.rst

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Notes
2525

2626
- Original paper: https://arxiv.org/abs/1312.5602
2727
- Further reference: https://www.nature.com/articles/nature14236
28+
- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial
2829

2930
.. note::
3031
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.

docs/modules/sac.rst

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Notes
3535
which is the equivalent to the inverse of reward scale in the original SAC paper.
3636
The main reason is that it avoids having too high errors when updating the Q functions.
3737

38+
.. note::
39+
When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
40+
(see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).
3841

3942
.. note::
4043

pyproject.toml

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.8
5-
target-version = "py38"
4+
# Assume Python 3.9
5+
target-version = "py39"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/
@@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
1818
# ClassVar, implicit optional check not needed for tests
1919
"./tests/*.py" = ["RUF012", "RUF013"]
2020

21-
2221
[tool.ruff.lint.mccabe]
2322
# Unlike Flake8, default to a complexity level of 10.
2423
max-complexity = 15

setup.py

+21-32
Original file line numberDiff line numberDiff line change
@@ -70,39 +70,15 @@
7070
7171
""" # noqa:E501
7272

73-
# Atari Games download is sometimes problematic:
74-
# https://github.com/Farama-Foundation/AutoROM/issues/39
75-
# That's why we define extra packages without it.
76-
extra_no_roms = [
77-
# For render
78-
"opencv-python",
79-
"pygame",
80-
# Tensorboard support
81-
"tensorboard>=2.9.1",
82-
# Checking memory taken by replay buffer
83-
"psutil",
84-
# For progress bar callback
85-
"tqdm",
86-
"rich",
87-
# For atari games,
88-
"shimmy[atari]~=1.3.0",
89-
"pillow",
90-
]
91-
92-
extra_packages = extra_no_roms + [ # noqa: RUF005
93-
# For atari roms,
94-
"autorom[accept-rom-license]~=0.6.1",
95-
]
96-
9773

9874
setup(
9975
name="stable_baselines3",
10076
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
10177
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
10278
install_requires=[
103-
"gymnasium>=0.28.1,<0.30",
104-
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
105-
"torch>=1.13",
79+
"gymnasium>=0.29.1,<1.1.0",
80+
"numpy>=1.20,<3.0",
81+
"torch>=2.3,<3.0",
10682
# For saving models
10783
"cloudpickle",
10884
# For reading logs
@@ -125,16 +101,29 @@
125101
"black>=24.2.0,<25",
126102
],
127103
"docs": [
128-
"sphinx>=5,<8",
104+
"sphinx>=5,<9",
129105
"sphinx-autobuild",
130106
"sphinx-rtd-theme>=1.3.0",
131107
# For spelling
132108
"sphinxcontrib.spelling",
133109
# Copy button for code snippets
134110
"sphinx_copybutton",
135111
],
136-
"extra": extra_packages,
137-
"extra_no_roms": extra_no_roms,
112+
"extra": [
113+
# For render
114+
"opencv-python",
115+
"pygame",
116+
# Tensorboard support
117+
"tensorboard>=2.9.1",
118+
# Checking memory taken by replay buffer
119+
"psutil",
120+
# For progress bar callback
121+
"tqdm",
122+
"rich",
123+
# For atari games,
124+
"ale-py>=0.9.0",
125+
"pillow",
126+
],
138127
},
139128
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
140129
author="Antonin Raffin",
@@ -146,7 +135,7 @@
146135
long_description=long_description,
147136
long_description_content_type="text/markdown",
148137
version=__version__,
149-
python_requires=">=3.8",
138+
python_requires=">=3.9",
150139
# PyPI package information.
151140
project_urls={
152141
"Code": "https://github.com/DLR-RM/stable-baselines3",
@@ -158,10 +147,10 @@
158147
},
159148
classifiers=[
160149
"Programming Language :: Python :: 3",
161-
"Programming Language :: Python :: 3.8",
162150
"Programming Language :: Python :: 3.9",
163151
"Programming Language :: Python :: 3.10",
164152
"Programming Language :: Python :: 3.11",
153+
"Programming Language :: Python :: 3.12",
165154
],
166155
)
167156

stable_baselines3/a2c/a2c.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
1+
from typing import Any, ClassVar, Optional, TypeVar, Union
22

33
import torch as th
44
from gymnasium import spaces
@@ -57,15 +57,15 @@ class A2C(OnPolicyAlgorithm):
5757
:param _init_setup_model: Whether or not to build the network at the creation of the instance
5858
"""
5959

60-
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
60+
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
6161
"MlpPolicy": ActorCriticPolicy,
6262
"CnnPolicy": ActorCriticCnnPolicy,
6363
"MultiInputPolicy": MultiInputActorCriticPolicy,
6464
}
6565

6666
def __init__(
6767
self,
68-
policy: Union[str, Type[ActorCriticPolicy]],
68+
policy: Union[str, type[ActorCriticPolicy]],
6969
env: Union[GymEnv, str],
7070
learning_rate: Union[float, Schedule] = 7e-4,
7171
n_steps: int = 5,
@@ -78,12 +78,12 @@ def __init__(
7878
use_rms_prop: bool = True,
7979
use_sde: bool = False,
8080
sde_sample_freq: int = -1,
81-
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
82-
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
81+
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
82+
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
8383
normalize_advantage: bool = False,
8484
stats_window_size: int = 100,
8585
tensorboard_log: Optional[str] = None,
86-
policy_kwargs: Optional[Dict[str, Any]] = None,
86+
policy_kwargs: Optional[dict[str, Any]] = None,
8787
verbose: int = 0,
8888
seed: Optional[int] = None,
8989
device: Union[th.device, str] = "auto",

stable_baselines3/common/atari_wrappers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, SupportsFloat
1+
from typing import SupportsFloat
22

33
import gymnasium as gym
44
import numpy as np
@@ -64,7 +64,7 @@ def reset(self, **kwargs) -> AtariResetReturn:
6464
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
6565
assert noops > 0
6666
obs = np.zeros(0)
67-
info: Dict = {}
67+
info: dict = {}
6868
for _ in range(noops):
6969
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
7070
if terminated or truncated:

0 commit comments

Comments
 (0)