diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index c02a185b3..af95ae972 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -14,6 +14,7 @@ New Features:
 ^^^^^^^^^^^^^
 - Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
 - Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)
+- Speed up handling of terminal observations, only the env with finished episodes are checked (the speed up will be noticable on massively parallel env)
 - Added Gymnasium v1.1 support
 
 Bug Fixes:
diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py
index a2df272d7..0c2582713 100644
--- a/stable_baselines3/common/off_policy_algorithm.py
+++ b/stable_baselines3/common/off_policy_algorithm.py
@@ -473,8 +473,8 @@ def _store_transition(
         next_obs = deepcopy(new_obs_)
         # As the VecEnv resets automatically, new_obs is already the
         # first observation of the next episode
-        for i, done in enumerate(dones):
-            if done and infos[i].get("terminal_observation") is not None:
+        for i in dones.nonzero()[0]:
+            if infos[i].get("terminal_observation") is not None:
                 if isinstance(next_obs, dict):
                     next_obs_ = infos[i]["terminal_observation"]
                     # VecNormalize normalizes the terminal observation
@@ -582,19 +582,18 @@ def collect_rollouts(
             # see https://github.com/hill-a/stable-baselines/issues/900
             self._on_step()
 
-            for idx, done in enumerate(dones):
-                if done:
-                    # Update stats
-                    num_collected_episodes += 1
-                    self._episode_num += 1
+            for idx in dones.nonzero()[0]:
+                # Update stats
+                num_collected_episodes += 1
+                self._episode_num += 1
 
-                    if action_noise is not None:
-                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
-                        action_noise.reset(**kwargs)
+                if action_noise is not None:
+                    kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
+                    action_noise.reset(**kwargs)
 
-                    # Log training infos
-                    if log_interval is not None and self._episode_num % log_interval == 0:
-                        self.dump_logs()
+                # Log training infos
+                if log_interval is not None and self._episode_num % log_interval == 0:
+                    self.dump_logs()
         callback.on_rollout_end()
 
         return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py
index 0db5ce5d5..ba5b30a9f 100644
--- a/stable_baselines3/common/on_policy_algorithm.py
+++ b/stable_baselines3/common/on_policy_algorithm.py
@@ -233,12 +233,8 @@ def collect_rollouts(
 
             # Handle timeout by bootstrapping with value function
             # see GitHub issue #633
-            for idx, done in enumerate(dones):
-                if (
-                    done
-                    and infos[idx].get("terminal_observation") is not None
-                    and infos[idx].get("TimeLimit.truncated", False)
-                ):
+            for idx in dones.nonzero()[0]:
+                if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
                     terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                     with th.no_grad():
                         terminal_value = self.policy.predict_values(terminal_obs)[0]  # type: ignore[arg-type]
diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py
index d1b3ad298..cd1d212f6 100644
--- a/stable_baselines3/common/vec_env/stacked_observations.py
+++ b/stable_baselines3/common/vec_env/stacked_observations.py
@@ -148,28 +148,27 @@ def update(
 
             # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
             # to [{}, {terminal_obs: {key1: ..., key2: ...}}]
-            for key in stacked_infos.keys():
-                for env_idx in range(len(infos)):
-                    if "terminal_observation" in infos[env_idx]:
+            for env_idx in dones.nonzero()[0]:
+                if "terminal_observation" in infos[env_idx]:
+                    for key in stacked_infos.keys():
                         infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
             return stacked_obs, infos
 
         shift = -observations.shape[self.stack_dimension]
         self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
-        for env_idx, done in enumerate(dones):
-            if done:
-                if "terminal_observation" in infos[env_idx]:
-                    old_terminal = infos[env_idx]["terminal_observation"]
-                    if self.channels_first:
-                        previous_stack = self.stacked_obs[env_idx, :shift, ...]
-                    else:
-                        previous_stack = self.stacked_obs[env_idx, ..., :shift]
-
-                    new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
-                    infos[env_idx]["terminal_observation"] = new_terminal
+        for env_idx in dones.nonzero()[0]:
+            if "terminal_observation" in infos[env_idx]:
+                old_terminal = infos[env_idx]["terminal_observation"]
+                if self.channels_first:
+                    previous_stack = self.stacked_obs[env_idx, :shift, ...]
                 else:
-                    warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
-                self.stacked_obs[env_idx] = 0
+                    previous_stack = self.stacked_obs[env_idx, ..., :shift]
+
+                new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
+                infos[env_idx]["terminal_observation"] = new_terminal
+            else:
+                warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
+            self.stacked_obs[env_idx] = 0
         if self.channels_first:
             self.stacked_obs[:, shift:, ...] = observations
         else:
diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py
index 4aa9325f6..b578f1612 100644
--- a/stable_baselines3/common/vec_env/vec_monitor.py
+++ b/stable_baselines3/common/vec_env/vec_monitor.py
@@ -77,21 +77,20 @@ def step_wait(self) -> VecEnvStepReturn:
         self.episode_returns += rewards
         self.episode_lengths += 1
         new_infos = list(infos[:])
-        for i in range(len(dones)):
-            if dones[i]:
-                info = infos[i].copy()
-                episode_return = self.episode_returns[i]
-                episode_length = self.episode_lengths[i]
-                episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
-                for key in self.info_keywords:
-                    episode_info[key] = info[key]
-                info["episode"] = episode_info
-                self.episode_count += 1
-                self.episode_returns[i] = 0
-                self.episode_lengths[i] = 0
-                if self.results_writer:
-                    self.results_writer.write_row(episode_info)
-                new_infos[i] = info
+        for i in dones.nonzero()[0]:
+            info = infos[i].copy()
+            episode_return = self.episode_returns[i]
+            episode_length = self.episode_lengths[i]
+            episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
+            for key in self.info_keywords:
+                episode_info[key] = info[key]
+            info["episode"] = episode_info
+            self.episode_count += 1
+            self.episode_returns[i] = 0
+            self.episode_lengths[i] = 0
+            if self.results_writer:
+                self.results_writer.write_row(episode_info)
+            new_infos[i] = info
         return obs, rewards, dones, new_infos
 
     def close(self) -> None:
diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py
index 439243f9f..ed39fa08b 100644
--- a/stable_baselines3/common/vec_env/vec_normalize.py
+++ b/stable_baselines3/common/vec_env/vec_normalize.py
@@ -197,9 +197,7 @@ def step_wait(self) -> VecEnvStepReturn:
         rewards = self.normalize_reward(rewards)
 
         # Normalize the terminal observations
-        for idx, done in enumerate(dones):
-            if not done:
-                continue
+        for idx in dones.nonzero()[0]:
             if "terminal_observation" in infos[idx]:
                 infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])
 
diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py
index 3fade64d1..0f3ba30ed 100644
--- a/stable_baselines3/common/vec_env/vec_transpose.py
+++ b/stable_baselines3/common/vec_env/vec_transpose.py
@@ -97,9 +97,7 @@ def step_wait(self) -> VecEnvStepReturn:
         observations, rewards, dones, infos = self.venv.step_wait()
 
         # Transpose the terminal observations
-        for idx, done in enumerate(dones):
-            if not done:
-                continue
+        for idx in dones.nonzero()[0]:
             if "terminal_observation" in infos[idx]:
                 infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])
 
diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py
index e914c7ec0..e801ef733 100644
--- a/stable_baselines3/her/her_replay_buffer.py
+++ b/stable_baselines3/her/her_replay_buffer.py
@@ -162,9 +162,8 @@ def add(  # type: ignore[override]
         super().add(obs, next_obs, action, reward, done, infos)
 
         # When episode ends, compute and store the episode length
-        for env_idx in range(self.n_envs):
-            if done[env_idx]:
-                self._compute_episode_length(env_idx)
+        for env_idx in done.nonzero()[0]:
+            self._compute_episode_length(env_idx)
 
     def _compute_episode_length(self, env_idx: int) -> None:
         """