Skip to content

Commit 512eea9

Browse files
authored
Warn users when using multi-dim MultiDiscrete obs space (DLR-RM#2003)
* Update env checker to warn users when using multi-dim MultiDiscrete obs space * Update changelog
1 parent 9a3b28b commit 512eea9

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

docs/misc/changelog.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a8 (WIP)
6+
Release 2.4.0a9 (WIP)
77
--------------------------
88

99
.. note::
@@ -13,13 +13,21 @@ Release 2.4.0a8 (WIP)
1313
To suppress the warning, simply save the model again.
1414
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_
1515

16+
.. warning::
17+
18+
Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
19+
and PyTorch < 2.0.
20+
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.
21+
22+
1623
Breaking Changes:
1724
^^^^^^^^^^^^^^^^^
1825

1926
New Features:
2027
^^^^^^^^^^^^^
2128
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
2229
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
30+
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
2331

2432
Bug Fixes:
2533
^^^^^^^^^^

stable_baselines3/common/env_checker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
9898
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
9999
)
100100

101+
if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
102+
warnings.warn(
103+
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
104+
"which is currently not supported by Stable-Baselines3. "
105+
"Please convert it to a 1D array using a wrapper: "
106+
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
107+
)
108+
101109
if isinstance(observation_space, spaces.Tuple):
102110
warnings.warn(
103111
"The observation space is a Tuple, "

stable_baselines3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a8
1+
2.4.0a9

tests/test_envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def patched_step(_action):
123123
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
124124
# Non zero start index
125125
spaces.Discrete(3, start=-1),
126+
# 2D MultiDiscrete
127+
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
126128
# Non zero start index (MultiDiscrete)
127129
spaces.MultiDiscrete([4, 4], start=[1, 0]),
128130
# Non zero start index inside a Dict

0 commit comments

Comments
 (0)