Skip to content

Commit 0ec37d8

Browse files
authored
Merge branch 'feat/mps-support' into feat/mps-support
2 parents 4c03a25 + 9489b1a commit 0ec37d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+681
-618
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10", "3.11"]
24-
23+
python-version: ["3.9", "3.10", "3.11", "3.12"]
24+
include:
25+
# Default version
26+
- gymnasium-version: "1.0.0"
27+
# Add a new config to test gym<1.0
28+
- python-version: "3.10"
29+
gymnasium-version: "0.29.1"
2530
steps:
2631
- uses: actions/checkout@v3
2732
- name: Set up Python ${{ matrix.python-version }}
@@ -37,15 +42,15 @@ jobs:
3742
# See https://github.com/astral-sh/uv/issues/1497
3843
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
3944
40-
# Install Atari Roms
41-
uv pip install --system autorom
42-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
43-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
44-
AutoROM --accept-license --source-file Roms.tar.gz
45-
46-
uv pip install --system .[extra_no_roms,tests,docs]
45+
uv pip install --system .[extra,tests,docs]
4746
# Use headless version
4847
uv pip install --system opencv-python-headless
48+
- name: Install specific version of gym
49+
run: |
50+
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51+
uv pip install --system "numpy<2"
52+
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
53+
if: matrix.gymnasium-version != '1.0.0'
4954
- name: Lint with ruff
5055
run: |
5156
make lint
@@ -58,8 +63,6 @@ jobs:
5863
- name: Type check
5964
run: |
6065
make type
61-
# Do not run for python 3.8 (mypy internal error)
62-
if: matrix.python-version != '3.8'
6366
- name: Test with pytest
6467
run: |
6568
make pytest

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
100100

101101
## Installation
102102

103-
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
103+
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
104104

105105
### Prerequisites
106-
Stable Baselines3 requires Python 3.8+.
106+
Stable Baselines3 requires Python 3.9+.
107107

108108
#### Windows
109109

docs/conda_env.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ dependencies:
88
- python=3.11
99
- pytorch=2.5.0=py3.11_cpu_0
1010
- pip:
11-
- gymnasium>=0.28.1,<0.30
11+
- gymnasium>=0.29.1,<1.1.0
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas
15-
- numpy>=1.20,<2.0
15+
- numpy>=1.20,<3.0
1616
- matplotlib
17-
- sphinx>=5,<8
17+
- sphinx>=5,<9
1818
- sphinx_rtd_theme>=1.3.0
1919
- sphinx_copybutton

docs/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import datetime
1515
import os
1616
import sys
17-
from typing import Dict
1817

1918
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
2019
# PyEnchant.
@@ -151,7 +150,7 @@ def setup(app):
151150

152151
# -- Options for LaTeX output ------------------------------------------------
153152

154-
latex_elements: Dict[str, str] = {
153+
latex_elements: dict[str, str] = {
155154
# The paper size ('letterpaper' or 'a4paper').
156155
#
157156
# 'papersize': 'letterpaper',

docs/guide/install.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Installation
77
Prerequisites
88
-------------
99

10-
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
10+
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
1111

1212
Windows
1313
~~~~~~~

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
2020

2121
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
2222

23+
SBX (SB3 + Jax): https://github.com/araffin/sbx
24+
2325

2426
Main Features
2527
--------------

docs/misc/changelog.rst

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,45 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a10 (WIP)
6+
Release 2.5.0a0 (WIP)
77
--------------------------
88

9-
**New algorithm: CrossQ in SB3 Contrib**
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
- Increased minimum required version of PyTorch to 2.3.0
12+
- Removed support for Python 3.8
13+
14+
New Features:
15+
^^^^^^^^^^^^^
16+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
17+
- Added official support for Python 3.12
18+
19+
Bug Fixes:
20+
^^^^^^^^^^
21+
22+
`SB3-Contrib`_
23+
^^^^^^^^^^^^^^
24+
25+
`RL Zoo`_
26+
^^^^^^^^^
27+
28+
`SBX`_ (SB3 + Jax)
29+
^^^^^^^^^^^^^^^^^^
30+
31+
Deprecations:
32+
^^^^^^^^^^^^^
33+
34+
Others:
35+
^^^^^^^
36+
37+
Documentation:
38+
^^^^^^^^^^^^^^
39+
40+
41+
Release 2.4.0 (2024-11-18)
42+
--------------------------
43+
44+
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
1045

1146
.. note::
1247

@@ -18,18 +53,20 @@ Release 2.4.0a10 (WIP)
1853
.. warning::
1954

2055
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
21-
and PyTorch < 2.0.
22-
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
56+
and PyTorch < 2.3.
57+
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).
2358

2459

2560
Breaking Changes:
2661
^^^^^^^^^^^^^^^^^
62+
- Increased minimum required version of Gymnasium to 0.29.1
2763

2864
New Features:
2965
^^^^^^^^^^^^^
3066
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
3167
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
3268
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
69+
- Added support for Gymnasium v1.0
3370

3471
Bug Fixes:
3572
^^^^^^^^^^
@@ -57,6 +94,7 @@ Bug Fixes:
5794
`SBX`_ (SB3 + Jax)
5895
^^^^^^^^^^^^^^^^^^
5996
- Added CNN support for DQN
97+
- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3
6098

6199
Deprecations:
62100
^^^^^^^^^^^^^
@@ -69,14 +107,13 @@ Others:
69107
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
70108
- Switched to uv to download packages faster on GitHub CI
71109
- Updated dependencies for read the doc
72-
73-
Bug Fixes:
74-
^^^^^^^^^^
110+
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
75111

76112
Documentation:
77113
^^^^^^^^^^^^^^
78114
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
79115
- Clarified documentation about planned features and citing software
116+
- Added a note about the fact we are optimizing log of ent coeff for SAC
80117

81118
Release 2.3.2 (2024-04-27)
82119
--------------------------

docs/modules/dqn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Notes
2525

2626
- Original paper: https://arxiv.org/abs/1312.5602
2727
- Further reference: https://www.nature.com/articles/nature14236
28+
- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial
2829

2930
.. note::
3031
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.

docs/modules/sac.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Notes
3535
which is the equivalent to the inverse of reward scale in the original SAC paper.
3636
The main reason is that it avoids having too high errors when updating the Q functions.
3737

38+
.. note::
39+
When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
40+
(see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).
3841

3942
.. note::
4043

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.8
5-
target-version = "py38"
4+
# Assume Python 3.9
5+
target-version = "py39"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/
@@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
1818
# ClassVar, implicit optional check not needed for tests
1919
"./tests/*.py" = ["RUF012", "RUF013"]
2020

21-
2221
[tool.ruff.lint.mccabe]
2322
# Unlike Flake8, default to a complexity level of 10.
2423
max-complexity = 15

0 commit comments

Comments
 (0)