From 6d5ab0aa4e1b22db1ef2dc988171b6fb7a911e3f Mon Sep 17 00:00:00 2001
From: Antonin RAFFIN <antonin.raffin@ensta.org>
Date: Mon, 17 Feb 2025 11:31:08 +0100
Subject: [PATCH 1/2] Speedup handling of terminal observations

---
 docs/misc/changelog.rst                       |  3 +-
 .../common/off_policy_algorithm.py            | 25 ++++++++--------
 .../common/on_policy_algorithm.py             |  8 ++---
 .../common/vec_env/stacked_observations.py    | 29 ++++++++++---------
 .../common/vec_env/vec_monitor.py             | 29 +++++++++----------
 .../common/vec_env/vec_normalize.py           |  4 +--
 .../common/vec_env/vec_transpose.py           |  4 +--
 stable_baselines3/her/her_replay_buffer.py    |  5 ++--
 stable_baselines3/version.txt                 |  2 +-
 9 files changed, 50 insertions(+), 59 deletions(-)

diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index cd2bc23f7..3a9b52a4d 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -3,7 +3,7 @@
 Changelog
 ==========
 
-Release 2.6.0a1 (WIP)
+Release 2.6.0a2 (WIP)
 --------------------------
 
 
@@ -14,6 +14,7 @@ New Features:
 ^^^^^^^^^^^^^
 - Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
 - Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)
+- Speed up handling of terminal observations, only the env with finished episodes are checked (the speed up will be noticable on massively parallel env)
 
 Bug Fixes:
 ^^^^^^^^^^
diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py
index b778480d4..c9c3a7534 100644
--- a/stable_baselines3/common/off_policy_algorithm.py
+++ b/stable_baselines3/common/off_policy_algorithm.py
@@ -473,8 +473,8 @@ def _store_transition(
         next_obs = deepcopy(new_obs_)
         # As the VecEnv resets automatically, new_obs is already the
         # first observation of the next episode
-        for i, done in enumerate(dones):
-            if done and infos[i].get("terminal_observation") is not None:
+        for i in dones.nonzero()[0]:
+            if infos[i].get("terminal_observation") is not None:
                 if isinstance(next_obs, dict):
                     next_obs_ = infos[i]["terminal_observation"]
                     # VecNormalize normalizes the terminal observation
@@ -582,19 +582,18 @@ def collect_rollouts(
             # see https://github.com/hill-a/stable-baselines/issues/900
             self._on_step()
 
-            for idx, done in enumerate(dones):
-                if done:
-                    # Update stats
-                    num_collected_episodes += 1
-                    self._episode_num += 1
+            for idx in dones.nonzero()[0]:
+                # Update stats
+                num_collected_episodes += 1
+                self._episode_num += 1
 
-                    if action_noise is not None:
-                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
-                        action_noise.reset(**kwargs)
+                if action_noise is not None:
+                    kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
+                    action_noise.reset(**kwargs)
 
-                    # Log training infos
-                    if log_interval is not None and self._episode_num % log_interval == 0:
-                        self.dump_logs()
+                # Log training infos
+                if log_interval is not None and self._episode_num % log_interval == 0:
+                    self.dump_logs()
         callback.on_rollout_end()
 
         return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py
index 0db5ce5d5..ba5b30a9f 100644
--- a/stable_baselines3/common/on_policy_algorithm.py
+++ b/stable_baselines3/common/on_policy_algorithm.py
@@ -233,12 +233,8 @@ def collect_rollouts(
 
             # Handle timeout by bootstrapping with value function
             # see GitHub issue #633
-            for idx, done in enumerate(dones):
-                if (
-                    done
-                    and infos[idx].get("terminal_observation") is not None
-                    and infos[idx].get("TimeLimit.truncated", False)
-                ):
+            for idx in dones.nonzero()[0]:
+                if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
                     terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                     with th.no_grad():
                         terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py
index d1b3ad298..a3a716fe8 100644
--- a/stable_baselines3/common/vec_env/stacked_observations.py
+++ b/stable_baselines3/common/vec_env/stacked_observations.py
@@ -149,27 +149,28 @@ def update(
             # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
             # to [{}, {terminal_obs: {key1: ..., key2: ...}}]
             for key in stacked_infos.keys():
-                for env_idx in range(len(infos)):
+                # Optimization: only check for env where done=True
+                for env_idx in dones.nonzero()[0]:
                     if "terminal_observation" in infos[env_idx]:
                         infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
             return stacked_obs, infos
 
         shift = -observations.shape[self.stack_dimension]
         self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
-        for env_idx, done in enumerate(dones):
-            if done:
-                if "terminal_observation" in infos[env_idx]:
-                    old_terminal = infos[env_idx]["terminal_observation"]
-                    if self.channels_first:
-                        previous_stack = self.stacked_obs[env_idx, :shift, ...]
-                    else:
-                        previous_stack = self.stacked_obs[env_idx, ..., :shift]
-
-                    new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
-                    infos[env_idx]["terminal_observation"] = new_terminal
+        # Optimization: only check for env where done=True
+        for env_idx in dones.nonzero()[0]:
+            if "terminal_observation" in infos[env_idx]:
+                old_terminal = infos[env_idx]["terminal_observation"]
+                if self.channels_first:
+                    previous_stack = self.stacked_obs[env_idx, :shift, ...]
                 else:
-                    warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
-                self.stacked_obs[env_idx] = 0
+                    previous_stack = self.stacked_obs[env_idx, ..., :shift]
+
+                new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
+                infos[env_idx]["terminal_observation"] = new_terminal
+            else:
+                warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
+            self.stacked_obs[env_idx] = 0
         if self.channels_first:
             self.stacked_obs[:, shift:, ...] = observations
         else:
diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py
index 4aa9325f6..b578f1612 100644
--- a/stable_baselines3/common/vec_env/vec_monitor.py
+++ b/stable_baselines3/common/vec_env/vec_monitor.py
@@ -77,21 +77,20 @@ def step_wait(self) -> VecEnvStepReturn:
         self.episode_returns += rewards
         self.episode_lengths += 1
         new_infos = list(infos[:])
-        for i in range(len(dones)):
-            if dones[i]:
-                info = infos[i].copy()
-                episode_return = self.episode_returns[i]
-                episode_length = self.episode_lengths[i]
-                episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
-                for key in self.info_keywords:
-                    episode_info[key] = info[key]
-                info["episode"] = episode_info
-                self.episode_count += 1
-                self.episode_returns[i] = 0
-                self.episode_lengths[i] = 0
-                if self.results_writer:
-                    self.results_writer.write_row(episode_info)
-                new_infos[i] = info
+        for i in dones.nonzero()[0]:
+            info = infos[i].copy()
+            episode_return = self.episode_returns[i]
+            episode_length = self.episode_lengths[i]
+            episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
+            for key in self.info_keywords:
+                episode_info[key] = info[key]
+            info["episode"] = episode_info
+            self.episode_count += 1
+            self.episode_returns[i] = 0
+            self.episode_lengths[i] = 0
+            if self.results_writer:
+                self.results_writer.write_row(episode_info)
+            new_infos[i] = info
         return obs, rewards, dones, new_infos
 
     def close(self) -> None:
diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py
index 439243f9f..ed39fa08b 100644
--- a/stable_baselines3/common/vec_env/vec_normalize.py
+++ b/stable_baselines3/common/vec_env/vec_normalize.py
@@ -197,9 +197,7 @@ def step_wait(self) -> VecEnvStepReturn:
         rewards = self.normalize_reward(rewards)
 
         # Normalize the terminal observations
-        for idx, done in enumerate(dones):
-            if not done:
-                continue
+        for idx in dones.nonzero()[0]:
             if "terminal_observation" in infos[idx]:
                 infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
 
diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py
index 3fade64d1..0f3ba30ed 100644
--- a/stable_baselines3/common/vec_env/vec_transpose.py
+++ b/stable_baselines3/common/vec_env/vec_transpose.py
@@ -97,9 +97,7 @@ def step_wait(self) -> VecEnvStepReturn:
         observations, rewards, dones, infos = self.venv.step_wait()
 
         # Transpose the terminal observations
-        for idx, done in enumerate(dones):
-            if not done:
-                continue
+        for idx in dones.nonzero()[0]:
             if "terminal_observation" in infos[idx]:
                 infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
 
diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py
index e914c7ec0..e801ef733 100644
--- a/stable_baselines3/her/her_replay_buffer.py
+++ b/stable_baselines3/her/her_replay_buffer.py
@@ -162,9 +162,8 @@ def add(  # type: ignore[override]
         super().add(obs, next_obs, action, reward, done, infos)
 
         # When episode ends, compute and store the episode length
-        for env_idx in range(self.n_envs):
-            if done[env_idx]:
-                self._compute_episode_length(env_idx)
+        for env_idx in done.nonzero()[0]:
+            self._compute_episode_length(env_idx)
 
     def _compute_episode_length(self, env_idx: int) -> None:
         """
diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt
index 5809eab2a..814d58f33 100644
--- a/stable_baselines3/version.txt
+++ b/stable_baselines3/version.txt
@@ -1 +1 @@
-2.6.0a1
+2.6.0a2

From 34cc97c51c51f6185f95f1f0098daf8b5c419a72 Mon Sep 17 00:00:00 2001
From: Antonin RAFFIN <antonin.raffin@ensta.org>
Date: Mon, 17 Feb 2025 14:00:47 +0100
Subject: [PATCH 2/2] Check for dones first

---
 stable_baselines3/common/vec_env/stacked_observations.py | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py
index a3a716fe8..cd1d212f6 100644
--- a/stable_baselines3/common/vec_env/stacked_observations.py
+++ b/stable_baselines3/common/vec_env/stacked_observations.py
@@ -148,16 +148,14 @@ def update(
 
             # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
             # to [{}, {terminal_obs: {key1: ..., key2: ...}}]
-            for key in stacked_infos.keys():
-                # Optimization: only check for env where done=True
-                for env_idx in dones.nonzero()[0]:
-                    if "terminal_observation" in infos[env_idx]:
+            for env_idx in dones.nonzero()[0]:
+                if "terminal_observation" in infos[env_idx]:
+                    for key in stacked_infos.keys():
                         infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
             return stacked_obs, infos
 
         shift = -observations.shape[self.stack_dimension]
         self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
-        # Optimization: only check for env where done=True
         for env_idx in dones.nonzero()[0]:
             if "terminal_observation" in infos[env_idx]:
                 old_terminal = infos[env_idx]["terminal_observation"]