Skip to content

Commit b8b2d30

Browse files
authored
Add has_attr for VecEnv (#2077)
* Add `has_attr` for `VecEnv` * Add special case for gymnasium<1.0 * Update changelog.rst * Update black version
1 parent ee8a77d commit b8b2d30

File tree

6 files changed

+86
-12
lines changed

6 files changed

+86
-12
lines changed

docs/misc/changelog.rst

+32-9
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,20 @@
33
Changelog
44
==========
55

6-
Release 2.5.0 (2025-01-27)
6+
Release 2.6.0a0 (WIP)
77
--------------------------
88

9-
**New algorithm: SimBa in SBX, NumPy 2.0 support**
10-
119

1210
Breaking Changes:
1311
^^^^^^^^^^^^^^^^^
14-
- Increased minimum required version of PyTorch to 2.3.0
15-
- Removed support for Python 3.8
1612

1713
New Features:
1814
^^^^^^^^^^^^^
19-
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
20-
- Added official support for Python 3.12
15+
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
2116

2217
Bug Fixes:
2318
^^^^^^^^^^
19+
- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt`
2420

2521
`SB3-Contrib`_
2622
^^^^^^^^^^^^^^
@@ -30,12 +26,39 @@ Bug Fixes:
3026

3127
`SBX`_ (SB3 + Jax)
3228
^^^^^^^^^^^^^^^^^^
33-
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
34-
- Added support for parameter resets
3529

3630
Deprecations:
3731
^^^^^^^^^^^^^
3832

33+
Others:
34+
^^^^^^^
35+
- Updated black from v24 to v25
36+
37+
Documentation:
38+
^^^^^^^^^^^^^^
39+
40+
41+
Release 2.5.0 (2025-01-27)
42+
--------------------------
43+
44+
**New algorithm: SimBa in SBX, NumPy 2.0 support**
45+
46+
47+
Breaking Changes:
48+
^^^^^^^^^^^^^^^^^
49+
- Increased minimum required version of PyTorch to 2.3.0
50+
- Removed support for Python 3.8
51+
52+
New Features:
53+
^^^^^^^^^^^^^
54+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
55+
- Added official support for Python 3.12
56+
57+
`SBX`_ (SB3 + Jax)
58+
^^^^^^^^^^^^^^^^^^
59+
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
60+
- Added support for parameter resets
61+
3962
Others:
4063
^^^^^^^
4164
- Updated Dockerfile

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898
# Lint code and sort imports (flake8 and isort replacement)
9999
"ruff>=0.3.1",
100100
# Reformat
101-
"black>=24.2.0,<25",
101+
"black>=25.1.0,<26",
102102
],
103103
"docs": [
104104
"sphinx>=5,<9",

stable_baselines3/common/vec_env/base_vec_env.py

+18
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ def close(self) -> None:
147147
"""
148148
raise NotImplementedError()
149149

150+
def has_attr(self, attr_name: str) -> bool:
151+
"""
152+
Check if an attribute exists for a vectorized environment.
153+
154+
:param attr_name: The name of the attribute to check
155+
:return: True if 'attr_name' exists in all environments
156+
"""
157+
# Default implementation, will not work with things that cannot be pickled:
158+
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
159+
try:
160+
self.get_attr(attr_name)
161+
return True
162+
except AttributeError:
163+
return False
164+
150165
@abstractmethod
151166
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
152167
"""
@@ -392,6 +407,9 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
392407
def get_images(self) -> Sequence[Optional[np.ndarray]]:
393408
return self.venv.get_images()
394409

410+
def has_attr(self, attr_name: str) -> bool:
411+
return self.venv.has_attr(attr_name)
412+
395413
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
396414
return self.venv.get_attr(attr_name, indices)
397415

stable_baselines3/common/vec_env/subproc_vec_env.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from stable_baselines3.common.vec_env.patch_gym import _patch_env
1818

1919

20-
def _worker(
20+
def _worker( # noqa: C901
2121
remote: mp.connection.Connection,
2222
parent_remote: mp.connection.Connection,
2323
env_fn_wrapper: CloudpickleWrapper,
@@ -58,6 +58,12 @@ def _worker(
5858
remote.send(method(*data[1], **data[2]))
5959
elif cmd == "get_attr":
6060
remote.send(env.get_wrapper_attr(data))
61+
elif cmd == "has_attr":
62+
try:
63+
env.get_wrapper_attr(data)
64+
remote.send(True)
65+
except AttributeError:
66+
remote.send(False)
6167
elif cmd == "set_attr":
6268
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
6369
elif cmd == "is_wrapped":
@@ -66,6 +72,8 @@ def _worker(
6672
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
6773
except EOFError:
6874
break
75+
except KeyboardInterrupt:
76+
break
6977

7078

7179
class SubprocVecEnv(VecEnv):
@@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]:
165173
outputs = [pipe.recv() for pipe in self.remotes]
166174
return outputs
167175

176+
def has_attr(self, attr_name: str) -> bool:
177+
"""Check if an attribute exists for a vectorized environment. (see base class)."""
178+
target_remotes = self._get_target_remotes(indices=None)
179+
for remote in target_remotes:
180+
remote.send(("has_attr", attr_name))
181+
return all([remote.recv() for remote in target_remotes])
182+
168183
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
169184
"""Return attribute from vectorized environment (see base class)."""
170185
target_remotes = self._get_target_remotes(indices)

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.5.0
1+
2.6.0a0

tests/test_vec_envs.py

+18
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,30 @@ def make_env():
123123
# we need a X server to test the "human" mode (uses OpenCV)
124124
# vec_env.render(mode="human")
125125

126+
# Set a new attribute, on the last wrapper and on the env
127+
assert not vec_env.has_attr("dummy")
128+
# Set value for the last wrapper only
129+
vec_env.set_attr("dummy", 12)
130+
assert vec_env.get_attr("dummy") == [12] * N_ENVS
131+
if vec_env_class == DummyVecEnv:
132+
assert vec_env.envs[0].dummy == 12
133+
134+
assert not vec_env.has_attr("dummy2")
135+
# Set the value on the original env
136+
# `set_wrapper_attr` doesn't exist before v1.0
137+
if gym.__version__ > "1":
138+
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
139+
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
140+
if vec_env_class == DummyVecEnv:
141+
assert vec_env.envs[0].unwrapped.dummy2 == 2
142+
126143
env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
127144
setattr_results = []
128145
# Set current_step to an arbitrary value
129146
for env_idx in range(N_ENVS):
130147
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
131148
# Retrieve the value for each environment
149+
assert vec_env.has_attr("current_step")
132150
getattr_results = vec_env.get_attr("current_step")
133151

134152
assert len(env_method_results) == N_ENVS

0 commit comments

Comments
 (0)