From 6f557349b06f5eb5ef880b2a5d26dabb4b5f87c2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 30 Jan 2025 10:39:08 +0100 Subject: [PATCH 1/4] Add `has_attr` for `VecEnv` --- docs/misc/changelog.rst | 39 ++++++++++++++----- .../common/vec_env/base_vec_env.py | 18 +++++++++ .../common/vec_env/subproc_vec_env.py | 17 +++++++- stable_baselines3/version.txt | 2 +- tests/test_vec_envs.py | 16 ++++++++ 5 files changed, 81 insertions(+), 11 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf0db00fa..3f63f61f8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,21 +3,16 @@ Changelog ========== -Release 2.5.0 (2025-01-27) +Release 2.6.0a0 (WIP) -------------------------- -**New algorithm: SimBa in SBX, NumPy 2.0 support** - Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Increased minimum required version of PyTorch to 2.3.0 -- Removed support for Python 3.8 New Features: ^^^^^^^^^^^^^ -- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too -- Added official support for Python 3.12 +- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists Bug Fixes: ^^^^^^^^^^ @@ -30,12 +25,38 @@ Bug Fixes: `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ -- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL -- Added support for parameter resets Deprecations: ^^^^^^^^^^^^^ +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + +Release 2.5.0 (2025-01-27) +-------------------------- + +**New algorithm: SimBa in SBX, NumPy 2.0 support** + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Increased minimum required version of PyTorch to 2.3.0 +- Removed support for Python 3.8 + +New Features: +^^^^^^^^^^^^^ +- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too +- Added official support for Python 3.12 + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ +- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL +- Added support for parameter resets + Others: ^^^^^^^ - Updated Dockerfile diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 71ee15e61..370113108 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -147,6 +147,21 @@ def close(self) -> None: """ raise NotImplementedError() + def has_attr(self, attr_name: str) -> bool: + """ + Check if an attribute exists for a vectorized environment. + + :param attr_name: The name of the attribute to check + :return: True if 'attr_name' exists in all environments + """ + # Default implementation, will not work with things that cannot be pickled: + # https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49 + try: + self.get_attr(attr_name) + return True + except AttributeError: + return False + @abstractmethod def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ @@ -392,6 +407,9 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() + def has_attr(self, attr_name: str) -> bool: + return self.venv.has_attr(attr_name) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 225eadd79..1563d70b1 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env.patch_gym import _patch_env -def _worker( +def _worker( # noqa: C901 remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, @@ -58,6 +58,12 @@ def _worker( remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": remote.send(env.get_wrapper_attr(data)) + elif cmd == "has_attr": + try: + env.get_wrapper_attr(data) + remote.send(True) + except AttributeError: + remote.send(False) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -66,6 +72,8 @@ def _worker( raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: break + except KeyboardInterrupt: + break class SubprocVecEnv(VecEnv): @@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]: outputs = [pipe.recv() for pipe in self.remotes] return outputs + def has_attr(self, attr_name: str) -> bool: + """Check if an attribute exists for a vectorized environment. (see base class).""" + target_remotes = self._get_target_remotes(indices=None) + for remote in target_remotes: + remote.send(("has_attr", attr_name)) + return all([remote.recv() for remote in target_remotes]) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 437459cd9..3d87ca93f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.5.0 +2.6.0a0 diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 7e4e5ec0b..f8737018b 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -123,12 +123,28 @@ def make_env(): # we need a X server to test the "human" mode (uses OpenCV) # vec_env.render(mode="human") + # Set a new attribute, on the last wrapper and on the env + assert not vec_env.has_attr("dummy") + # Set value for the last wrapper only + vec_env.set_attr("dummy", 12) + assert vec_env.get_attr("dummy") == [12] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].dummy == 12 + + assert not vec_env.has_attr("dummy2") + # Set the value on the original env + vec_env.env_method("set_wrapper_attr", "dummy2", 2) + assert vec_env.get_attr("dummy2") == [2] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].unwrapped.dummy2 == 2 + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value for env_idx in range(N_ENVS): setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx)) # Retrieve the value for each environment + assert vec_env.has_attr("current_step") getattr_results = vec_env.get_attr("current_step") assert len(env_method_results) == N_ENVS From d498a035da5b08697d8c8e4dc9e421a7015610cb Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 30 Jan 2025 11:41:07 +0100 Subject: [PATCH 2/4] Add special case for gymnasium<1.0 --- tests/test_vec_envs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index f8737018b..43a693ddd 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -133,10 +133,12 @@ def make_env(): assert not vec_env.has_attr("dummy2") # Set the value on the original env - vec_env.env_method("set_wrapper_attr", "dummy2", 2) - assert vec_env.get_attr("dummy2") == [2] * N_ENVS - if vec_env_class == DummyVecEnv: - assert vec_env.envs[0].unwrapped.dummy2 == 2 + # `set_wrapper_attr` doesn't exist before v1.0 + if gym.__version__ > "1": + vec_env.env_method("set_wrapper_attr", "dummy2", 2) + assert vec_env.get_attr("dummy2") == [2] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].unwrapped.dummy2 == 2 env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] From 2140944487e4ed4afe8b63c77fca600db6d3dd20 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 31 Jan 2025 22:17:36 +0100 Subject: [PATCH 3/4] Update changelog.rst --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3f63f61f8..e8f7faa70 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt` `SB3-Contrib`_ ^^^^^^^^^^^^^^ From ebe007da17663ef4afeed6ac57960873f6252cb3 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 3 Feb 2025 09:18:15 +0100 Subject: [PATCH 4/4] Update black version --- docs/misc/changelog.rst | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e8f7faa70..c7967a1e1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,7 @@ Deprecations: Others: ^^^^^^^ +- Updated black from v24 to v25 Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index fa24fc8a3..8123cf43a 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.3.1", # Reformat - "black>=24.2.0,<25", + "black>=25.1.0,<26", ], "docs": [ "sphinx>=5,<9",