From 6d5ab0aa4e1b22db1ef2dc988171b6fb7a911e3f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN <antonin.raffin@ensta.org> Date: Mon, 17 Feb 2025 11:31:08 +0100 Subject: [PATCH 1/2] Speedup handling of terminal observations --- docs/misc/changelog.rst | 3 +- .../common/off_policy_algorithm.py | 25 ++++++++-------- .../common/on_policy_algorithm.py | 8 ++--- .../common/vec_env/stacked_observations.py | 29 ++++++++++--------- .../common/vec_env/vec_monitor.py | 29 +++++++++---------- .../common/vec_env/vec_normalize.py | 4 +-- .../common/vec_env/vec_transpose.py | 4 +-- stable_baselines3/her/her_replay_buffer.py | 5 ++-- stable_baselines3/version.txt | 2 +- 9 files changed, 50 insertions(+), 59 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cd2bc23f7..3a9b52a4d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.6.0a1 (WIP) +Release 2.6.0a2 (WIP) -------------------------- @@ -14,6 +14,7 @@ New Features: ^^^^^^^^^^^^^ - Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists - Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference) +- Speed up handling of terminal observations, only the env with finished episodes are checked (the speed up will be noticable on massively parallel env) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index b778480d4..c9c3a7534 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -473,8 +473,8 @@ def _store_transition( next_obs = deepcopy(new_obs_) # As the VecEnv resets automatically, new_obs is already the # first observation of the next episode - for i, done in enumerate(dones): - if done and infos[i].get("terminal_observation") is not None: + for i in dones.nonzero()[0]: + if infos[i].get("terminal_observation") is not None: if isinstance(next_obs, dict): next_obs_ = infos[i]["terminal_observation"] # VecNormalize normalizes the terminal observation @@ -582,19 +582,18 @@ def collect_rollouts( # see https://github.com/hill-a/stable-baselines/issues/900 self._on_step() - for idx, done in enumerate(dones): - if done: - # Update stats - num_collected_episodes += 1 - self._episode_num += 1 + for idx in dones.nonzero()[0]: + # Update stats + num_collected_episodes += 1 + self._episode_num += 1 - if action_noise is not None: - kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} - action_noise.reset(**kwargs) + if action_noise is not None: + kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} + action_noise.reset(**kwargs) - # Log training infos - if log_interval is not None and self._episode_num % log_interval == 0: - self.dump_logs() + # Log training infos + if log_interval is not None and self._episode_num % log_interval == 0: + self.dump_logs() callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 0db5ce5d5..ba5b30a9f 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -233,12 +233,8 @@ def collect_rollouts( # Handle timeout by bootstrapping with value function # see GitHub issue #633 - for idx, done in enumerate(dones): - if ( - done - and infos[idx].get("terminal_observation") is not None - and infos[idx].get("TimeLimit.truncated", False) - ): + for idx in dones.nonzero()[0]: + if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False): terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0] with th.no_grad(): terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type] diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index d1b3ad298..a3a716fe8 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -149,27 +149,28 @@ def update( # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} # to [{}, {terminal_obs: {key1: ..., key2: ...}}] for key in stacked_infos.keys(): - for env_idx in range(len(infos)): + # Optimization: only check for env where done=True + for env_idx in dones.nonzero()[0]: if "terminal_observation" in infos[env_idx]: infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"] return stacked_obs, infos shift = -observations.shape[self.stack_dimension] self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension) - for env_idx, done in enumerate(dones): - if done: - if "terminal_observation" in infos[env_idx]: - old_terminal = infos[env_idx]["terminal_observation"] - if self.channels_first: - previous_stack = self.stacked_obs[env_idx, :shift, ...] - else: - previous_stack = self.stacked_obs[env_idx, ..., :shift] - - new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis) - infos[env_idx]["terminal_observation"] = new_terminal + # Optimization: only check for env where done=True + for env_idx in dones.nonzero()[0]: + if "terminal_observation" in infos[env_idx]: + old_terminal = infos[env_idx]["terminal_observation"] + if self.channels_first: + previous_stack = self.stacked_obs[env_idx, :shift, ...] else: - warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") - self.stacked_obs[env_idx] = 0 + previous_stack = self.stacked_obs[env_idx, ..., :shift] + + new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis) + infos[env_idx]["terminal_observation"] = new_terminal + else: + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stacked_obs[env_idx] = 0 if self.channels_first: self.stacked_obs[:, shift:, ...] = observations else: diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py index 4aa9325f6..b578f1612 100644 --- a/stable_baselines3/common/vec_env/vec_monitor.py +++ b/stable_baselines3/common/vec_env/vec_monitor.py @@ -77,21 +77,20 @@ def step_wait(self) -> VecEnvStepReturn: self.episode_returns += rewards self.episode_lengths += 1 new_infos = list(infos[:]) - for i in range(len(dones)): - if dones[i]: - info = infos[i].copy() - episode_return = self.episode_returns[i] - episode_length = self.episode_lengths[i] - episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} - for key in self.info_keywords: - episode_info[key] = info[key] - info["episode"] = episode_info - self.episode_count += 1 - self.episode_returns[i] = 0 - self.episode_lengths[i] = 0 - if self.results_writer: - self.results_writer.write_row(episode_info) - new_infos[i] = info + for i in dones.nonzero()[0]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} + for key in self.info_keywords: + episode_info[key] = info[key] + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info return obs, rewards, dones, new_infos def close(self) -> None: diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 439243f9f..ed39fa08b 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -197,9 +197,7 @@ def step_wait(self) -> VecEnvStepReturn: rewards = self.normalize_reward(rewards) # Normalize the terminal observations - for idx, done in enumerate(dones): - if not done: - continue + for idx in dones.nonzero()[0]: if "terminal_observation" in infos[idx]: infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 3fade64d1..0f3ba30ed 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -97,9 +97,7 @@ def step_wait(self) -> VecEnvStepReturn: observations, rewards, dones, infos = self.venv.step_wait() # Transpose the terminal observations - for idx, done in enumerate(dones): - if not done: - continue + for idx in dones.nonzero()[0]: if "terminal_observation" in infos[idx]: infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index e914c7ec0..e801ef733 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -162,9 +162,8 @@ def add( # type: ignore[override] super().add(obs, next_obs, action, reward, done, infos) # When episode ends, compute and store the episode length - for env_idx in range(self.n_envs): - if done[env_idx]: - self._compute_episode_length(env_idx) + for env_idx in done.nonzero()[0]: + self._compute_episode_length(env_idx) def _compute_episode_length(self, env_idx: int) -> None: """ diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 5809eab2a..814d58f33 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.6.0a1 +2.6.0a2 From 34cc97c51c51f6185f95f1f0098daf8b5c419a72 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN <antonin.raffin@ensta.org> Date: Mon, 17 Feb 2025 14:00:47 +0100 Subject: [PATCH 2/2] Check for dones first --- stable_baselines3/common/vec_env/stacked_observations.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index a3a716fe8..cd1d212f6 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -148,16 +148,14 @@ def update( # From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]} # to [{}, {terminal_obs: {key1: ..., key2: ...}}] - for key in stacked_infos.keys(): - # Optimization: only check for env where done=True - for env_idx in dones.nonzero()[0]: - if "terminal_observation" in infos[env_idx]: + for env_idx in dones.nonzero()[0]: + if "terminal_observation" in infos[env_idx]: + for key in stacked_infos.keys(): infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"] return stacked_obs, infos shift = -observations.shape[self.stack_dimension] self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension) - # Optimization: only check for env where done=True for env_idx in dones.nonzero()[0]: if "terminal_observation" in infos[env_idx]: old_terminal = infos[env_idx]["terminal_observation"]