Skip to content

Commit 9489b1a

Browse files
authored
Merge branch 'master' into feat/mps-support
2 parents 7c71688 + daaebd0 commit 9489b1a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+534
-490
lines changed

.github/workflows/ci.yml

+3-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
runs-on: ubuntu-latest
2121
strategy:
2222
matrix:
23-
python-version: ["3.8", "3.9", "3.10", "3.11"]
23+
python-version: ["3.9", "3.10", "3.11", "3.12"]
2424
include:
2525
# Default version
2626
- gymnasium-version: "1.0.0"
@@ -48,7 +48,8 @@ jobs:
4848
- name: Install specific version of gym
4949
run: |
5050
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
51-
# Only run for python 3.10, downgrade gym to 0.29.1
51+
uv pip install --system "numpy<2"
52+
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
5253
if: matrix.gymnasium-version != '1.0.0'
5354
- name: Lint with ruff
5455
run: |
@@ -62,8 +63,6 @@ jobs:
6263
- name: Type check
6364
run: |
6465
make type
65-
# Do not run for python 3.8 (mypy internal error)
66-
if: matrix.python-version != '3.8'
6766
- name: Test with pytest
6867
run: |
6968
make pytest

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster
100100

101101
## Installation
102102

103-
**Note:** Stable-Baselines3 supports PyTorch >= 1.13
103+
**Note:** Stable-Baselines3 supports PyTorch >= 2.3
104104

105105
### Prerequisites
106-
Stable Baselines3 requires Python 3.8+.
106+
Stable Baselines3 requires Python 3.9+.
107107

108108
#### Windows
109109

docs/conda_env.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- cloudpickle
1313
- opencv-python-headless
1414
- pandas
15-
- numpy>=1.20,<2.0
15+
- numpy>=1.20,<3.0
1616
- matplotlib
1717
- sphinx>=5,<9
1818
- sphinx_rtd_theme>=1.3.0

docs/conf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import datetime
1515
import os
1616
import sys
17-
from typing import Dict
1817

1918
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
2019
# PyEnchant.
@@ -151,7 +150,7 @@ def setup(app):
151150

152151
# -- Options for LaTeX output ------------------------------------------------
153152

154-
latex_elements: Dict[str, str] = {
153+
latex_elements: dict[str, str] = {
155154
# The paper size ('letterpaper' or 'a4paper').
156155
#
157156
# 'papersize': 'letterpaper',

docs/guide/install.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Installation
77
Prerequisites
88
-------------
99

10-
Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
10+
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3
1111

1212
Windows
1313
~~~~~~~

docs/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train
2020

2121
SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
2222

23+
SBX (SB3 + Jax): https://github.com/araffin/sbx
24+
2325

2426
Main Features
2527
--------------

docs/misc/changelog.rst

+39-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,42 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a11 (WIP)
6+
Release 2.5.0a0 (WIP)
7+
--------------------------
8+
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
- Increased minimum required version of PyTorch to 2.3.0
12+
- Removed support for Python 3.8
13+
14+
New Features:
15+
^^^^^^^^^^^^^
16+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
17+
- Added official support for Python 3.12
18+
19+
Bug Fixes:
20+
^^^^^^^^^^
21+
22+
`SB3-Contrib`_
23+
^^^^^^^^^^^^^^
24+
25+
`RL Zoo`_
26+
^^^^^^^^^
27+
28+
`SBX`_ (SB3 + Jax)
29+
^^^^^^^^^^^^^^^^^^
30+
31+
Deprecations:
32+
^^^^^^^^^^^^^
33+
34+
Others:
35+
^^^^^^^
36+
37+
Documentation:
38+
^^^^^^^^^^^^^^
39+
40+
41+
Release 2.4.0 (2024-11-18)
742
--------------------------
843

944
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**
@@ -18,13 +53,13 @@ Release 2.4.0a11 (WIP)
1853
.. warning::
1954

2055
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
21-
and PyTorch < 2.0.
22-
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
56+
and PyTorch < 2.3.
57+
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2).
2358

2459

2560
Breaking Changes:
2661
^^^^^^^^^^^^^^^^^
27-
- Increase minimum required version of Gymnasium to 0.29.1
62+
- Increased minimum required version of Gymnasium to 0.29.1
2863

2964
New Features:
3065
^^^^^^^^^^^^^
@@ -74,9 +109,6 @@ Others:
74109
- Updated dependencies for read the doc
75110
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``
76111

77-
Bug Fixes:
78-
^^^^^^^^^^
79-
80112
Documentation:
81113
^^^^^^^^^^^^^^
82114
- Updated PPO doc to recommend using CPU with ``MlpPolicy``

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
[tool.ruff]
22
# Same as Black.
33
line-length = 127
4-
# Assume Python 3.8
5-
target-version = "py38"
4+
# Assume Python 3.9
5+
target-version = "py39"
66

77
[tool.ruff.lint]
88
# See https://beta.ruff.rs/docs/rules/

setup.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
7878
install_requires=[
7979
"gymnasium>=0.29.1,<1.1.0",
80-
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
81-
"torch>=1.13",
80+
"numpy>=1.20,<3.0",
81+
"torch>=2.3,<3.0",
8282
# For saving models
8383
"cloudpickle",
8484
# For reading logs
@@ -135,7 +135,7 @@
135135
long_description=long_description,
136136
long_description_content_type="text/markdown",
137137
version=__version__,
138-
python_requires=">=3.8",
138+
python_requires=">=3.9",
139139
# PyPI package information.
140140
project_urls={
141141
"Code": "https://github.com/DLR-RM/stable-baselines3",
@@ -147,10 +147,10 @@
147147
},
148148
classifiers=[
149149
"Programming Language :: Python :: 3",
150-
"Programming Language :: Python :: 3.8",
151150
"Programming Language :: Python :: 3.9",
152151
"Programming Language :: Python :: 3.10",
153152
"Programming Language :: Python :: 3.11",
153+
"Programming Language :: Python :: 3.12",
154154
],
155155
)
156156

stable_baselines3/a2c/a2c.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
1+
from typing import Any, ClassVar, Optional, TypeVar, Union
22

33
import torch as th
44
from gymnasium import spaces
@@ -57,15 +57,15 @@ class A2C(OnPolicyAlgorithm):
5757
:param _init_setup_model: Whether or not to build the network at the creation of the instance
5858
"""
5959

60-
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
60+
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
6161
"MlpPolicy": ActorCriticPolicy,
6262
"CnnPolicy": ActorCriticCnnPolicy,
6363
"MultiInputPolicy": MultiInputActorCriticPolicy,
6464
}
6565

6666
def __init__(
6767
self,
68-
policy: Union[str, Type[ActorCriticPolicy]],
68+
policy: Union[str, type[ActorCriticPolicy]],
6969
env: Union[GymEnv, str],
7070
learning_rate: Union[float, Schedule] = 7e-4,
7171
n_steps: int = 5,
@@ -78,12 +78,12 @@ def __init__(
7878
use_rms_prop: bool = True,
7979
use_sde: bool = False,
8080
sde_sample_freq: int = -1,
81-
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
82-
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
81+
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
82+
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
8383
normalize_advantage: bool = False,
8484
stats_window_size: int = 100,
8585
tensorboard_log: Optional[str] = None,
86-
policy_kwargs: Optional[Dict[str, Any]] = None,
86+
policy_kwargs: Optional[dict[str, Any]] = None,
8787
verbose: int = 0,
8888
seed: Optional[int] = None,
8989
device: Union[th.device, str] = "auto",

stable_baselines3/common/atari_wrappers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, SupportsFloat
1+
from typing import SupportsFloat
22

33
import gymnasium as gym
44
import numpy as np
@@ -64,7 +64,7 @@ def reset(self, **kwargs) -> AtariResetReturn:
6464
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
6565
assert noops > 0
6666
obs = np.zeros(0)
67-
info: Dict = {}
67+
info: dict = {}
6868
for _ in range(noops):
6969
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
7070
if terminated or truncated:

stable_baselines3/common/base_class.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import warnings
77
from abc import ABC, abstractmethod
88
from collections import deque
9-
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
9+
from collections.abc import Iterable
10+
from typing import Any, ClassVar, Optional, TypeVar, Union
1011

1112
import gymnasium as gym
1213
import numpy as np
@@ -94,7 +95,7 @@ class BaseAlgorithm(ABC):
9495
"""
9596

9697
# Policy aliases (see _get_policy_from_name())
97-
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {}
98+
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {}
9899
policy: BasePolicy
99100
observation_space: spaces.Space
100101
action_space: spaces.Space
@@ -104,10 +105,10 @@ class BaseAlgorithm(ABC):
104105

105106
def __init__(
106107
self,
107-
policy: Union[str, Type[BasePolicy]],
108+
policy: Union[str, type[BasePolicy]],
108109
env: Union[GymEnv, str, None],
109110
learning_rate: Union[float, Schedule],
110-
policy_kwargs: Optional[Dict[str, Any]] = None,
111+
policy_kwargs: Optional[dict[str, Any]] = None,
111112
stats_window_size: int = 100,
112113
tensorboard_log: Optional[str] = None,
113114
verbose: int = 0,
@@ -117,7 +118,7 @@ def __init__(
117118
seed: Optional[int] = None,
118119
use_sde: bool = False,
119120
sde_sample_freq: int = -1,
120-
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
121+
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
121122
) -> None:
122123
if isinstance(policy, str):
123124
self.policy_class = self._get_policy_from_name(policy)
@@ -141,10 +142,10 @@ def __init__(
141142
self.start_time = 0.0
142143
self.learning_rate = learning_rate
143144
self.tensorboard_log = tensorboard_log
144-
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
145+
self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
145146
self._last_episode_starts = None # type: Optional[np.ndarray]
146147
# When using VecNormalize:
147-
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
148+
self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
148149
self._episode_num = 0
149150
# Used for gSDE only
150151
self.use_sde = use_sde
@@ -283,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps
283284
"""
284285
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
285286

286-
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
287+
def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None:
287288
"""
288289
Update the optimizers learning rate using the current learning rate schedule
289290
and the current progress remaining (from 1 to 0).
@@ -299,7 +300,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
299300
for optimizer in optimizers:
300301
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
301302

302-
def _excluded_save_params(self) -> List[str]:
303+
def _excluded_save_params(self) -> list[str]:
303304
"""
304305
Returns the names of the parameters that should be excluded from being
305306
saved by pickling. E.g. replay buffers are skipped by default
@@ -320,7 +321,7 @@ def _excluded_save_params(self) -> List[str]:
320321
"_custom_logger",
321322
]
322323

323-
def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
324+
def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
324325
"""
325326
Get a policy class from its name representation.
326327
@@ -337,7 +338,7 @@ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
337338
else:
338339
raise ValueError(f"Policy {policy_name} unknown")
339340

340-
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
341+
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
341342
"""
342343
Get the name of the torch variables that will be saved with
343344
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
@@ -387,7 +388,7 @@ def _setup_learn(
387388
reset_num_timesteps: bool = True,
388389
tb_log_name: str = "run",
389390
progress_bar: bool = False,
390-
) -> Tuple[int, BaseCallback]:
391+
) -> tuple[int, BaseCallback]:
391392
"""
392393
Initialize different variables needed for training.
393394
@@ -435,7 +436,7 @@ def _setup_learn(
435436

436437
return total_timesteps, callback
437438

438-
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
439+
def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
439440
"""
440441
Retrieve reward, episode length, episode success and update the buffer
441442
if using Monitor wrapper or a GoalEnv.
@@ -535,11 +536,11 @@ def learn(
535536

536537
def predict(
537538
self,
538-
observation: Union[np.ndarray, Dict[str, np.ndarray]],
539-
state: Optional[Tuple[np.ndarray, ...]] = None,
539+
observation: Union[np.ndarray, dict[str, np.ndarray]],
540+
state: Optional[tuple[np.ndarray, ...]] = None,
540541
episode_start: Optional[np.ndarray] = None,
541542
deterministic: bool = False,
542-
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
543+
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
543544
"""
544545
Get the policy action from an observation (and optional hidden state).
545546
Includes sugar-coating to handle different observations (e.g. normalizing images).
@@ -640,11 +641,11 @@ def set_parameters(
640641

641642
@classmethod
642643
def load( # noqa: C901
643-
cls: Type[SelfBaseAlgorithm],
644+
cls: type[SelfBaseAlgorithm],
644645
path: Union[str, pathlib.Path, io.BufferedIOBase],
645646
env: Optional[GymEnv] = None,
646647
device: Union[th.device, str] = "auto",
647-
custom_objects: Optional[Dict[str, Any]] = None,
648+
custom_objects: Optional[dict[str, Any]] = None,
648649
print_system_info: bool = False,
649650
force_reset: bool = True,
650651
**kwargs,
@@ -800,7 +801,7 @@ def load( # noqa: C901
800801
model.policy.reset_noise() # type: ignore[operator]
801802
return model
802803

803-
def get_parameters(self) -> Dict[str, Dict]:
804+
def get_parameters(self) -> dict[str, dict]:
804805
"""
805806
Return the parameters of the agent. This includes parameters from different networks, e.g.
806807
critics (value functions) and policies (pi functions).

0 commit comments

Comments
 (0)