Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for mps support #1

Open
wants to merge 50 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ace0516
Use MPS device when available
araffin Jul 4, 2022
9ac6225
Merge branch 'master' into feat/mps-support
araffin Aug 13, 2022
2dcbef9
Update test
araffin Aug 13, 2022
b00ca7f
Merge branch 'master' into feat/mps-support
araffin Aug 16, 2022
06a2124
Merge branch 'master' into feat/mps-support
qgallouedec Sep 28, 2022
6d868c0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 4, 2022
8d79e96
Merge branch 'master' into feat/mps-support
qgallouedec Oct 7, 2022
3276cb0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 10, 2022
f4f6073
Merge branch 'master' into feat/mps-support
qgallouedec Oct 14, 2022
64327c7
Merge branch 'master' into feat/mps-support
qgallouedec Oct 17, 2022
0344c3c
Merge branch 'master' into feat/mps-support
araffin Oct 24, 2022
fa196ab
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2022
efd086e
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2022
7f11843
Merge branch 'master' into feat/mps-support
qgallouedec Dec 7, 2022
c60f681
Merge branch 'master' into feat/mps-support
qgallouedec Dec 20, 2022
92e8d11
Merge branch 'master' into feat/mps-support
araffin Jan 13, 2023
b235c8e
Merge branch 'master' into feat/mps-support
qgallouedec Feb 14, 2023
d4d0536
Merge branch 'master' into feat/mps-support
araffin Apr 3, 2023
0311b62
Merge branch 'master' into feat/mps-support
araffin Apr 21, 2023
086f79a
Merge branch 'master' into feat/mps-support
araffin May 3, 2023
fe606fc
Merge branch 'master' into feat/mps-support
araffin May 24, 2023
34f4819
Merge branch 'master' into feat/mps-support
qgallouedec Jun 30, 2023
ef39571
Merge branch 'master' into feat/mps-support
araffin Aug 17, 2023
d26324c
Merge branch 'master' into feat/mps-support
araffin Aug 30, 2023
1e5dc90
Merge branch 'master' into feat/mps-support
araffin Oct 6, 2023
40ed03c
mps.is_available -> mps.is_built
qgallouedec Oct 6, 2023
e83924b
docstring
qgallouedec Oct 6, 2023
b707480
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2023
81e3c63
Merge branch 'master' into feat/mps-support
araffin Nov 16, 2023
f0e54a7
Merge branch 'master' into feat/mps-support
araffin Jan 10, 2024
d47c586
Merge branch 'master' into feat/mps-support
araffin Apr 18, 2024
b85a2a5
Fix warning
araffin Apr 18, 2024
1c25053
Fix tests
deathcoder Sep 14, 2024
f822ef5
Attempt fix ci: only cast reward from float64 to float32
deathcoder Sep 17, 2024
1ac4a60
allow running workflows from ui
deathcoder Sep 17, 2024
9970f51
Merge pull request #2 from deathcoder/attempt-fix-ci
deathcoder Sep 17, 2024
955382e
Merge branch 'master' into feat/mps-support
araffin Sep 18, 2024
56c153f
Add warning when using PPO on GPU and update doc (#2017)
Dev1nW Oct 7, 2024
3d59b5c
Use uv on GitHub CI for faster download and update changelog (#2026)
araffin Oct 24, 2024
dd3d0ac
Update readme and clarify planned features (#2030)
araffin Oct 29, 2024
5e7372d
Merge branch 'feat/mps-support' into feat/mps-support
araffin Oct 29, 2024
263e657
Merge branch 'master' into feat/mps-support
araffin Oct 29, 2024
8f0b488
Update Gymnasium to v1.0.0 (#1837)
pseudo-rnd-thoughts Nov 4, 2024
e4f4f12
Add note about SAC ent coeff optimization (#2037)
araffin Nov 8, 2024
7c71688
Merge branch 'master' into feat/mps-support
araffin Nov 8, 2024
4c03a25
Merge remote-tracking branch 'origin/feat/mps-support' into feat/mps-…
araffin Nov 8, 2024
020ee42
Release 2.4.0 (#2040)
araffin Nov 18, 2024
daaebd0
Drop python 3.8 and add python 3.12 support (#2041)
araffin Nov 18, 2024
9489b1a
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2024
0ec37d8
Merge branch 'feat/mps-support' into feat/mps-support
araffin Nov 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

python-version: ["3.9", "3.10", "3.11", "3.12"]
include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -31,18 +36,21 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu

pip install .[extra_no_roms,tests,docs]
uv pip install --system .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
uv pip install --system "numpy<2"
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
Expand All @@ -55,8 +63,6 @@ jobs:
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Ubuntu 24.04 is not yet officially supported by Read the Docs

Ubuntu 24.04 (Jammy) was just released in April 2024 and Read the Docs typically takes some time to officially support new Ubuntu LTS versions. I recommend keeping ubuntu-22.04 for now to ensure stable documentation builds. You can upgrade to 24.04 once it's officially supported by Read the Docs.

  • Revert the OS version in .readthedocs.yml back to ubuntu-22.04
  • Keep the Python environment update to mambaforge-23.11 as it's independent of the OS version
🔗 Analysis chain

Verify Ubuntu 24.04 support on Read the Docs.

While upgrading to Ubuntu 24.04 LTS is forward-thinking, we should verify that Read the Docs officially supports this version to ensure stable documentation builds.


Let me gather more information about Read the Docs' Ubuntu support.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if ubuntu-24.04 is mentioned in the RTD documentation or changelog
rg -i "ubuntu.?24\.04|jammy" README.md docs/

Length of output: 474


Script:

#!/bin/bash
# Check .readthedocs.yml history to understand the OS version transition
git log -p -- .readthedocs.yml | grep -A 5 -B 5 "os:"

# Fetch current .readthedocs.yml content
cat .readthedocs.yml

Length of output: 944

tools:
python: "mambaforge-22.9"
python: "mambaforge-23.11"
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ into two categories:
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.

Expand Down
36 changes: 23 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


Expand All @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.


| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
Expand All @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin

### Planned features

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down Expand Up @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)

Expand All @@ -92,22 +100,21 @@ It provides a minimal number of features compared to SB3 but can be much faster

## Installation

**Note:** Stable-Baselines3 supports PyTorch >= 1.13
**Note:** Stable-Baselines3 supports PyTorch >= 2.3

### Prerequisites
Stable Baselines3 requires Python 3.8+.
Stable Baselines3 requires Python 3.9+.

#### Windows 10
#### Windows

To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).


### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
pip install stable-baselines3[extra]
```
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
Expand Down Expand Up @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
Expand All @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks:

<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
Expand All @@ -218,9 +226,9 @@ To run a single test:
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
You can also do a static type check using `mypy`:
```sh
pip install pytype mypy
pip install mypy
make type
```

Expand Down Expand Up @@ -252,6 +260,8 @@ To cite this repository in publications:
}
```

Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).

## Maintainers

Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
Expand Down
14 changes: 7 additions & 7 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
name: root
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=22.3.1
- python=3.8
- pytorch=1.13.0=py3.8_cpu_0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<8
- sphinx>=5,<9
- sphinx_rtd_theme>=1.3.0
- sphinx_copybutton
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import datetime
import os
import sys
from typing import Dict

# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
Expand Down Expand Up @@ -151,7 +150,7 @@ def setup(app):

# -- Options for LaTeX output ------------------------------------------------

latex_elements: Dict[str, str] = {
latex_elements: dict[str, str] = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inconsistency in CrossQ implementation location.

There appears to be an inconsistency in the documentation. The table indicates CrossQ is implemented in SB3 Contrib (via footnote [#f1]), but the note at the bottom of the file mentions it's in the SBX repo. Please clarify which repository actually contains the CrossQ implementation.

DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌ ✔️
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Installation
Prerequisites
-------------

Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3

Windows
~~~~~~~
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix Python version inconsistency in Windows section

The Windows section still references Python 3.8, which contradicts the new 3.9+ requirement.

Apply this change:

-You need an environment with Python version 3.8 or above.
+You need an environment with Python version 3.9 or above.

Committable suggestion skipped: line range outside the PR's diff.

Expand Down
1 change: 1 addition & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_


**Gym Wrappers**:
Expand Down
6 changes: 5 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train

SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

SBX (SB3 + Jax): https://github.com/araffin/sbx


Main Features
--------------
Expand Down Expand Up @@ -113,12 +115,14 @@ To cite this project in publications:
url = {http://jmlr.org/papers/v22/20-1364.html}
}

Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI <https://doi.org/10.5281/zenodo.8123988>`_.

Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
You can check issues in the `repository <https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Expand Down
Loading
Loading