Skip to content

Commit 9489b1a

Browse files
authored
Merge branch 'master' into feat/mps-support
2 parents 7c71688 + daaebd0 commit 9489b1a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+534
-490
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10", "3.11"]
23+
python-version: ["3.9", "3.10", "3.11", "3.12"]
2424
include:
2525
# Default version
2626
- gymnasium-version: "1.0.0"
@@ -48,7 +48,8 @@ jobs:
4848
- name: Install specific version of gym
4949
run: |
5050
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51-
# Only run for python 3.10, downgrade gym to 0.29.1
51+
uv pip install --system "numpy<2"
52+
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
5253
if: matrix.gymnasium-version != '1.0.0'
5354
- name: Lint with ruff
5455
run: |
@@ -62,8 +63,6 @@ jobs:
6263
- name: Type check
6364
run: |
6465
make type
65-
# Do not run for python 3.8 (mypy internal error)
66-
if: matrix.python-version != '3.8'
6766
- name: Test with pytest
6867
run: |
6968
make pytest

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
100100

101101
## Installation
102102

103-
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
103+
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
104104

105105
### Prerequisites
106-
Stable Baselines3 requires Python 3.8+.
106+
Stable Baselines3 requires Python 3.9+.
107107

108108
#### Windows
109109

docs/conda_env.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas
15-
- numpy>=1.20,<2.0
15+
- numpy>=1.20,<3.0
1616
- matplotlib
1717
- sphinx>=5,<9
1818
- sphinx_rtd_theme>=1.3.0

docs/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import datetime
1515
import os
1616
import sys
17-
from typing import Dict
1817

1918
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
2019
# PyEnchant.
@@ -151,7 +150,7 @@ def setup(app):
151150

152151
# -- Options for LaTeX output ------------------------------------------------
153152

154-
latex_elements: Dict[str, str] = {
153+
latex_elements: dict[str, str] = {
155154
# The paper size ('letterpaper' or 'a4paper').
156155
#
157156
# 'papersize': 'letterpaper',

docs/guide/install.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Installation
77
Prerequisites
88
-------------
99

10-
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
10+
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
1111

1212
Windows
1313
~~~~~~~

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
2020

2121
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
2222

23+
SBX (SB3 + Jax): https://github.com/araffin/sbx
24+
2325

2426
Main Features
2527
--------------

docs/misc/changelog.rst

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,42 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a11 (WIP)
6+
Release 2.5.0a0 (WIP)
7+
--------------------------
8+
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
- Increased minimum required version of PyTorch to 2.3.0
12+
- Removed support for Python 3.8
13+
14+
New Features:
15+
^^^^^^^^^^^^^
16+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
17+
- Added official support for Python 3.12
18+
19+
Bug Fixes:
20+
^^^^^^^^^^
21+
22+
`SB3-Contrib`_
23+
^^^^^^^^^^^^^^
24+
25+
`RL Zoo`_
26+
^^^^^^^^^
27+
28+
`SBX`_ (SB3 + Jax)
29+
^^^^^^^^^^^^^^^^^^
30+
31+
Deprecations:
32+
^^^^^^^^^^^^^
33+
34+
Others:
35+
^^^^^^^
36+
37+
Documentation:
38+
^^^^^^^^^^^^^^
39+
40+
41+
Release 2.4.0 (2024-11-18)
742
--------------------------
843

944
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
@@ -18,13 +53,13 @@ Release 2.4.0a11 (WIP)
1853
.. warning::
1954

2055
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
21-
and PyTorch < 2.0.
22-
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
56+
and PyTorch < 2.3.
57+
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).
2358

2459

2560
Breaking Changes:
2661
^^^^^^^^^^^^^^^^^
27-
- Increase minimum required version of Gymnasium to 0.29.1
62+
- Increased minimum required version of Gymnasium to 0.29.1
2863

2964
New Features:
3065
^^^^^^^^^^^^^
@@ -74,9 +109,6 @@ Others:
74109
- Updated dependencies for read the doc
75110
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
76111

77-
Bug Fixes:
78-
^^^^^^^^^^
79-
80112
Documentation:
81113
^^^^^^^^^^^^^^
82114
- Updated PPO doc to recommend using CPU with ``MlpPolicy``

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.8
5-
target-version = "py38"
4+
# Assume Python 3.9
5+
target-version = "py39"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
7878
install_requires=[
7979
"gymnasium>=0.29.1,<1.1.0",
80-
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
81-
"torch>=1.13",
80+
"numpy>=1.20,<3.0",
81+
"torch>=2.3,<3.0",
8282
# For saving models
8383
"cloudpickle",
8484
# For reading logs
@@ -135,7 +135,7 @@
135135
long_description=long_description,
136136
long_description_content_type="text/markdown",
137137
version=__version__,
138-
python_requires=">=3.8",
138+
python_requires=">=3.9",
139139
# PyPI package information.
140140
project_urls={
141141
"Code": "https://github.com/DLR-RM/stable-baselines3",
@@ -147,10 +147,10 @@
147147
},
148148
classifiers=[
149149
"Programming Language :: Python :: 3",
150-
"Programming Language :: Python :: 3.8",
151150
"Programming Language :: Python :: 3.9",
152151
"Programming Language :: Python :: 3.10",
153152
"Programming Language :: Python :: 3.11",
153+
"Programming Language :: Python :: 3.12",
154154
],
155155
)
156156

stable_baselines3/a2c/a2c.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
1+
from typing import Any, ClassVar, Optional, TypeVar, Union
22

33
import torch as th
44
from gymnasium import spaces
@@ -57,15 +57,15 @@ class A2C(OnPolicyAlgorithm):
5757
:param _init_setup_model: Whether or not to build the network at the creation of the instance
5858
"""
5959

60-
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
60+
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
6161
"MlpPolicy": ActorCriticPolicy,
6262
"CnnPolicy": ActorCriticCnnPolicy,
6363
"MultiInputPolicy": MultiInputActorCriticPolicy,
6464
}
6565

6666
def __init__(
6767
self,
68-
policy: Union[str, Type[ActorCriticPolicy]],
68+
policy: Union[str, type[ActorCriticPolicy]],
6969
env: Union[GymEnv, str],
7070
learning_rate: Union[float, Schedule] = 7e-4,
7171
n_steps: int = 5,
@@ -78,12 +78,12 @@ def __init__(
7878
use_rms_prop: bool = True,
7979
use_sde: bool = False,
8080
sde_sample_freq: int = -1,
81-
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
82-
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
81+
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
82+
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
8383
normalize_advantage: bool = False,
8484
stats_window_size: int = 100,
8585
tensorboard_log: Optional[str] = None,
86-
policy_kwargs: Optional[Dict[str, Any]] = None,
86+
policy_kwargs: Optional[dict[str, Any]] = None,
8787
verbose: int = 0,
8888
seed: Optional[int] = None,
8989
device: Union[th.device, str] = "auto",

0 commit comments

Comments
 (0)