Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests for mps support #1

Open
wants to merge 50 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
ace0516
Use MPS device when available
araffin Jul 4, 2022
9ac6225
Merge branch 'master' into feat/mps-support
araffin Aug 13, 2022
2dcbef9
Update test
araffin Aug 13, 2022
b00ca7f
Merge branch 'master' into feat/mps-support
araffin Aug 16, 2022
06a2124
Merge branch 'master' into feat/mps-support
qgallouedec Sep 28, 2022
6d868c0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 4, 2022
8d79e96
Merge branch 'master' into feat/mps-support
qgallouedec Oct 7, 2022
3276cb0
Merge branch 'master' into feat/mps-support
qgallouedec Oct 10, 2022
f4f6073
Merge branch 'master' into feat/mps-support
qgallouedec Oct 14, 2022
64327c7
Merge branch 'master' into feat/mps-support
qgallouedec Oct 17, 2022
0344c3c
Merge branch 'master' into feat/mps-support
araffin Oct 24, 2022
fa196ab
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2022
efd086e
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2022
7f11843
Merge branch 'master' into feat/mps-support
qgallouedec Dec 7, 2022
c60f681
Merge branch 'master' into feat/mps-support
qgallouedec Dec 20, 2022
92e8d11
Merge branch 'master' into feat/mps-support
araffin Jan 13, 2023
b235c8e
Merge branch 'master' into feat/mps-support
qgallouedec Feb 14, 2023
d4d0536
Merge branch 'master' into feat/mps-support
araffin Apr 3, 2023
0311b62
Merge branch 'master' into feat/mps-support
araffin Apr 21, 2023
086f79a
Merge branch 'master' into feat/mps-support
araffin May 3, 2023
fe606fc
Merge branch 'master' into feat/mps-support
araffin May 24, 2023
34f4819
Merge branch 'master' into feat/mps-support
qgallouedec Jun 30, 2023
ef39571
Merge branch 'master' into feat/mps-support
araffin Aug 17, 2023
d26324c
Merge branch 'master' into feat/mps-support
araffin Aug 30, 2023
1e5dc90
Merge branch 'master' into feat/mps-support
araffin Oct 6, 2023
40ed03c
mps.is_available -> mps.is_built
qgallouedec Oct 6, 2023
e83924b
docstring
qgallouedec Oct 6, 2023
b707480
Merge branch 'master' into feat/mps-support
qgallouedec Nov 2, 2023
81e3c63
Merge branch 'master' into feat/mps-support
araffin Nov 16, 2023
f0e54a7
Merge branch 'master' into feat/mps-support
araffin Jan 10, 2024
d47c586
Merge branch 'master' into feat/mps-support
araffin Apr 18, 2024
b85a2a5
Fix warning
araffin Apr 18, 2024
1c25053
Fix tests
deathcoder Sep 14, 2024
f822ef5
Attempt fix ci: only cast reward from float64 to float32
deathcoder Sep 17, 2024
1ac4a60
allow running workflows from ui
deathcoder Sep 17, 2024
9970f51
Merge pull request #2 from deathcoder/attempt-fix-ci
deathcoder Sep 17, 2024
955382e
Merge branch 'master' into feat/mps-support
araffin Sep 18, 2024
56c153f
Add warning when using PPO on GPU and update doc (#2017)
Dev1nW Oct 7, 2024
3d59b5c
Use uv on GitHub CI for faster download and update changelog (#2026)
araffin Oct 24, 2024
dd3d0ac
Update readme and clarify planned features (#2030)
araffin Oct 29, 2024
5e7372d
Merge branch 'feat/mps-support' into feat/mps-support
araffin Oct 29, 2024
263e657
Merge branch 'master' into feat/mps-support
araffin Oct 29, 2024
8f0b488
Update Gymnasium to v1.0.0 (#1837)
pseudo-rnd-thoughts Nov 4, 2024
e4f4f12
Add note about SAC ent coeff optimization (#2037)
araffin Nov 8, 2024
7c71688
Merge branch 'master' into feat/mps-support
araffin Nov 8, 2024
4c03a25
Merge remote-tracking branch 'origin/feat/mps-support' into feat/mps-…
araffin Nov 8, 2024
020ee42
Release 2.4.0 (#2040)
araffin Nov 18, 2024
daaebd0
Drop python 3.8 and add python 3.12 support (#2041)
araffin Nov 18, 2024
9489b1a
Merge branch 'master' into feat/mps-support
araffin Nov 18, 2024
0ec37d8
Merge branch 'feat/mps-support' into feat/mps-support
araffin Nov 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

pip install .[extra_no_roms,tests,docs]
uv pip install --system .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
make lint
Expand Down
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Ubuntu 24.04 is not yet officially supported by Read the Docs

Ubuntu 24.04 (Jammy) was just released in April 2024 and Read the Docs typically takes some time to officially support new Ubuntu LTS versions. I recommend keeping ubuntu-22.04 for now to ensure stable documentation builds. You can upgrade to 24.04 once it's officially supported by Read the Docs.

  • Revert the OS version in .readthedocs.yml back to ubuntu-22.04
  • Keep the Python environment update to mambaforge-23.11 as it's independent of the OS version
🔗 Analysis chain

Verify Ubuntu 24.04 support on Read the Docs.

While upgrading to Ubuntu 24.04 LTS is forward-thinking, we should verify that Read the Docs officially supports this version to ensure stable documentation builds.


Let me gather more information about Read the Docs' Ubuntu support.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check if ubuntu-24.04 is mentioned in the RTD documentation or changelog
rg -i "ubuntu.?24\.04|jammy" README.md docs/

Length of output: 474


Script:

#!/bin/bash
# Check .readthedocs.yml history to understand the OS version transition
git log -p -- .readthedocs.yml | grep -A 5 -B 5 "os:"

# Fetch current .readthedocs.yml content
cat .readthedocs.yml

Length of output: 944

tools:
python: "mambaforge-22.9"
python: "mambaforge-23.11"
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ into two categories:
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.

Expand Down
32 changes: 21 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


Expand All @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.


| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
Expand All @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin

### Planned features

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down Expand Up @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)

Expand All @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
### Prerequisites
Stable Baselines3 requires Python 3.8+.

#### Windows 10
#### Windows

To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).


### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
pip install stable-baselines3[extra]
```
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
Expand Down Expand Up @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
Expand All @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks:

<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
Expand All @@ -218,9 +226,9 @@ To run a single test:
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
You can also do a static type check using `mypy`:
```sh
pip install pytype mypy
pip install mypy
make type
```

Expand Down Expand Up @@ -252,6 +260,8 @@ To cite this repository in publications:
}
```

Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).

## Maintainers

Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
Expand Down
12 changes: 6 additions & 6 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
name: root
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=22.3.1
- python=3.8
- pytorch=1.13.0=py3.8_cpu_0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium
- gymnasium>=0.28.1,<0.30
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- numpy>=1.20,<2.0
- matplotlib
- sphinx>=5,<8
- sphinx_rtd_theme>=1.3.0
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inconsistency in CrossQ implementation location.

There appears to be an inconsistency in the documentation. The table indicates CrossQ is implemented in SB3 Contrib (via footnote [#f1]), but the note at the bottom of the file mentions it's in the SBX repo. Please clarify which repository actually contains the CrossQ implementation.

DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌ ✔️
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_


**Gym Wrappers**:
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ To cite this project in publications:
url = {http://jmlr.org/papers/v22/20-1364.html}
}

Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI <https://doi.org/10.5281/zenodo.8123988>`_.

Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
You can check issues in the `repository <https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Expand Down
15 changes: 14 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Changelog
==========

Release 2.4.0a9 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib**

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
Expand Down Expand Up @@ -43,6 +45,10 @@ Bug Fixes:

`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

`RL Zoo`_
^^^^^^^^^
Expand All @@ -60,12 +66,17 @@ Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
- Clarified documentation about planned features and citing software

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down Expand Up @@ -653,6 +664,7 @@ New Features:
- Added checkpoints for replay buffer and ``VecNormalize`` statistics (@anand-bala)
- Added option for ``Monitor`` to append to existing file instead of overriding (@sidney-tio)
- The env checker now raises an error when using dict observation spaces and observation keys don't match observation space keys
- Use MacOS Metal "mps" device when available

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -710,6 +722,7 @@ Breaking Changes:

New Features:
^^^^^^^^^^^^^
- Save cloudpickle version


`SB3-Contrib`_
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
vec_env.render("human")


.. note::

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv

if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)

For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
-------

Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
:return:
"""
if copy:
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
Comment on lines +139 to +140
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Modify the MPS check in the to_torch method for better compatibility

Consider changing the conditional check to determine if the device is MPS by checking self.device.type == "mps" instead of hasattr(th, "backends") and th.backends.mps.is_built(). This ensures more robust detection of the MPS backend and better compatibility across different PyTorch versions.

Apply this diff to modify the condition:

 def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
     if copy:
-        if hasattr(th, "backends") and th.backends.mps.is_built():
+        if self.device.type == "mps":
             return th.tensor(array, dtype=th.float32, device=self.device)
         return th.tensor(array, device=self.device)
     return th.as_tensor(array, device=self.device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(th, "backends") and th.backends.mps.is_built():
return th.tensor(array, dtype=th.float32, device=self.device)
if self.device.type == "mps":
return th.tensor(array, dtype=th.float32, device=self.device)

return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device)

Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/envs/bit_flipping_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
return int(sum(state[i] * 2**i for i in range(len(state))))
return int(sum(int(state[i]) * 2**i for i in range(len(state))))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify state conversion using NumPy vectorization

Currently, the state conversion uses a generator expression with explicit loops and integer casting:

return int(sum(int(state[i]) * 2**i for i in range(len(state))))

This can be simplified and optimized by leveraging NumPy's vectorized operations, which are more efficient and concise. Consider rewriting the code using np.dot:

Apply this diff to simplify the code:

- return int(sum(int(state[i]) * 2**i for i in range(len(state))))
+ return int(state.dot(2 ** np.arange(len(state))))

This approach eliminates the explicit loop and casting, improving performance and readability.


if self.image_obs_space:
size = np.prod(self.image_shape)
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize image creation by preallocating the array

Instead of concatenating arrays to create the image, you can preallocate the array and assign values directly. This avoids unnecessary memory allocation and improves performance:

Apply this diff to optimize the code:

- image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
+ image = np.zeros(size, dtype=np.uint8)
+ image[:len(state)] = state.astype(np.uint8) * 255

This refactored code preallocates a zero-filled array of the required size and assigns the scaled state values directly to the beginning of the array.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
image = np.zeros(size, dtype=np.uint8)
image[:len(state)] = state.astype(np.uint8) * 255

return image.reshape(self.image_shape).astype(np.uint8)
return state

Expand Down
23 changes: 23 additions & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -135,6 +136,28 @@ def _setup_model(self) -> None:
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
# Warn when not using CPU with MlpPolicy
self._maybe_recommend_cpu()

def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.

:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)
Comment on lines +142 to +160
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Adjust default policy class name and enhance MPS-specific guidance

The implementation is good but needs two adjustments:

  1. The default mlp_class_name parameter value doesn't match the intended behavior:

    • The docstring mentions MlpPolicy, but the default is "ActorCriticPolicy"
    • This could trigger warnings for non-MLP policies
  2. Given this PR's focus on MPS support, the warning message should include MPS-specific guidance.

Here's the suggested implementation:

-    def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
+    def _maybe_recommend_cpu(self, mlp_class_name: str = "MlpPolicy") -> None:
         """
         Recommend to use CPU only when using A2C/PPO with MlpPolicy.
 
         :param: The name of the class for the default MlpPolicy.
         """
         policy_class_name = self.policy_class.__name__
         if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
             warnings.warn(
                 f"You are trying to run {self.__class__.__name__} on the GPU, "
                 "but it is primarily intended to run on the CPU when not using a CNN policy "
                 f"(you are using {policy_class_name} which should be a MlpPolicy). "
+                "This is especially important for MPS (Apple Silicon GPU) users. "
                 "See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
                 "for more info. "
                 "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
                 "Note: The model will train, but the GPU utilization will be poor and "
                 "the training might take longer than on CPU.",
                 UserWarning,
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)
def _maybe_recommend_cpu(self, mlp_class_name: str = "MlpPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"This is especially important for MPS (Apple Silicon GPU) users. "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)


def collect_rollouts(
self,
Expand Down
Loading
Loading