Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add has_attr for VecEnv #2077

Merged
merged 4 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@
Changelog
==========

Release 2.5.0 (2025-01-27)
Release 2.6.0a0 (WIP)
--------------------------

**New algorithm: SimBa in SBX, NumPy 2.0 support**


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8

New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists

Bug Fixes:
^^^^^^^^^^
- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt`

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -30,12 +26,39 @@ Bug Fixes:

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Updated black from v24 to v25

Documentation:
^^^^^^^^^^^^^^


Release 2.5.0 (2025-01-27)
--------------------------

**New algorithm: SimBa in SBX, NumPy 2.0 support**


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8

New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets

Others:
^^^^^^^
- Updated Dockerfile
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.3.1",
# Reformat
"black>=24.2.0,<25",
"black>=25.1.0,<26",
],
"docs": [
"sphinx>=5,<9",
Expand Down
18 changes: 18 additions & 0 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ def close(self) -> None:
"""
raise NotImplementedError()

def has_attr(self, attr_name: str) -> bool:
"""
Check if an attribute exists for a vectorized environment.

:param attr_name: The name of the attribute to check
:return: True if 'attr_name' exists in all environments
"""
# Default implementation, will not work with things that cannot be pickled:
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
try:
self.get_attr(attr_name)
return True
except AttributeError:
return False

@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""
Expand Down Expand Up @@ -392,6 +407,9 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]:
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()

def has_attr(self, attr_name: str) -> bool:
return self.venv.has_attr(attr_name)

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
return self.venv.get_attr(attr_name, indices)

Expand Down
17 changes: 16 additions & 1 deletion stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from stable_baselines3.common.vec_env.patch_gym import _patch_env


def _worker(
def _worker( # noqa: C901
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
Expand Down Expand Up @@ -58,6 +58,12 @@ def _worker(
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(env.get_wrapper_attr(data))
elif cmd == "has_attr":
try:
env.get_wrapper_attr(data)
remote.send(True)
except AttributeError:
remote.send(False)
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
Expand All @@ -66,6 +72,8 @@ def _worker(
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
except KeyboardInterrupt:
break


class SubprocVecEnv(VecEnv):
Expand Down Expand Up @@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]:
outputs = [pipe.recv() for pipe in self.remotes]
return outputs

def has_attr(self, attr_name: str) -> bool:
"""Check if an attribute exists for a vectorized environment. (see base class)."""
target_remotes = self._get_target_remotes(indices=None)
for remote in target_remotes:
remote.send(("has_attr", attr_name))
return all([remote.recv() for remote in target_remotes])

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0
2.6.0a0
18 changes: 18 additions & 0 deletions tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,30 @@ def make_env():
# we need a X server to test the "human" mode (uses OpenCV)
# vec_env.render(mode="human")

# Set a new attribute, on the last wrapper and on the env
assert not vec_env.has_attr("dummy")
# Set value for the last wrapper only
vec_env.set_attr("dummy", 12)
assert vec_env.get_attr("dummy") == [12] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].dummy == 12

assert not vec_env.has_attr("dummy2")
# Set the value on the original env
# `set_wrapper_attr` doesn't exist before v1.0
if gym.__version__ > "1":
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].unwrapped.dummy2 == 2

env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
# Retrieve the value for each environment
assert vec_env.has_attr("current_step")
getattr_results = vec_env.get_attr("current_step")

assert len(env_method_results) == N_ENVS
Expand Down