diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 822e0cb3f..a0077a367 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,8 +20,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - + python-version: ["3.9", "3.10", "3.11", "3.12"] + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -31,18 +36,21 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu - pip install .[extra_no_roms,tests,docs] + uv pip install --system .[extra,tests,docs] # Use headless version - pip install opencv-python-headless + uv pip install --system opencv-python-headless + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + uv pip install --system "numpy<2" + # Only run for python 3.10, downgrade gym to 0.29.1, numpy<2 + if: matrix.gymnasium-version != '1.0.0' - name: Lint with ruff run: | make lint @@ -55,8 +63,6 @@ jobs: - name: Type check run: | make type - # Do not run for python 3.8 (mypy internal error) - if: matrix.python-version != '3.8' - name: Test with pytest run: | make pytest diff --git a/.readthedocs.yml b/.readthedocs.yml index dbb2fad03..26f0c883b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -16,6 +16,6 @@ conda: environment: docs/conda_env.yml build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "mambaforge-22.9" + python: "mambaforge-23.11" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d295269a9..cc5d1075b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ into two categories: - Create an issue about your intended feature, and we shall discuss the design and implementation. Once we agree that the plan looks good, go ahead and implement it. 2. You want to implement a feature or bug-fix for an outstanding issue - - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues + - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted - Pick an issue or feature and comment on the task that you want to work on this feature. - If you need more context on a particular issue, please ask, and we shall provide. diff --git a/README.md b/README.md index 52634e486..9ae78b239 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg) -[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) +[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) +[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to **The performance of each algorithm was tested** (see *Results* section in their respective page), you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. +We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform. + | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin ### Planned features -Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones). +Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*. +If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement). + +While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories: +- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository +- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository +- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299) ## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3) @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/ We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) -This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) @@ -92,22 +100,21 @@ It provides a minimal number of features compared to SB3 but can be much faster ## Installation -**Note:** Stable-Baselines3 supports PyTorch >= 1.13 +**Note:** Stable-Baselines3 supports PyTorch >= 2.3 ### Prerequisites -Stable Baselines3 requires Python 3.8+. +Stable Baselines3 requires Python 3.9+. -#### Windows 10 +#### Windows To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites). ### Install using pip Install the Stable Baselines3 package: +```sh +pip install 'stable-baselines3[extra]' ``` -pip install stable-baselines3[extra] -``` -**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)). This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use: ```sh @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks: | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | | ARS[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| CrossQ[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks: 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. -Actions `gym.spaces`: +Actions `gymnasium.spaces`: * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. @@ -218,9 +226,9 @@ To run a single test: python3 -m pytest -v -k 'test_check_env_dict_action' ``` -You can also do a static type check using `pytype` and `mypy`: +You can also do a static type check using `mypy`: ```sh -pip install pytype mypy +pip install mypy make type ``` @@ -252,6 +260,8 @@ To cite this repository in publications: } ``` +Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988). + ## Maintainers Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 53fecf278..ee491017b 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -1,19 +1,19 @@ name: root channels: - pytorch - - defaults + - conda-forge dependencies: - cpuonly=1.0=0 - - pip=22.3.1 - - python=3.8 - - pytorch=1.13.0=py3.8_cpu_0 + - pip=24.2 + - python=3.11 + - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium + - gymnasium>=0.29.1,<1.1.0 - cloudpickle - opencv-python-headless - pandas - - numpy + - numpy>=1.20,<3.0 - matplotlib - - sphinx>=5,<8 + - sphinx>=5,<9 - sphinx_rtd_theme>=1.3.0 - sphinx_copybutton diff --git a/docs/conf.py b/docs/conf.py index bd6365701..7e0555e57 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,6 @@ import datetime import os import sys -from typing import Dict # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. @@ -151,7 +150,7 @@ def setup(app): # -- Options for LaTeX output ------------------------------------------------ -latex_elements: Dict[str, str] = { +latex_elements: dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index d5e7ae1d2..db03ba292 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` =================== =========== ============ ================= =============== ================ ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️ A2C ✔️ ✔️ ✔️ ✔️ ✔️ +CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️ DDPG ✔️ ❌ ❌ ❌ ✔️ DQN ❌ ✔️ ❌ ❌ ✔️ HER ✔️ ✔️ ❌ ❌ ✔️ diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 587234b00..4bdd0a007 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -7,7 +7,7 @@ Installation Prerequisites ------------- -Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13 +Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3 Windows ~~~~~~~ diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 445832c59..8ec912e15 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -42,6 +42,7 @@ See documentation for the full list of included features. - `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ +- `Batch Normalization in Deep Reinforcement Learning (CrossQ) `_ **Gym Wrappers**: diff --git a/docs/index.rst b/docs/index.rst index c8a70a94b..6b6018b42 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +SBX (SB3 + Jax): https://github.com/araffin/sbx + Main Features -------------- @@ -113,12 +115,14 @@ To cite this project in publications: url = {http://jmlr.org/papers/v22/20-1364.html} } +Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI `_. + Contributing ------------ To any interested in making the rl baselines better, there are still some improvements that need to be done. -You can check issues in the `repo `_. +You can check issues in the `repository `_. If you want to contribute, please read `CONTRIBUTING.md `_ first. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e8a2984d2..937853979 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,46 @@ Changelog ========== -Release 2.4.0a9 (WIP) +Release 2.5.0a0 (WIP) -------------------------- +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Increased minimum required version of PyTorch to 2.3.0 +- Removed support for Python 3.8 + +New Features: +^^^^^^^^^^^^^ +- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too +- Added official support for Python 3.12 + +Bug Fixes: +^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + +Release 2.4.0 (2024-11-18) +-------------------------- + +**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support** + .. note:: DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about @@ -16,18 +53,20 @@ Release 2.4.0a9 (WIP) .. warning:: Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024) - and PyTorch < 2.0. - We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0. + and PyTorch < 2.3. + We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.3 (compatible with NumPy v2). Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Increased minimum required version of Gymnasium to 0.29.1 New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) - Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces +- Added support for Gymnasium v1.0 Bug Fixes: ^^^^^^^^^^ @@ -43,6 +82,10 @@ Bug Fixes: `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) +- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen) +- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) +- Fixed loading QRDQN changes `target_update_interval` (@jak3122) `RL Zoo`_ ^^^^^^^^^ @@ -51,6 +94,7 @@ Bug Fixes: `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ - Added CNN support for DQN +- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3 Deprecations: ^^^^^^^^^^^^^ @@ -60,12 +104,16 @@ Others: - Fixed various typos (@cschindlbeck) - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 - -Bug Fixes: -^^^^^^^^^^ +- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` +- Switched to uv to download packages faster on GitHub CI +- Updated dependencies for read the doc +- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs`` Documentation: ^^^^^^^^^^^^^^ +- Updated PPO doc to recommend using CPU with ``MlpPolicy`` +- Clarified documentation about planned features and citing software +- Added a note about the fact we are optimizing log of ent coeff for SAC Release 2.3.2 (2024-04-27) -------------------------- @@ -653,6 +701,7 @@ New Features: - Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala) - Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio) - The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys +- Use MacOS Metal "mps" device when available `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -710,6 +759,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Save cloudpickle version `SB3-Contrib`_ diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 85d486661..78f70f698 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -25,6 +25,7 @@ Notes - Original paper: https://arxiv.org/abs/1312.5602 - Further reference: https://www.nature.com/articles/nature14236 +- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial .. note:: This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index b5e667241..4285cfb50 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. vec_env.render("human") +.. note:: + + PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``: + + .. code-block:: + + from stable_baselines3 import PPO + from stable_baselines3.common.env_util import make_vec_env + from stable_baselines3.common.vec_env import SubprocVecEnv + + if __name__=="__main__": + env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) + model = PPO("MlpPolicy", env, device="cpu") + model.learn(total_timesteps=25_000) + + For more information, see :ref:`Vectorized Environments `, `Issue #1245 `_ or the `Multiprocessing notebook `_. + Results ------- diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 960a282dc..cf6191bc4 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -35,6 +35,9 @@ Notes which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions. +.. note:: + When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable + (see issues `GH#36 `_, `#55 `_ and others). .. note:: diff --git a/pyproject.toml b/pyproject.toml index dd435a33e..89af5a67f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.ruff] # Same as Black. line-length = 127 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.9 +target-version = "py39" [tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"] # ClassVar, implicit optional check not needed for tests "./tests/*.py" = ["RUF012", "RUF013"] - [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 9d56dfd77..fa24fc8a3 100644 --- a/setup.py +++ b/setup.py @@ -70,39 +70,15 @@ """ # noqa:E501 -# Atari Games download is sometimes problematic: -# https://github.com/Farama-Foundation/AutoROM/issues/39 -# That's why we define extra packages without it. -extra_no_roms = [ - # For render - "opencv-python", - "pygame", - # Tensorboard support - "tensorboard>=2.9.1", - # Checking memory taken by replay buffer - "psutil", - # For progress bar callback - "tqdm", - "rich", - # For atari games, - "shimmy[atari]~=1.3.0", - "pillow", -] - -extra_packages = extra_no_roms + [ # noqa: RUF005 - # For atari roms, - "autorom[accept-rom-license]~=0.6.1", -] - setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<0.30", - "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 - "torch>=1.13", + "gymnasium>=0.29.1,<1.1.0", + "numpy>=1.20,<3.0", + "torch>=2.3,<3.0", # For saving models "cloudpickle", # For reading logs @@ -125,7 +101,7 @@ "black>=24.2.0,<25", ], "docs": [ - "sphinx>=5,<8", + "sphinx>=5,<9", "sphinx-autobuild", "sphinx-rtd-theme>=1.3.0", # For spelling @@ -133,8 +109,21 @@ # Copy button for code snippets "sphinx_copybutton", ], - "extra": extra_packages, - "extra_no_roms": extra_no_roms, + "extra": [ + # For render + "opencv-python", + "pygame", + # Tensorboard support + "tensorboard>=2.9.1", + # Checking memory taken by replay buffer + "psutil", + # For progress bar callback + "tqdm", + "rich", + # For atari games, + "ale-py>=0.9.0", + "pillow", + ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", @@ -146,7 +135,7 @@ long_description=long_description, long_description_content_type="text/markdown", version=__version__, - python_requires=">=3.8", + python_requires=">=3.9", # PyPI package information. project_urls={ "Code": "https://github.com/DLR-RM/stable-baselines3", @@ -158,10 +147,10 @@ }, classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], ) diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f0c..a125aaef6 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import torch as th from gymnasium import spaces @@ -57,7 +57,7 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, @@ -65,7 +65,7 @@ class A2C(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 7e-4, n_steps: int = 5, @@ -78,12 +78,12 @@ def __init__( use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index bbdba9a3d..83a64a5c7 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,4 +1,4 @@ -from typing import Dict, SupportsFloat +from typing import SupportsFloat import gymnasium as gym import numpy as np @@ -64,7 +64,7 @@ def reset(self, **kwargs) -> AtariResetReturn: noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) - info: Dict = {} + info: dict = {} for _ in range(noops): obs, _, terminated, truncated, info = self.env.step(self.noop_action) if terminated or truncated: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e43955f94..412f9dda2 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -6,7 +6,8 @@ import warnings from abc import ABC, abstractmethod from collections import deque -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from collections.abc import Iterable +from typing import Any, ClassVar, Optional, TypeVar, Union import gymnasium as gym import numpy as np @@ -94,7 +95,7 @@ class BaseAlgorithm(ABC): """ # Policy aliases (see _get_policy_from_name()) - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {} + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {} policy: BasePolicy observation_space: spaces.Space action_space: spaces.Space @@ -104,10 +105,10 @@ class BaseAlgorithm(ABC): def __init__( self, - policy: Union[str, Type[BasePolicy]], + policy: Union[str, type[BasePolicy]], env: Union[GymEnv, str, None], learning_rate: Union[float, Schedule], - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -117,7 +118,7 @@ def __init__( seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ) -> None: if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) @@ -141,10 +142,10 @@ def __init__( self.start_time = 0.0 self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log - self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]] self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: - self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]] self._episode_num = 0 # Used for gSDE only self.use_sde = use_sde @@ -283,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps """ self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps) - def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None: + def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). @@ -299,7 +300,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o for optimizer in optimizers: update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: """ Returns the names of the parameters that should be excluded from being saved by pickling. E.g. replay buffers are skipped by default @@ -320,7 +321,7 @@ def _excluded_save_params(self) -> List[str]: "_custom_logger", ] - def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]: """ Get a policy class from its name representation. @@ -337,7 +338,7 @@ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: else: raise ValueError(f"Policy {policy_name} unknown") - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: """ Get the name of the torch variables that will be saved with PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default @@ -387,7 +388,7 @@ def _setup_learn( reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, - ) -> Tuple[int, BaseCallback]: + ) -> tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -435,7 +436,7 @@ def _setup_learn( return total_timesteps, callback - def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: + def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: """ Retrieve reward, episode length, episode success and update the buffer if using Monitor wrapper or a GoalEnv. @@ -535,11 +536,11 @@ def learn( def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -640,11 +641,11 @@ def set_parameters( @classmethod def load( # noqa: C901 - cls: Type[SelfBaseAlgorithm], + cls: type[SelfBaseAlgorithm], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", - custom_objects: Optional[Dict[str, Any]] = None, + custom_objects: Optional[dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, **kwargs, @@ -800,7 +801,7 @@ def load( # noqa: C901 model.policy.reset_noise() # type: ignore[operator] return model - def get_parameters(self) -> Dict[str, Dict]: + def get_parameters(self) -> dict[str, dict]: """ Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions). diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index b2fc5a710..d0eb0856c 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -1,6 +1,7 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from collections.abc import Generator +from typing import Any, Optional, Union import numpy as np import torch as th @@ -36,7 +37,7 @@ class BaseBuffer(ABC): """ observation_space: spaces.Space - obs_shape: Tuple[int, ...] + obs_shape: tuple[int, ...] def __init__( self, @@ -135,14 +136,16 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: :return: """ if copy: + if hasattr(th, "backends") and th.backends.mps.is_built(): + return th.tensor(array, dtype=th.float32, device=self.device) return th.tensor(array, device=self.device) return th.as_tensor(array, device=self.device) @staticmethod def _normalize_obs( - obs: Union[np.ndarray, Dict[str, np.ndarray]], + obs: Union[np.ndarray, dict[str, np.ndarray]], env: Optional[VecNormalize] = None, - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + ) -> Union[np.ndarray, dict[str, np.ndarray]]: if env is not None: return env.normalize_obs(obs) return obs @@ -250,7 +253,7 @@ def add( action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) @@ -538,9 +541,9 @@ class DictReplayBuffer(ReplayBuffer): """ observation_space: spaces.Dict - obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] - observations: Dict[str, np.ndarray] # type: ignore[assignment] - next_observations: Dict[str, np.ndarray] # type: ignore[assignment] + obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] + observations: dict[str, np.ndarray] # type: ignore[assignment] + next_observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, @@ -609,12 +612,12 @@ def __init__( def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], - next_obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], + next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # Copy to avoid modification by reference for key in self.observations.keys(): @@ -718,8 +721,8 @@ class DictRolloutBuffer(RolloutBuffer): """ observation_space: spaces.Dict - obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] - observations: Dict[str, np.ndarray] # type: ignore[assignment] + obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] + observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, @@ -757,7 +760,7 @@ def reset(self) -> None: def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, episode_start: np.ndarray, diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c7841866b..31c3a24a7 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,7 +1,7 @@ import os import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -45,8 +45,8 @@ def __init__(self, verbose: int = 0): # n_envs * n times env.step() was called self.num_timesteps = 0 # type: int self.verbose = verbose - self.locals: Dict[str, Any] = {} - self.globals: Dict[str, Any] = {} + self.locals: dict[str, Any] = {} + self.globals: dict[str, Any] = {} # Sometimes, for event callback, it is useful # to have access to the parent object self.parent = None # type: Optional[BaseCallback] @@ -75,7 +75,7 @@ def init_callback(self, model: "base_class.BaseAlgorithm") -> None: def _init_callback(self) -> None: pass - def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: # Those are reference and will be updated automatically self.locals = locals_ self.globals = globals_ @@ -125,7 +125,7 @@ def on_rollout_end(self) -> None: def _on_rollout_end(self) -> None: pass - def update_locals(self, locals_: Dict[str, Any]) -> None: + def update_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -134,7 +134,7 @@ def update_locals(self, locals_: Dict[str, Any]) -> None: self.locals.update(locals_) self.update_child_locals(locals_) - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables on sub callbacks. @@ -177,7 +177,7 @@ def _on_event(self) -> bool: def _on_step(self) -> bool: return True - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -195,7 +195,7 @@ class CallbackList(BaseCallback): sequentially. """ - def __init__(self, callbacks: List[BaseCallback]): + def __init__(self, callbacks: list[BaseCallback]): super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks @@ -231,7 +231,7 @@ def _on_training_end(self) -> None: for callback in self.callbacks: callback.on_training_end() - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -328,7 +328,7 @@ class ConvertCallback(BaseCallback): :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0): + def __init__(self, callback: Optional[Callable[[dict[str, Any], dict[str, Any]], bool]], verbose: int = 0): super().__init__(verbose) self.callback = callback @@ -405,12 +405,12 @@ def __init__( if log_path is not None: log_path = os.path.join(log_path, "evaluations") self.log_path = log_path - self.evaluations_results: List[List[float]] = [] - self.evaluations_timesteps: List[int] = [] - self.evaluations_length: List[List[int]] = [] + self.evaluations_results: list[list[float]] = [] + self.evaluations_timesteps: list[int] = [] + self.evaluations_length: list[list[int]] = [] # For computing success rate - self._is_success_buffer: List[bool] = [] - self.evaluations_successes: List[List[bool]] = [] + self._is_success_buffer: list[bool] = [] + self.evaluations_successes: list[list[bool]] = [] def _init_callback(self) -> None: # Does not work in some corner cases, where the wrapper is not the same @@ -427,7 +427,7 @@ def _init_callback(self) -> None: if self.callback_on_new_best is not None: self.callback_on_new_best.init_callback(self.model) - def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: """ Callback passed to the ``evaluate_policy`` function in order to log the success rate (when applicable), @@ -530,7 +530,7 @@ def _on_step(self) -> bool: return continue_training - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 132a35348..380898a50 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,7 +1,7 @@ """Probability distributions.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -30,7 +30,7 @@ def __init__(self): self.distribution = None @abstractmethod - def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]: + def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]: """Create the layers and parameters that represent the distribution. Subclasses must define this, but the arguments and return type vary between @@ -98,7 +98,7 @@ def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ @abstractmethod - def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities from the probability distribution given its parameters. @@ -135,7 +135,7 @@ def __init__(self, action_dim: int): self.mean_actions = None self.log_std = None - def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -190,7 +190,7 @@ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deter self.proba_distribution(mean_actions, log_std) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. @@ -254,7 +254,7 @@ def mode(self) -> th.Tensor: # Squash the output return th.tanh(self.gaussian_actions) - def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: action = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(action, self.gaussian_actions) return action, log_prob @@ -305,7 +305,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -318,7 +318,7 @@ class MultiCategoricalDistribution(Distribution): :param action_dims: List of sizes of discrete action spaces """ - def __init__(self, action_dims: List[int]): + def __init__(self, action_dims: list[int]): super().__init__() self.action_dims = action_dims @@ -362,7 +362,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -412,7 +412,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -513,7 +513,7 @@ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: def proba_distribution_net( self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None - ) -> Tuple[nn.Module, nn.Parameter]: + ) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the @@ -611,7 +611,7 @@ def actions_from_params( def log_prob_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor - ) -> Tuple[th.Tensor, th.Tensor]: + ) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(mean_actions, log_std, latent_sde) log_prob = self.log_prob(actions) return actions, log_prob @@ -661,7 +661,7 @@ def log_prob_correction(self, x: th.Tensor) -> th.Tensor: def make_proba_distribution( - action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None + action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[dict[str, Any]] = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index e47dd123a..0310bcfe7 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Union +from typing import Any, Union import gymnasium as gym import numpy as np @@ -172,10 +172,10 @@ def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: def _check_goal_env_compute_reward( - obs: Dict[str, Union[np.ndarray, int]], + obs: dict[str, Union[np.ndarray, int]], env: gym.Env, reward: float, - info: Dict[str, Any], + info: dict[str, Any], ) -> None: """ Check that reward is computed with `compute_reward` diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 0132c32f8..bbe281f27 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Optional, Union import gymnasium as gym @@ -9,7 +9,7 @@ from stable_baselines3.common.vec_env.patch_gym import _patch_env -def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: +def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> Optional[gym.Wrapper]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. @@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g return None -def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool: +def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. @@ -43,11 +43,11 @@ def make_vec_env( start_index: int = 0, monitor_dir: Optional[str] = None, wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, - monitor_kwargs: Optional[Dict[str, Any]] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[dict[str, Any]] = None, + vec_env_cls: Optional[type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[dict[str, Any]] = None, + monitor_kwargs: Optional[dict[str, Any]] = None, + wrapper_kwargs: Optional[dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored ``VecEnv``. @@ -134,11 +134,11 @@ def make_atari_env( seed: Optional[int] = None, start_index: int = 0, monitor_dir: Optional[str] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, - monitor_kwargs: Optional[Dict[str, Any]] = None, + wrapper_kwargs: Optional[dict[str, Any]] = None, + env_kwargs: Optional[dict[str, Any]] = None, + vec_env_cls: Optional[Union[type[DummyVecEnv], type[SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[dict[str, Any]] = None, + monitor_kwargs: Optional[dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 3ea0c7bb0..d0d48df0c 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np from gymnasium import Env, spaces @@ -75,14 +75,17 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: :param state: :return: """ + if self.discrete_obs_space: + # Convert from int8 to int32 for NumPy 2.0 + state = state.astype(np.int32) # The internal state is the binary representation of the # observed one - return int(sum(state[i] * 2**i for i in range(len(state)))) + return int(sum(int(state[i]) * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) - image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) + image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8))) return image.reshape(self.image_shape).astype(np.uint8) return state @@ -163,7 +166,7 @@ def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: boo } ) - def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: + def _get_obs(self) -> dict[str, Union[int, np.ndarray]]: """ Helper to create the observation. @@ -178,8 +181,8 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ) def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict] = None - ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + self, *, seed: Optional[int] = None, options: Optional[dict] = None + ) -> tuple[dict[str, Union[int, np.ndarray]], dict]: if seed is not None: self._obs_space.seed(seed) self.current_step = 0 @@ -207,7 +210,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: return obs, reward, terminated, truncated, info def compute_reward( - self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] + self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[dict[str, Any]] ) -> np.float32: # As we are using a vectorized version, we need to keep track of the `batch_size` if isinstance(achieved_goal, int): diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 99a664999..0c5610446 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union import gymnasium as gym import numpy as np @@ -34,7 +34,7 @@ def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = No self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[T, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 @@ -42,7 +42,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - self._choose_next_state() return self.state, {} - def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]: + def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 @@ -74,7 +74,7 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 @@ -142,7 +142,7 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index f34d13b7c..8749f82f3 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import gymnasium as gym import numpy as np @@ -73,7 +73,7 @@ def __init__( self.init_possible_transitions() self.num_col = num_col - self.state_mapping: List[Dict[str, np.ndarray]] = [] + self.state_mapping: list[dict[str, np.ndarray]] = [] self.init_state_mapping(num_col, num_row) self.max_state = len(self.state_mapping) - 1 @@ -94,7 +94,7 @@ def init_state_mapping(self, num_col: int, num_row: int) -> None: for j in range(num_row): self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) - def get_state_mapping(self) -> Dict[str, np.ndarray]: + def get_state_mapping(self) -> dict[str, np.ndarray]: """ Uses the state to get the observation mapping. @@ -166,7 +166,7 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, np.ndarray], dict]: """ Resets the environment state and step count and returns reset observation. diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index c9253a899..e66448b51 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -14,11 +14,11 @@ def evaluate_policy( n_eval_episodes: int = 10, deterministic: bool = True, render: bool = False, - callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, + callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None, reward_threshold: Optional[float] = None, return_episode_rewards: bool = False, warn: bool = True, -) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: +) -> Union[tuple[float, float], tuple[list[float], list[int]]]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. If a vector env is passed in, this divides the episodes to evaluate onto the diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 8ceda71ed..8d707cba5 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -5,8 +5,9 @@ import tempfile import warnings from collections import defaultdict +from collections.abc import Mapping, Sequence from io import TextIOBase -from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union +from typing import Any, Optional, TextIO, Union import matplotlib.figure import numpy as np @@ -114,7 +115,7 @@ class KVWriter: Key Value writer """ - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: """ Write a dictionary to file @@ -136,7 +137,7 @@ class SeqWriter: sequence writer """ - def write_sequence(self, sequence: List[str]) -> None: + def write_sequence(self, sequence: list[str]) -> None: """ write_sequence an array to file @@ -172,7 +173,7 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): else: raise ValueError(f"Expected file or str, got {filename_or_file}") - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Create strings for printing key2str = {} tag = "" @@ -244,7 +245,7 @@ def _truncate(self, string: str) -> str: string = string[: self.max_length - 3] + "..." return string - def write_sequence(self, sequence: List[str]) -> None: + def write_sequence(self, sequence: list[str]) -> None: for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one @@ -260,7 +261,7 @@ def close(self) -> None: self.file.close() -def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]: +def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]: """ Filters the keys specified by ``key_exclude`` for the specified format @@ -286,7 +287,7 @@ class JSONOutputFormat(KVWriter): def __init__(self, filename: str): self.file = open(filename, "w") - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): if isinstance(value, Video): raise FormatUnsupportedError(["json"], "video") @@ -328,12 +329,12 @@ class CSVOutputFormat(KVWriter): """ def __init__(self, filename: str): - self.file = open(filename, "w+t") - self.keys: List[str] = [] + self.file = open(filename, "w+") + self.keys: list[str] = [] self.separator = "," self.quotechar = '"' - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Add our current row to the history key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys @@ -399,7 +400,7 @@ def __init__(self, folder: str): self.writer = SummaryWriter(log_dir=folder) self._is_closed = False - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: assert not self._is_closed, "The SummaryWriter was closed, please re-create one." for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and "tensorboard" in excluded: @@ -481,16 +482,16 @@ class Logger: :param output_formats: the list of output formats """ - def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): - self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration - self.name_to_count: Dict[str, int] = defaultdict(int) - self.name_to_excluded: Dict[str, Tuple[str, ...]] = {} + def __init__(self, folder: Optional[str], output_formats: list[KVWriter]): + self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration + self.name_to_count: dict[str, int] = defaultdict(int) + self.name_to_excluded: dict[str, tuple[str, ...]] = {} self.level = INFO self.dir = folder self.output_formats = output_formats @staticmethod - def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]: + def to_tuple(string_or_tuple: Optional[Union[str, tuple[str, ...]]]) -> tuple[str, ...]: """ Helper function to convert str to tuple of str. """ @@ -500,7 +501,7 @@ def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[st return string_or_tuple return (string_or_tuple,) - def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record(self, key: str, value: Any, exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -513,7 +514,7 @@ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, . self.name_to_value[key] = value self.name_to_excluded[key] = self.to_tuple(exclude) - def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. @@ -624,7 +625,7 @@ def close(self) -> None: # Misc # ---------------------------------------- - def _do_log(self, args: Tuple[Any, ...]) -> None: + def _do_log(self, args: tuple[Any, ...]) -> None: """ log to the requested format outputs @@ -635,7 +636,7 @@ def _do_log(self, args: Tuple[Any, ...]) -> None: _format.write_sequence(list(map(str, args))) -def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger: +def configure(folder: Optional[str] = None, format_strings: Optional[list[str]] = None) -> Logger: """ Configure the current logger. diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index fb8ce33c6..80dfd4668 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -5,7 +5,7 @@ import os import time from glob import glob -from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union +from typing import Any, Optional, SupportsFloat, Union import gymnasium as gym import pandas @@ -33,8 +33,8 @@ def __init__( env: gym.Env, filename: Optional[str] = None, allow_early_resets: bool = True, - reset_keywords: Tuple[str, ...] = (), - info_keywords: Tuple[str, ...] = (), + reset_keywords: tuple[str, ...] = (), + info_keywords: tuple[str, ...] = (), override_existing: bool = True, ): super().__init__(env=env) @@ -52,16 +52,16 @@ def __init__( self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets - self.rewards: List[float] = [] + self.rewards: list[float] = [] self.needs_reset = True - self.episode_returns: List[float] = [] - self.episode_lengths: List[int] = [] - self.episode_times: List[float] = [] + self.episode_returns: list[float] = [] + self.episode_lengths: list[int] = [] + self.episode_times: list[float] = [] self.total_steps = 0 # extra info about the current episode, that was passed in during reset() - self.current_reset_info: Dict[str, Any] = {} + self.current_reset_info: dict[str, Any] = {} - def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: + def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -82,7 +82,7 @@ def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """ Step the environment with the given action @@ -126,7 +126,7 @@ def get_total_steps(self) -> int: """ return self.total_steps - def get_episode_rewards(self) -> List[float]: + def get_episode_rewards(self) -> list[float]: """ Returns the rewards of all the episodes @@ -134,7 +134,7 @@ def get_episode_rewards(self) -> List[float]: """ return self.episode_returns - def get_episode_lengths(self) -> List[int]: + def get_episode_lengths(self) -> list[int]: """ Returns the number of timesteps of all the episodes @@ -142,7 +142,7 @@ def get_episode_lengths(self) -> List[int]: """ return self.episode_lengths - def get_episode_times(self) -> List[float]: + def get_episode_times(self) -> list[float]: """ Returns the runtime in seconds of all the episodes @@ -175,8 +175,8 @@ class ResultsWriter: def __init__( self, filename: str = "", - header: Optional[Dict[str, Union[float, str]]] = None, - extra_keys: Tuple[str, ...] = (), + header: Optional[dict[str, Union[float, str]]] = None, + extra_keys: tuple[str, ...] = (), override_existing: bool = True, ): if header is None: @@ -200,7 +200,7 @@ def __init__( self.file_handler.flush() - def write_row(self, epinfo: Dict[str, float]) -> None: + def write_row(self, epinfo: dict[str, float]) -> None: """ Write row of monitor data to csv log file. @@ -217,7 +217,7 @@ def close(self) -> None: self.file_handler.close() -def get_monitor_files(path: str) -> List[str]: +def get_monitor_files(path: str) -> list[str]: """ get all the monitor files in the given path diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 01670e6e4..991cd23ea 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -1,6 +1,7 @@ import copy from abc import ABC, abstractmethod -from typing import Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional import numpy as np from numpy.typing import DTypeLike @@ -153,11 +154,11 @@ def base_noise(self, base_noise: ActionNoise) -> None: self._base_noise = base_noise @property - def noises(self) -> List[ActionNoise]: + def noises(self) -> list[ActionNoise]: return self._noises @noises.setter - def noises(self, noises: List[ActionNoise]) -> None: + def noises(self, noises: list[ActionNoise]) -> None: noises = list(noises) # raises TypeError if not iterable assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c460d0236..6a043e7ac 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,7 +4,7 @@ import time import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -79,7 +79,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): def __init__( self, - policy: Union[str, Type[BasePolicy]], + policy: Union[str, type[BasePolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6 @@ -87,13 +87,13 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "step"), + train_freq: Union[int, tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -105,7 +105,7 @@ def __init__( sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -256,7 +256,7 @@ def _setup_learn( reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, - ) -> Tuple[int, BaseCallback]: + ) -> tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. """ @@ -362,7 +362,7 @@ def _sample_action( learning_starts: int, action_noise: Optional[ActionNoise] = None, n_envs: int = 1, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, @@ -442,10 +442,10 @@ def _store_transition( self, replay_buffer: ReplayBuffer, buffer_action: np.ndarray, - new_obs: Union[np.ndarray, Dict[str, np.ndarray]], + new_obs: Union[np.ndarray, dict[str, np.ndarray]], reward: np.ndarray, dones: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: """ Store transition in the replay buffer. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 262453721..ac4c0970c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,6 +1,7 @@ import sys import time -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +import warnings +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -59,7 +60,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], n_steps: int, @@ -70,17 +71,17 @@ def __init__( max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, monitor_wrapper: bool = True, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -135,6 +136,28 @@ def _setup_model(self) -> None: self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) self.policy = self.policy.to(self.device) + # Warn when not using CPU with MlpPolicy + self._maybe_recommend_cpu() + + def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None: + """ + Recommend to use CPU only when using A2C/PPO with MlpPolicy. + + :param: The name of the class for the default MlpPolicy. + """ + policy_class_name = self.policy_class.__name__ + if self.device != th.device("cpu") and policy_class_name == mlp_class_name: + warnings.warn( + f"You are trying to run {self.__class__.__name__} on the GPU, " + "but it is primarily intended to run on the CPU when not using a CNN policy " + f"(you are using {policy_class_name} which should be a MlpPolicy). " + "See https://github.com/DLR-RM/stable-baselines3/issues/1245 " + "for more info. " + "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU." + "Note: The model will train, but the GPU utilization will be poor and " + "the training might take longer than on CPU.", + UserWarning, + ) def collect_rollouts( self, @@ -316,7 +339,7 @@ def learn( return self - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index f9c4285dc..e20256f0c 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -64,12 +64,12 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, features_extractor: Optional[BaseFeaturesExtractor] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__() @@ -95,9 +95,9 @@ def __init__( def _update_features_extractor( self, - net_kwargs: Dict[str, Any], + net_kwargs: dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Update the network keyword arguments and create a new features extractor object if needed. If a ``features_extractor`` object is passed, then it will be shared. @@ -130,7 +130,7 @@ def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtr preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return features_extractor(preprocessed_obs) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: """ Get data that need to be saved in order to re-create the model when loading it from disk. @@ -164,7 +164,7 @@ def save(self, path: str) -> None: th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: + def load(cls: type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: """ Load model from path. @@ -210,7 +210,7 @@ def set_training_mode(self, mode: bool) -> None: """ self.train(mode) - def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool: + def is_vectorized_observation(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> bool: """ Check whether or not the observation is vectorized, apply transposition to image (so that they are channel-first) if needed. @@ -233,7 +233,7 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np. ) return vectorized_env - def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]: + def obs_to_tensor(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[PyTorchObs, bool]: """ Convert an input observation to a PyTorch tensor that can be fed to a model. Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -330,11 +330,11 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.T def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -450,20 +450,20 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): if optimizer_kwargs is None: optimizer_kwargs = {} @@ -534,7 +534,7 @@ def __init__( self._build(lr_schedule) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value] @@ -633,7 +633,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg] - def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) @@ -659,7 +659,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso def extract_features( # type: ignore[override] self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None - ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + ) -> Union[th.Tensor, tuple[th.Tensor, th.Tensor]]: """ Preprocess the observation if needed and extract features. @@ -716,7 +716,7 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.T """ return self.get_distribution(observation).get_actions(deterministic=deterministic) - def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. @@ -800,20 +800,20 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__( observation_space, @@ -873,20 +873,20 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__( observation_space, @@ -942,10 +942,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: BaseFeaturesExtractor, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, @@ -961,14 +961,14 @@ def __init__( self.share_features_extractor = share_features_extractor self.n_critics = n_critics - self.q_networks: List[nn.Module] = [] + self.q_networks: list[nn.Module] = [] for idx in range(n_critics): q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) q_net = nn.Sequential(*q_net_list) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index d0bfbcd1e..a35f8b76f 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, Tuple, Union +from typing import Union import numpy as np import torch as th @@ -90,10 +90,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> def preprocess_obs( - obs: Union[th.Tensor, Dict[str, th.Tensor]], + obs: Union[th.Tensor, dict[str, th.Tensor]], observation_space: spaces.Space, normalize_images: bool = True, -) -> Union[th.Tensor, Dict[str, th.Tensor]]: +) -> Union[th.Tensor, dict[str, th.Tensor]]: """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) @@ -107,7 +107,7 @@ def preprocess_obs( """ if isinstance(observation_space, spaces.Dict): # Do not modify by reference the original observation - assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}" + assert isinstance(obs, dict), f"Expected dict, got {type(obs)}" preprocessed_obs = {} for key, _obs in obs.items(): preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) @@ -142,7 +142,7 @@ def preprocess_obs( def get_obs_shape( observation_space: spaces.Space, -) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: +) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]: """ Get the shape of the observation (useful for the buffers). diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index f4c1a7a05..f09a54e58 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional import numpy as np import pandas as pd @@ -29,7 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray: return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) -def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: +def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]: """ Apply a function to the rolling window of 2 arrays @@ -44,7 +44,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl return var_1[window - 1 :], function_on_var2 -def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: +def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]: """ Decompose a data frame variable to x and ys @@ -69,7 +69,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray def plot_curves( - xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) + xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ plot the curves @@ -99,7 +99,7 @@ def plot_curves( def plot_results( - dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) + dirs: list[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ Plot the results using csv files from ``Monitor`` wrapper. diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index ac3538c50..c8f03b212 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -1,10 +1,8 @@ -from typing import Tuple - import numpy as np class RunningMeanStd: - def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): + def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()): """ Calculates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index a85c9c2ec..8b545f898 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -12,7 +12,7 @@ import pickle import warnings import zipfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union import cloudpickle import torch as th @@ -73,7 +73,7 @@ def is_json_serializable(item: Any) -> bool: return json_serializable -def data_to_json(data: Dict[str, Any]) -> str: +def data_to_json(data: dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing @@ -128,7 +128,7 @@ def data_to_json(data: Dict[str, Any]) -> str: return json_string -def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = None) -> dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. @@ -293,9 +293,9 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O def save_to_zip_file( save_path: Union[str, pathlib.Path, io.BufferedIOBase], - data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - pytorch_variables: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + pytorch_variables: Optional[dict[str, Any]] = None, verbose: int = 0, ) -> None: """ @@ -376,11 +376,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in def load_from_zip_file( load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, - custom_objects: Optional[Dict[str, Any]] = None, + custom_objects: Optional[dict[str, Any]] = None, device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, -) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: +) -> tuple[Optional[dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index 25f0a6f96..036958460 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional import torch from torch.optim import Optimizer @@ -67,7 +68,7 @@ def __init__( defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) super().__init__(params, defaults) - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 234b91551..6c6aa2ddd 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Optional, Union import gymnasium as gym import torch as th @@ -110,13 +110,13 @@ def forward(self, observations: th.Tensor) -> th.Tensor: def create_mlp( input_dim: int, output_dim: int, - net_arch: List[int], - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: list[int], + activation_fn: type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, - pre_linear_modules: Optional[List[Type[nn.Module]]] = None, - post_linear_modules: Optional[List[Type[nn.Module]]] = None, -) -> List[nn.Module]: + pre_linear_modules: Optional[list[type[nn.Module]]] = None, + post_linear_modules: Optional[list[type[nn.Module]]] = None, +) -> list[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. @@ -211,14 +211,14 @@ class MlpExtractor(nn.Module): def __init__( self, feature_dim: int, - net_arch: Union[List[int], Dict[str, List[int]]], - activation_fn: Type[nn.Module], + net_arch: Union[list[int], dict[str, list[int]]], + activation_fn: type[nn.Module], device: Union[th.device, str] = "auto", ) -> None: super().__init__() device = get_device(device) - policy_net: List[nn.Module] = [] - value_net: List[nn.Module] = [] + policy_net: list[nn.Module] = [] + value_net: list[nn.Module] = [] last_layer_dim_pi = feature_dim last_layer_dim_vf = feature_dim @@ -249,7 +249,7 @@ def __init__( self.policy_net = nn.Sequential(*policy_net).to(device) self.value_net = nn.Sequential(*value_net).to(device) - def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ :return: latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` @@ -288,7 +288,7 @@ def __init__( # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) - extractors: Dict[str, nn.Module] = {} + extractors: dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items(): @@ -313,7 +313,7 @@ def forward(self, observations: TensorDict) -> th.Tensor: return th.cat(encoded_tensor_list, dim=1) -def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]: +def get_actor_critic_arch(net_arch: Union[list[int], dict[str, list[int]]]) -> tuple[list[int], list[int]]: """ Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 042c66f9c..b7c578ac0 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,7 +1,7 @@ """Common aliases for type hints""" from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, SupportsFloat, Union import gymnasium as gym import numpy as np @@ -13,14 +13,14 @@ from stable_baselines3.common.vec_env import VecEnv GymEnv = Union[gym.Env, "VecEnv"] -GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] -GymResetReturn = Tuple[GymObs, Dict] -AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] -GymStepReturn = Tuple[GymObs, float, bool, bool, Dict] -AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] -TensorDict = Dict[str, th.Tensor] -OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"] +GymObs = Union[tuple, dict[str, Any], np.ndarray, int] +GymResetReturn = tuple[GymObs, dict] +AtariResetReturn = tuple[np.ndarray, dict[str, Any]] +GymStepReturn = tuple[GymObs, float, bool, bool, dict] +AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]] +TensorDict = dict[str, th.Tensor] +OptimizerStateDict = dict[str, Any] +MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"] PyTorchObs = Union[th.Tensor, TensorDict] # A schedule takes the remaining progress as input @@ -81,11 +81,11 @@ class TrainFreq(NamedTuple): class PolicyPredictor(Protocol): def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index bcde1cfa0..562ff132a 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,8 +4,9 @@ import random import re from collections import deque +from collections.abc import Iterable from itertools import zip_longest -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Optional, Union import cloudpickle import gymnasium as gym @@ -29,8 +30,8 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: """ Seed the different random generators. - :param seed: - :param using_cuda: + :param seed: Seed + :param using_cuda: Whether CUDA is currently used """ # Seed python RNG random.seed(seed) @@ -46,7 +47,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: # From stable baselines -def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] @@ -62,7 +63,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) - return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y) def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: @@ -140,19 +141,20 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. - For now, it supports only cpu and cuda. - By default, it tries to use the gpu. + For now, it supports only CPU and CUDA. + By default, it tries to use the GPU. - :param device: One for 'auto', 'cuda', 'cpu' + :param device: One of "auto", "cuda", "cpu", + or any PyTorch supported device (for instance "mps") :return: Supported Pytorch device """ - # Cuda by default + # MPS/CUDA by default if device == "auto": - device = "cuda" + device = get_available_accelerator() # Force conversion to th.device device = th.device(device) - # Cuda not available + # CUDA not available if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") @@ -415,7 +417,7 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> float: return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type] -def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]: +def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]: """ Extract parameters from the state dict of ``model`` if the name contains one of the strings in ``included_names``. @@ -473,7 +475,7 @@ def polyak_update( th.add(target_param.data, param.data, alpha=tau, out=target_param.data) -def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: +def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: """ Moves the observation to the given device. @@ -484,6 +486,8 @@ def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.devi if isinstance(obs, np.ndarray): return th.as_tensor(obs, device=device) elif isinstance(obs, dict): + if hasattr(th, "backends") and th.backends.mps.is_built(): + return {key: th.as_tensor(_obs, dtype=th.float32, device=device) for (key, _obs) in obs.items()} return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()} else: raise Exception(f"Unrecognized type of observation {type(obs)}") @@ -517,7 +521,22 @@ def should_collect_more_steps( ) -def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: +def get_available_accelerator() -> str: + """ + Return the available accelerator + (currently checking only for CUDA and MPS device) + """ + if hasattr(th, "backends") and th.backends.mps.is_built(): + # MacOS Metal GPU + th.set_default_dtype(th.float32) + return "mps" + elif th.cuda.is_available(): + return "cuda" + else: + return "cpu" + + +def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: """ Retrieve system and python env info for the current system. @@ -532,7 +551,7 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: "Python": platform.python_version(), "Stable-Baselines3": sb3.__version__, "PyTorch": th.__version__, - "GPU Enabled": str(th.cuda.is_available()), + "Accelerator": get_available_accelerator(), "Numpy": np.__version__, "Cloudpickle": cloudpickle.__version__, "Gymnasium": gym.__version__, diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 5f73d3978..9a60c07dc 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Optional, Type, TypeVar +from typing import Optional, TypeVar from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -16,7 +16,7 @@ VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) -def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: +def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. @@ -42,7 +42,7 @@ def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]: return unwrap_vec_wrapper(env, VecNormalize) -def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool: +def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool: """ Check if an environment is already wrapped in a given ``VecEnvWrapper``. diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 8e0c8cc69..b85c1cf88 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,8 +1,9 @@ import inspect import warnings from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Optional, Union import cloudpickle import gymnasium as gym @@ -14,10 +15,10 @@ VecEnvIndices = Union[None, int, Iterable[int]] # VecEnvObs is what is returned by the reset() method # it contains the observation for each env -VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] +VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]] # VecEnvStepReturn is what is returned by the step() method # it contains the observation, reward, done, info for each env -VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] +VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]] def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover @@ -65,11 +66,11 @@ def __init__( self.observation_space = observation_space self.action_space = action_space # store info returned by the reset method - self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] + self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)] # seeds to be used in the next call to env.reset() - self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] + self._seeds: list[Optional[int]] = [None for _ in range(num_envs)] # options to be used in the next call to env.reset() - self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)] + self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)] try: render_modes = self.get_attr("render_mode") @@ -147,7 +148,7 @@ def close(self) -> None: raise NotImplementedError() @abstractmethod - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ Return attribute from vectorized environment. @@ -170,7 +171,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> raise NotImplementedError() @abstractmethod - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """ Call instance methods of vectorized environments. @@ -183,7 +184,7 @@ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = No raise NotImplementedError() @abstractmethod - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """ Check if environments are wrapped with a given wrapper. @@ -292,7 +293,7 @@ def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: self._seeds = [seed + idx for idx in range(self.num_envs)] return self._seeds - def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None: """ Set environment options for all environments. If a dict is passed instead of a list, the same options will be used for all environments. @@ -379,7 +380,7 @@ def step_wait(self) -> VecEnvStepReturn: def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: return self.venv.seed(seed) - def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None: return self.venv.set_options(options) def close(self) -> None: @@ -391,16 +392,16 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: return self.venv.set_attr(attr_name, value, indices) - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: return self.venv.env_is_wrapped(wrapper_class, indices=indices) def __getattr__(self, name: str) -> Any: @@ -419,7 +420,7 @@ def __getattr__(self, name: str) -> Any: return self.getattr_recursive(name) - def _get_all_attributes(self) -> Dict[str, Any]: + def _get_all_attributes(self) -> dict[str, Any]: """Get all (inherited) instance and class attributes :return: all_attributes diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..4069356d2 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,14 +1,15 @@ import warnings from collections import OrderedDict +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Sequence, Type +from typing import Any, Callable, Optional import gymnasium as gym import numpy as np from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): @@ -26,7 +27,7 @@ class DummyVecEnv(VecEnv): actions: np.ndarray - def __init__(self, env_fns: List[Callable[[], gym.Env]]): + def __init__(self, env_fns: list[Callable[[], gym.Env]]): self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( @@ -46,7 +47,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) - self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)] + self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)] self.metadata = env.metadata def step_async(self, actions: np.ndarray) -> None: @@ -110,12 +111,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: - return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, attr_name) for env_i in target_envs] + return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" @@ -123,12 +124,12 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> for env_i in target_envs: setattr(env_i, attr_name, value) - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_envs = self._get_target_envs(indices) # Import here to avoid a circular import @@ -136,6 +137,6 @@ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndice return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] - def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: + def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]: indices = self._get_indices(indices) return [self.envs[i] for i in indices] diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 6ba655ebf..874809a03 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma "Missing shimmy installation. You provided an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install 'shimmy>=0.2.1'`)." + "install shimmy (`pip install 'shimmy>=2.0'`)." ) from e warnings.warn( diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index b6a759f30..d1b3ad298 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -1,12 +1,13 @@ import warnings -from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union +from collections.abc import Mapping +from typing import Any, Generic, Optional, TypeVar, Union import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first -TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray]) +TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray]) class StackedObservations(Generic[TObs]): @@ -66,7 +67,7 @@ def __init__( @staticmethod def compute_stacking( n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None - ) -> Tuple[bool, int, Tuple[int, ...], int]: + ) -> tuple[bool, int, tuple[int, ...], int]: """ Calculates the parameters in order to stack observations @@ -119,8 +120,8 @@ def update( self, observations: TObs, dones: np.ndarray, - infos: List[Dict[str, Any]], - ) -> Tuple[TObs, List[Dict[str, Any]]]: + infos: list[dict[str, Any]], + ) -> tuple[TObs, list[dict[str, Any]]]: """ Add the observations to the stack and use the dones to update the infos. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..225eadd79 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,7 +1,7 @@ import multiprocessing as mp import warnings -from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -27,7 +27,7 @@ def _worker( parent_remote.close() env = _patch_env(env_fn_wrapper.var()) - reset_info: Optional[Dict[str, Any]] = {} + reset_info: Optional[dict[str, Any]] = {} while True: try: cmd, data = remote.recv() @@ -54,10 +54,10 @@ def _worker( elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": - method = getattr(env, data[0]) + method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": - remote.send(getattr(env, data)) + remote.send(env.get_wrapper_attr(data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -92,7 +92,7 @@ class SubprocVecEnv(VecEnv): Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None): + def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: Optional[str] = None): self.waiting = False self.closed = False n_envs = len(env_fns) @@ -129,7 +129,7 @@ def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment] - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] + return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): @@ -139,7 +139,7 @@ def reset(self) -> VecEnvObs: # Seeds and options are only used once self._reset_seeds() self._reset_options() - return _flatten_obs(obs, self.observation_space) + return _stack_obs(obs, self.observation_space) def close(self) -> None: if self.closed: @@ -165,7 +165,7 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]: outputs = [pipe.recv() for pipe in self.remotes] return outputs - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: @@ -180,21 +180,21 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> for remote in target_remotes: remote.recv() - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("env_method", (method_name, method_args, method_kwargs))) return [remote.recv() for remote in target_remotes] - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("is_wrapped", wrapper_class)) return [remote.recv() for remote in target_remotes] - def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: + def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]: """ Get the connection object needed to communicate with the wanted envs that are in subprocesses. @@ -206,27 +206,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: return [self.remotes[i] for i in indices] -def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: +def _stack_obs(obs_list: Union[list[VecEnvObs], tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ - Flatten observations, depending on the observation space. + Stack observations (convert from a list of single env obs to a stack of obs), + depending on the observation space. :param obs: observations. A list or tuple of observations, one per environment. Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. - :return: flattened observations. - A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + :return: Concatenated observations. + A NumPy array or a dict or tuple of stacked numpy arrays. Each NumPy array has the environment index as its first axis. """ - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" + assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs_list) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" + assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space" + return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] + return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index] else: - return np.stack(obs) # type: ignore[arg-type] + return np.stack(obs_list) # type: ignore[arg-type] diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 855f50edc..c1babd87b 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -2,8 +2,7 @@ Helpers for dealing with vectorized environments. """ -from collections import OrderedDict -from typing import Any, Dict, List, Tuple +from typing import Any import numpy as np from gymnasium import spaces @@ -12,18 +11,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs -def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """ - Deep-copy a dict of numpy arrays. - - :param obs: a dict of numpy arrays. - :return: a dict of copied numpy arrays. - """ - assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" - return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) - - -def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: +def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type specified by space. @@ -44,7 +32,7 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec return obs_dict[None] -def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: +def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]: """ Get dict-structured information about a gym.Space. @@ -60,13 +48,13 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): - subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" - subspaces = {None: obs_space} # type: ignore[assignment] + subspaces = {None: obs_space} # type: ignore[assignment,dict-item] keys = [] shapes = {} dtypes = {} @@ -74,4 +62,4 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype - return keys, shapes, dtypes + return keys, shapes, dtypes # type: ignore[return-value] diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index 170f36ec8..1d775aad5 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -1,5 +1,4 @@ import warnings -from typing import List, Tuple import numpy as np from gymnasium import spaces @@ -48,7 +47,7 @@ def reset(self) -> VecEnvObs: self._observations = observations return observations - def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]: + def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]: """ Check for inf and NaN for a single numpy array. diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index daa2b365c..2142bcb9e 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from collections.abc import Mapping +from typing import Any, Optional, Union import numpy as np from gymnasium import spaces @@ -29,17 +30,17 @@ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[st def step_wait( self, - ) -> Tuple[ - Union[np.ndarray, Dict[str, np.ndarray]], + ) -> tuple[ + Union[np.ndarray, dict[str, np.ndarray]], np.ndarray, np.ndarray, - List[Dict[str, Any]], + list[dict[str, Any]], ]: observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Reset all environments """ diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py index 0d7f18a5e..4aa9325f6 100644 --- a/stable_baselines3/common/vec_env/vec_monitor.py +++ b/stable_baselines3/common/vec_env/vec_monitor.py @@ -1,6 +1,6 @@ import time import warnings -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -27,7 +27,7 @@ def __init__( self, venv: VecEnv, filename: Optional[str] = None, - info_keywords: Tuple[str, ...] = (), + info_keywords: tuple[str, ...] = (), ): # Avoid circular import from stable_baselines3.common.monitor import Monitor, ResultsWriter diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index ab1d8403a..5f0ee1c25 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -1,7 +1,7 @@ import inspect import pickle from copy import deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np from gymnasium import spaces @@ -29,8 +29,8 @@ class VecNormalize(VecEnvWrapper): If not specified, all keys will be normalized. """ - obs_spaces: Dict[str, spaces.Space] - old_obs: Union[np.ndarray, Dict[str, np.ndarray]] + obs_spaces: dict[str, spaces.Space] + old_obs: Union[np.ndarray, dict[str, np.ndarray]] def __init__( self, @@ -42,7 +42,7 @@ def __init__( clip_reward: float = 10.0, gamma: float = 0.99, epsilon: float = 1e-8, - norm_obs_keys: Optional[List[str]] = None, + norm_obs_keys: Optional[list[str]] = None, ): VecEnvWrapper.__init__(self, venv) @@ -125,7 +125,7 @@ def _sanity_checks(self) -> None: f"not {self.observation_space}" ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. @@ -138,7 +138,7 @@ def __getstate__(self) -> Dict[str, Any]: del state["returns"] return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. @@ -229,7 +229,7 @@ def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarr """ return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean - def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Normalize observations using this VecNormalize's observations statistics. Calling this method does not update statistics. @@ -254,9 +254,11 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: """ if self.norm_reward: reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) - return reward + # Note: we cast to float32 as it correspond to Python default float type + # This cast is needed because `RunningMeanStd` keeps stats in float64 + return reward.astype(np.float32) - def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]: # Avoid modifying by reference the original object obs_ = deepcopy(obs) if self.norm_obs: @@ -274,7 +276,7 @@ def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: return reward * np.sqrt(self.ret_rms.var + self.epsilon) return reward - def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def get_original_obs(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Returns an unnormalized version of the observations from the most recent step or reset. @@ -287,7 +289,7 @@ def get_original_reward(self) -> np.ndarray: """ return self.old_reward.copy() - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Reset all environments :return: first observation of the episode diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 487bd8c07..3fade64d1 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, Union +from typing import Union import numpy as np from gymnasium import spaces @@ -73,7 +73,7 @@ def transpose_image(image: np.ndarray) -> np.ndarray: return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) - def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: + def transpose_observations(self, observations: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]: """ Transpose (if needed) and return new observations. @@ -106,7 +106,7 @@ def step_wait(self) -> VecEnvStepReturn: assert isinstance(observations, (np.ndarray, dict)) return self.transpose_observations(observations), rewards, dones, infos - def reset(self) -> Union[np.ndarray, Dict]: + def reset(self) -> Union[np.ndarray, dict]: """ Reset all environments """ diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 52faebd1f..add3846b6 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,7 +1,9 @@ import os +import os.path from typing import Callable -from gymnasium.wrappers.monitoring import video_recorder +import numpy as np +from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -13,6 +15,11 @@ class VecVideoRecorder(VecEnvWrapper): Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. + Note: for now it only allows to record one video and all videos + must have at least two frames. + + The video recorder code was adapted from Gymnasium v1.0. + :param venv: :param video_folder: Where to save videos :param record_video_trigger: Function that defines when to start recording. @@ -22,8 +29,6 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: video_recorder.VideoRecorder - def __init__( self, venv: VecEnv, @@ -51,6 +56,8 @@ def __init__( self.env.metadata = metadata assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" + self.frames_per_sec = self.env.metadata.get("render_fps", 30) + self.record_video_trigger = record_video_trigger self.video_folder = os.path.abspath(video_folder) # Create output folder if needed @@ -60,54 +67,88 @@ def __init__( self.step_id = 0 self.video_length = video_length + self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" + self.video_path = os.path.join(self.video_folder, self.video_name) + self.recording = False - self.recorded_frames = 0 + self.recorded_frames: list[np.ndarray] = [] + + try: + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: obs = self.venv.reset() - self.start_video_recorder() + if self._video_enabled(): + self._start_video_recorder() return obs - def start_video_recorder(self) -> None: - self.close_video_recorder() - - video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" - base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - ) - - self.video_recorder.capture_frame() - self.recorded_frames = 1 - self.recording = True + def _start_video_recorder(self) -> None: + self._start_recording() + self._capture_frame() def _video_enabled(self) -> bool: return self.record_video_trigger(self.step_id) def step_wait(self) -> VecEnvStepReturn: - obs, rews, dones, infos = self.venv.step_wait() + obs, rewards, dones, infos = self.venv.step_wait() self.step_id += 1 if self.recording: - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.recorded_frames > self.video_length: - print(f"Saving video to {self.video_recorder.path}") - self.close_video_recorder() + self._capture_frame() + if len(self.recorded_frames) > self.video_length: + print(f"Saving video to {self.video_path}") + self._stop_recording() elif self._video_enabled(): - self.start_video_recorder() + self._start_video_recorder() - return obs, rews, dones, infos + return obs, rewards, dones, infos - def close_video_recorder(self) -> None: - if self.recording: - self.video_recorder.close() - self.recording = False - self.recorded_frames = 1 + def _capture_frame(self) -> None: + assert self.recording, "Cannot capture a frame, recording wasn't started." + + frame = self.env.render() + if isinstance(frame, list): + frame = frame[-1] + + if isinstance(frame, np.ndarray): + self.recorded_frames.append(frame) + else: + self._stop_recording() + logger.warn( + f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." + ) def close(self) -> None: + """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) - self.close_video_recorder() + if self.recording: + self._stop_recording() + + def _start_recording(self) -> None: + """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" + if self.recording: + self._stop_recording() + + self.recording = True + + def _stop_recording(self) -> None: + """Stop current recording and saves the video.""" + assert self.recording, "_stop_recording was called, but no recording was started" + + if len(self.recorded_frames) == 0: + logger.warn("Ignored saving a video as there were zero frames to save.") + else: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.video_path) + + self.recorded_frames = [] + self.recording = False - def __del__(self): - self.close_video_recorder() + def __del__(self) -> None: + """Warn the user in case last video wasn't saved.""" + if len(self.recorded_frames) > 0: + logger.warn("Unable to save last video! Did you call close()?") diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 2fe2fdfc4..d94fa1812 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import torch as th @@ -55,7 +55,7 @@ class DDPG(TD3): def __init__( self, - policy: Union[str, Type[TD3Policy]], + policy: Union[str, type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 @@ -63,14 +63,14 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 894ed9f04..a3f200e59 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -62,7 +62,7 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -75,7 +75,7 @@ class DQN(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[DQNPolicy]], + policy: Union[str, type[DQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 @@ -83,10 +83,10 @@ def __init__( batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 4, + train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.1, @@ -95,7 +95,7 @@ def __init__( max_grad_norm: float = 10, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -227,11 +227,11 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. @@ -273,10 +273,10 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return [*super()._excluded_save_params(), "q_net", "q_net_target"] - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index bfefc8137..95f05d8ca 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import torch as th from gymnasium import spaces @@ -35,8 +35,8 @@ def __init__( action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ) -> None: super().__init__( @@ -71,7 +71,7 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Te action = q_values.argmax(dim=1).reshape(-1) return action - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -113,13 +113,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, @@ -183,7 +183,7 @@ def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self.q_net._predict(obs, deterministic=deterministic) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -237,13 +237,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, @@ -282,13 +282,13 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 20214e72c..956aabc92 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np import torch as th @@ -98,7 +98,7 @@ def __init__( self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64) self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. @@ -109,7 +109,7 @@ def __getstate__(self) -> Dict[str, Any]: del state["env"] return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. @@ -134,12 +134,12 @@ def set_env(self, env: VecEnv) -> None: def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], - next_obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], + next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # When the buffer is full, we rewrite on old episodes. When we start to # rewrite on an old episodes, we want the whole old episode to be deleted diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 52ee2eb64..03cbc2464 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -71,7 +71,7 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, @@ -79,7 +79,7 @@ class PPO(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, @@ -95,12 +95,12 @@ def __init__( max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 6185e2992..330467727 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import torch as th from gymnasium import spaces @@ -51,10 +51,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: nn.Module, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, full_std: bool = True, @@ -102,7 +102,7 @@ def __init__( self.mu = nn.Linear(last_layer_dim, action_dim) self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -144,7 +144,7 @@ def reset_noise(self, batch_size: int = 1) -> None: assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) - def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -169,7 +169,7 @@ def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) @@ -216,17 +216,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -309,7 +309,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Target networks should always be in eval mode self.critic_target.set_training_mode(False) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -400,17 +400,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -466,17 +466,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa5028..8cb2ae53d 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -77,7 +77,7 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -89,7 +89,7 @@ class SAC(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[SACPolicy]], + policy: Union[str, type[SACPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, buffer_size: int = 1_000_000, # 1e6 @@ -97,11 +97,11 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, @@ -111,7 +111,7 @@ def __init__( use_sde_at_warmup: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -313,10 +313,10 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005 - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] if self.ent_coef_optimizer is not None: saved_pytorch_variables = ["log_ent_coef"] diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index a15be0396..aa7ea8069 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union import torch as th from gymnasium import spaces @@ -36,10 +36,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: nn.Module, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): super().__init__( @@ -59,7 +59,7 @@ def __init__( # Deterministic action self.mu = nn.Sequential(*actor_net) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -116,13 +116,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -207,7 +207,7 @@ def _build(self, lr_schedule: Schedule) -> None: self.actor_target.set_training_mode(False) self.critic_target.set_training_mode(False) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -285,13 +285,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -339,13 +339,13 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a61d954bc..affb9c9f8 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -65,7 +65,7 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -78,7 +78,7 @@ class TD3(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[TD3Policy]], + policy: Union[str, type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 @@ -86,18 +86,18 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_delay: int = 2, target_policy_noise: float = 0.2, target_noise_clip: float = 0.5, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -228,9 +228,9 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005 - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 636c433a1..b8feefb94 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a9 +2.5.0a0 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index f093e47e7..305f56fc2 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -72,7 +72,7 @@ def step(self, action): terminated = truncated = False return self.observation_space.sample(), reward, terminated, truncated, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.observation_space.seed(seed) return self.observation_space.sample(), {} @@ -117,12 +117,11 @@ def test_consistency(model_class): """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) obs, _ = dict_env.reset() - kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 48eae12d0..0f8ae6253 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Tuple import gymnasium as gym import numpy as np @@ -55,7 +54,7 @@ def test_squashed_gaussian(model_class): @pytest.fixture() -def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]: +def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]: """ Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env :return: A2C model, random observations, random actions diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 62dd6ffa6..c1598cad4 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import gymnasium as gym import numpy as np @@ -135,7 +135,7 @@ def test_check_env_detailed_error(obs_tuple, method): class TestEnv(gym.Env): action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): return wrong_obs if method == "reset" else good_obs, {} def step(self, action): @@ -162,7 +162,7 @@ def __init__(self, steps_before_termination: int = 1): self._steps_called = 0 self._terminated = False - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[int, dict]: super().reset(seed=seed) self._steps_called = 0 @@ -170,7 +170,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - return 0, {} - def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]: self._steps_called += 1 assert not self._terminated diff --git a/tests/test_gae.py b/tests/test_gae.py index 83b95a4c0..6e32d0f87 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -23,7 +23,7 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.observation_space.seed(seed) self.n_steps = 0 @@ -53,7 +53,7 @@ def __init__(self, n_states=4): self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) @@ -73,7 +73,7 @@ def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() - max_steps = self.training_env.envs[0].max_steps + max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value diff --git a/tests/test_logger.py b/tests/test_logger.py index bc18bf2ce..039e3f4ac 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -2,8 +2,8 @@ import os import sys import time +from collections.abc import Sequence from io import TextIOBase -from typing import Sequence from unittest import mock import gymnasium as gym @@ -592,6 +592,7 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): """ STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, @@ -603,16 +604,17 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..4acabb692 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,7 @@ import gymnasium as gym import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env(): def test_ppo_warnings(): - """Test that PPO warns and errors correctly on - problematic rollout buffer sizes""" + """ + Test that PPO warns and errors correctly on + problematic rollout buffer sizes, + and recommend using CPU. + """ # Only 1 step: advantage normalization will return NaN with pytest.raises(AssertionError): @@ -234,3 +238,9 @@ def test_ppo_warnings(): loss = model.logger.name_to_value["train/loss"] assert loss > 0 assert not np.isnan(loss) # check not nan (since nan does not equal nan) + + with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"): + model = PPO("MlpPolicy", "Pendulum-v1") + # Pretend to be on the GPU + model.device = th.device("cuda") + model._maybe_recommend_cpu() diff --git a/tests/test_spaces.py b/tests/test_spaces.py index e006c1f96..102c0ef8f 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np import pytest +import torch as th from gymnasium import spaces from gymnasium.spaces.space import Space @@ -24,7 +25,7 @@ class DummyEnv(gym.Env): def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} @@ -151,6 +152,8 @@ def test_discrete_obs_space(model_class, env): ], ) def test_float64_action_space(model_class, obs_space, action_space): + if hasattr(th, "backends") and th.backends.mps.is_built(): + pytest.skip("MPS framework doesn't support float64") env = DummyEnv(obs_space, action_space) env = gym.wrappers.TimeLimit(env, max_episode_steps=200) if isinstance(env.observation_space, spaces.Dict): diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index eee0ec0aa..b81586367 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Union +from typing import Union import pytest @@ -24,7 +24,7 @@ class HParamCallback(BaseCallback): """ def _on_training_start(self) -> None: - hparam_dict: Dict[str, Union[str, float]] = { + hparam_dict: dict[str, Union[str, float]] = { "algorithm": self.model.__class__.__name__, # Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458 "gamma": self.model.gamma, # type: ignore[attr-defined] @@ -33,7 +33,7 @@ def _on_training_start(self) -> None: hparam_dict["learning rate"] = self.model.learning_rate # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag # Tensorbaord will find & display metrics from the `SCALARS` tab - metric_dict: Dict[str, float] = { + metric_dict: dict[str, float] = { "rollout/ep_len_mean": 0, } self.logger.record( diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e9f..1280f5772 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import os import shutil +import ale_py import gymnasium as gym import numpy as np import pytest @@ -24,6 +25,8 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +gym.register_envs(ale_py) + @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) @@ -177,7 +180,7 @@ def test_custom_vec_env(tmp_path): @pytest.mark.parametrize("direct_policy", [False, True]) -def test_evaluate_policy(direct_policy: bool): +def test_evaluate_policy(direct_policy): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 @@ -442,9 +445,10 @@ def test_get_system_info(): assert info["Stable-Baselines3"] == str(sb3.__version__) assert "Python" in info_str assert "PyTorch" in info_str - assert "GPU Enabled" in info_str + assert "Accelerator" in info_str assert "Numpy" in info_str assert "Gym" in info_str + assert "Cloudpickle" in info_str def test_is_vectorized_observation(): diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index a9516ae25..93185fa3c 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -4,7 +4,7 @@ import multiprocessing import os import warnings -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -30,9 +30,9 @@ def __init__(self, space, render_mode: str = "rgb_array"): self.current_step = 0 self.ep_length = 4 self.render_mode = render_mode - self.current_options: Optional[Dict] = None + self.current_options: Optional[dict] = None - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.seed(seed) self.current_step = 0 @@ -193,7 +193,7 @@ def __init__(self, max_steps): self.max_steps = max_steps self.current_step = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): self.current_step = 0 return np.array([self.current_step], dtype="int"), {} @@ -307,7 +307,7 @@ def test_vecenv_dict_spaces(vec_env_class): space = spaces.Dict(SPACES) def obs_assert(obs): - assert isinstance(obs, collections.OrderedDict) + assert isinstance(obs, dict) assert obs.keys() == space.spaces.keys() for key, values in obs.items(): check_vecenv_obs(values, space.spaces[key]) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index b7d71b748..3db5fcd47 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,5 @@ import operator -from typing import Any, Dict, Optional +from typing import Any, Optional import gymnasium as gym import numpy as np @@ -22,7 +22,7 @@ class DummyRewardEnv(gym.Env): - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} def __init__(self, return_reward_idx=0): self.action_space = spaces.Discrete(2) @@ -39,7 +39,7 @@ def step(self, action): truncated = self.t == len(self.returned_rewards) return np.array([returned_value]), returned_value, terminated, truncated, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) self.t = 0 @@ -62,7 +62,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} @@ -94,7 +94,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {}