Skip to content

Commit 955382e

Browse files
authored
Merge branch 'master' into feat/mps-support
2 parents b85a2a5 + 512eea9 commit 955382e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+539
-167
lines changed

.github/workflows/ci.yml

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: CI
55

66
on:
77
push:
8-
branches: [ master ]
8+
branches: [master]
99
pull_request:
10-
branches: [ master ]
10+
branches: [master]
1111

1212
jobs:
1313
build:
@@ -23,38 +23,40 @@ jobs:
2323
python-version: ["3.8", "3.9", "3.10", "3.11"]
2424

2525
steps:
26-
- uses: actions/checkout@v3
27-
- name: Set up Python ${{ matrix.python-version }}
28-
uses: actions/setup-python@v4
29-
with:
30-
python-version: ${{ matrix.python-version }}
31-
- name: Install dependencies
32-
run: |
33-
python -m pip install --upgrade pip
34-
# cpu version of pytorch
35-
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
26+
- uses: actions/checkout@v3
27+
- name: Set up Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v4
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
# cpu version of pytorch
35+
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
3636
37-
# Install Atari Roms
38-
pip install autorom
39-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
40-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
41-
AutoROM --accept-license --source-file Roms.tar.gz
37+
# Install Atari Roms
38+
pip install autorom
39+
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
40+
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
41+
AutoROM --accept-license --source-file Roms.tar.gz
4242
43-
pip install .[extra_no_roms,tests,docs]
44-
# Use headless version
45-
pip install opencv-python-headless
46-
- name: Lint with ruff
47-
run: |
48-
make lint
49-
- name: Build the doc
50-
run: |
51-
make doc
52-
- name: Check codestyle
53-
run: |
54-
make check-codestyle
55-
- name: Type check
56-
run: |
57-
make type
58-
- name: Test with pytest
59-
run: |
60-
make pytest
43+
pip install .[extra_no_roms,tests,docs]
44+
# Use headless version
45+
pip install opencv-python-headless
46+
- name: Lint with ruff
47+
run: |
48+
make lint
49+
- name: Build the doc
50+
run: |
51+
make doc
52+
- name: Check codestyle
53+
run: |
54+
make check-codestyle
55+
- name: Type check
56+
run: |
57+
make type
58+
# Do not run for python 3.8 (mypy internal error)
59+
if: matrix.python-version != '3.8'
60+
- name: Test with pytest
61+
run: |
62+
make pytest

CODE_OF_CONDUCT.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
We as members, contributors, and leaders pledge to make participation in our
66
community a harassment-free experience for everyone, regardless of age, body
77
size, visible or invisible disability, ethnicity, sex characteristics, gender
8-
identity and expression, level of experience, education, socio-economic status,
8+
identity and expression, level of experience, education, socioeconomic status,
99
nationality, personal appearance, race, religion, or sexual identity
1010
and orientation.
1111

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ lint:
2020
# see https://www.flake8rules.com/
2121
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
2222
# exit-zero treats all errors as warnings.
23-
ruff check ${LINT_PATHS} --exit-zero
23+
ruff check ${LINT_PATHS} --exit-zero --output-format=concise
2424

2525
format:
2626
# Sort imports

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
2-
31
<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
42
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
53
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
@@ -8,6 +6,8 @@
86

97
# Stable Baselines3
108

9+
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
10+
1111
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
1212

1313
You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
@@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/
8585

8686
## Stable-Baselines Jax (SBX)
8787

88-
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
88+
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.
8989

9090
It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698
9191

@@ -192,7 +192,7 @@ All the following examples can be executed online using Google Colab notebooks:
192192
<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.
193193

194194
Actions `gym.spaces`:
195-
* `Box`: A N-dimensional box that containes every point in the action space.
195+
* `Box`: A N-dimensional box that contains every point in the action space.
196196
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
197197
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
198198
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.

docs/guide/algos.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ Actions ``gym.spaces``:
4343

4444
.. note::
4545

46-
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
46+
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`
47+
and in our :ref:`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, ...).
4748

4849
.. note::
4950

docs/guide/examples.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
128128
129129
:param env_id: the environment ID
130130
:param num_env: the number of environments you wish to have in subprocesses
131-
:param seed: the inital seed for RNG
131+
:param seed: the initial seed for RNG
132132
:param rank: index of the subprocess
133133
"""
134134
def _init():
@@ -179,9 +179,9 @@ Multiprocessing with off-policy algorithms
179179
180180
vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
181181
182-
# We collect 4 transitions per call to `ènv.step()`
183-
# and performs 2 gradient steps per call to `ènv.step()`
184-
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
182+
# We collect 4 transitions per call to `env.step()`
183+
# and performs 2 gradient steps per call to `env.step()`
184+
# if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()`
185185
model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1)
186186
model.learn(total_timesteps=10_000)
187187
@@ -436,7 +436,7 @@ will compute a running average and standard deviation of input features (it can
436436
log_dir = "/tmp/"
437437
model.save(log_dir + "ppo_halfcheetah")
438438
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
439-
env.save(stats_path)
439+
vec_env.save(stats_path)
440440
441441
# To demonstrate loading
442442
del model, vec_env

docs/guide/rl_tips.rst

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Reinforcement Learning Tips and Tricks
55
======================================
66

7-
The aim of this section is to help you do reinforcement learning experiments.
7+
The aim of this section is to help you run reinforcement learning experiments.
88
It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...),
99
as well as tips and tricks when using a custom environment or implementing an RL algorithm.
1010

@@ -14,6 +14,11 @@ as well as tips and tricks when using a custom environment or implementing an RL
1414
this section in more details. You can also find the `slides here <https://araffin.github.io/slides/rlvs-tips-tricks/>`_.
1515

1616

17+
.. note::
18+
19+
We also have a `video on Designing and Running Real-World RL Experiments <https://youtu.be/eZ6ZEpCi6D8>`_, slides `can be found online <https://araffin.github.io/slides/design-real-rl-experiments/>`_.
20+
21+
1722
General advice when using Reinforcement Learning
1823
================================================
1924

@@ -103,19 +108,19 @@ and this `issue <https://github.com/hill-a/stable-baselines/issues/199>`_ by Cé
103108
Which algorithm should I use?
104109
=============================
105110

106-
There is no silver bullet in RL, depending on your needs and problem, you may choose one or the other.
111+
There is no silver bullet in RL, you can choose one or the other depending on your needs and problems.
107112
The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...)
108113
or continuous actions (ex: go to a certain speed)?
109114

110-
Some algorithms are only tailored for one or the other domain: ``DQN`` only supports discrete actions, where ``SAC`` is restricted to continuous actions.
115+
Some algorithms are only tailored for one or the other domain: ``DQN`` supports only discrete actions, while ``SAC`` is restricted to continuous actions.
111116

112-
The second difference that will help you choose is whether you can parallelize your training or not.
117+
The second difference that will help you decide is whether you can parallelize your training or not.
113118
If what matters is the wall clock training time, then you should lean towards ``A2C`` and its derivatives (PPO, ...).
114119
Take a look at the `Vectorized Environments <vec_envs.html>`_ to learn more about training with multiple workers.
115120

116-
To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has fewer features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update.
121+
To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update.
117122

118-
In sparse reward settings, we either recommend to use dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo <sb3_contrib>`).
123+
In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo <sb3_contrib>`).
119124

120125
To sum it up:
121126

@@ -146,7 +151,7 @@ Continuous Actions
146151
Continuous Actions - Single Process
147152
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148153

149-
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>`).
154+
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3``, ``CrossQ`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>` and :ref:`SBX (SB3 + Jax) repo <sbx>`).
150155
Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for best results.
151156

152157
If you want an extremely sample-efficient algorithm, we recommend using the `DroQ configuration <https://twitter.com/araffin2/status/1575439865222660098>`_ in `SBX`_ (it does many gradient steps per step in the environment).
@@ -155,8 +160,7 @@ If you want an extremely sample-efficient algorithm, we recommend using the `Dro
155160
Continuous Actions - Multiprocessed
156161
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
157162

158-
Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo <sb3_contrib>`) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
159-
for continuous actions problems (cf *Bullet* envs).
163+
Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo <sb3_contrib>`) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for continuous actions problems (cf *Bullet* envs).
160164

161165
.. note::
162166

@@ -181,26 +185,23 @@ Tips and Tricks when creating a custom environment
181185
==================================================
182186

183187
If you want to learn about how to create a custom environment, we recommend you read this `page <custom_env.html>`_.
184-
We also provide a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for
185-
a concrete example of creating a custom gym environment.
188+
We also provide a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example of creating a custom gym environment.
186189

187190
Some basic advice:
188191

189-
- always normalize your observation space when you can, i.e., when you know the boundaries
190-
- normalize your action space and make it symmetric when continuous (cf potential issue below) A good practice is to rescale your actions to lie in [-1, 1]. This does not limit you as you can easily rescale the action inside the environment
191-
- start with shaped reward (i.e. informative reward) and simplified version of your problem
192-
- debug with random actions to check that your environment works and follows the gym interface:
192+
- always normalize your observation space if you can, i.e. if you know the boundaries
193+
- normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment.
194+
- start with a shaped reward (i.e. informative reward) and a simplified version of your problem
195+
- debug with random actions to check if your environment works and follows the gym interface (with ``check_env``, see below)
193196

194-
Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption
197+
Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption
195198
and properly handle termination due to a timeout (maximum number of steps in an episode).
196-
For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give a history of observations
197-
as input.
199+
For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input.
198200

199201
Termination due to timeout (max number of steps per episode) needs to be handled separately.
200202
You should return ``truncated = True``.
201203
If you are using the gym ``TimeLimit`` wrapper, this will be done automatically.
202-
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_ or take a look at the `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_
203-
for more details.
204+
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_, take a look at the `Designing and Running Real-World RL Experiments video <https://youtu.be/eZ6ZEpCi6D8>`_ or `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_ for more details.
204205

205206

206207
We provide a helper to check that your environment runs without error:
@@ -234,7 +235,7 @@ If you want to quickly try a random agent on your environment, you can also do:
234235

235236
Most reinforcement learning algorithms rely on a Gaussian distribution (initially centered at 0 with std 1) for continuous actions.
236237
So, if you forget to normalize the action space when using a custom environment,
237-
this can harm learning and be difficult to debug (cf attached image and `issue #473 <https://github.com/hill-a/stable-baselines/issues/473>`_).
238+
this can harm learning and can be difficult to debug (cf attached image and `issue #473 <https://github.com/hill-a/stable-baselines/issues/473>`_).
238239

239240
.. figure:: ../_static/img/mistake.png
240241

docs/guide/sbx.rst

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Implemented algorithms:
1717
- Deep Q Network (DQN)
1818
- Twin Delayed DDPG (TD3)
1919
- Deep Deterministic Policy Gradient (DDPG)
20+
- Batch Normalization in Deep Reinforcement Learning (CrossQ)
2021

2122

2223
As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
@@ -29,16 +30,17 @@ For that you will need to create two files:
2930
import rl_zoo3
3031
import rl_zoo3.train
3132
from rl_zoo3.train import train
32-
33-
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
33+
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
3434
3535
rl_zoo3.ALGOS["ddpg"] = DDPG
3636
rl_zoo3.ALGOS["dqn"] = DQN
37-
rl_zoo3.ALGOS["droq"] = DroQ
37+
# See SBX readme to use DroQ configuration
38+
# rl_zoo3.ALGOS["droq"] = DroQ
3839
rl_zoo3.ALGOS["sac"] = SAC
3940
rl_zoo3.ALGOS["ppo"] = PPO
4041
rl_zoo3.ALGOS["td3"] = TD3
4142
rl_zoo3.ALGOS["tqc"] = TQC
43+
rl_zoo3.ALGOS["crossq"] = CrossQ
4244
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
4345
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
4446
@@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t
5658
import rl_zoo3
5759
import rl_zoo3.enjoy
5860
from rl_zoo3.enjoy import enjoy
59-
60-
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
61+
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
6162
6263
rl_zoo3.ALGOS["ddpg"] = DDPG
6364
rl_zoo3.ALGOS["dqn"] = DQN
64-
rl_zoo3.ALGOS["droq"] = DroQ
65+
# See SBX readme to use DroQ configuration
66+
# rl_zoo3.ALGOS["droq"] = DroQ
6567
rl_zoo3.ALGOS["sac"] = SAC
6668
rl_zoo3.ALGOS["ppo"] = PPO
6769
rl_zoo3.ALGOS["td3"] = TD3
6870
rl_zoo3.ALGOS["tqc"] = TQC
71+
rl_zoo3.ALGOS["crossq"] = CrossQ
6972
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
7073
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
7174

docs/guide/tensorboard.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te
192192
193193
import gymnasium as gym
194194
import torch as th
195+
import numpy as np
195196
196197
from stable_baselines3 import A2C
197198
from stable_baselines3.common.callbacks import BaseCallback
@@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te
226227
:param _locals: A dictionary containing all local variables of the callback's scope
227228
:param _globals: A dictionary containing all global variables of the callback's scope
228229
"""
230+
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
231+
# with values in [0, 1], as described in
232+
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
229233
screen = self._eval_env.render(mode="rgb_array")
230234
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
231235
screens.append(screen.transpose(2, 0, 1))
@@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te
239243
)
240244
self.logger.record(
241245
"trajectory/video",
242-
Video(th.ByteTensor([screens]), fps=40),
246+
Video(th.from_numpy(np.asarray([screens])), fps=40),
243247
exclude=("stdout", "log", "json", "csv"),
244248
)
245249
return True

0 commit comments

Comments
 (0)