Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup handling of terminal observations #2086

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
- Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)
- Speed up handling of terminal observations, only the env with finished episodes are checked (the speed up will be noticable on massively parallel env)
- Added Gymnasium v1.1 support

Bug Fixes:
Expand Down
25 changes: 12 additions & 13 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ def _store_transition(
next_obs = deepcopy(new_obs_)
# As the VecEnv resets automatically, new_obs is already the
# first observation of the next episode
for i, done in enumerate(dones):
if done and infos[i].get("terminal_observation") is not None:
for i in dones.nonzero()[0]:
if infos[i].get("terminal_observation") is not None:
if isinstance(next_obs, dict):
next_obs_ = infos[i]["terminal_observation"]
# VecNormalize normalizes the terminal observation
Expand Down Expand Up @@ -582,19 +582,18 @@ def collect_rollouts(
# see https://github.com/hill-a/stable-baselines/issues/900
self._on_step()

for idx, done in enumerate(dones):
if done:
# Update stats
num_collected_episodes += 1
self._episode_num += 1
for idx in dones.nonzero()[0]:
# Update stats
num_collected_episodes += 1
self._episode_num += 1

if action_noise is not None:
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
action_noise.reset(**kwargs)
if action_noise is not None:
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
action_noise.reset(**kwargs)

# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self.dump_logs()
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self.dump_logs()
callback.on_rollout_end()

return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
8 changes: 2 additions & 6 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,8 @@ def collect_rollouts(

# Handle timeout by bootstrapping with value function
# see GitHub issue #633
for idx, done in enumerate(dones):
if (
done
and infos[idx].get("terminal_observation") is not None
and infos[idx].get("TimeLimit.truncated", False)
):
for idx in dones.nonzero()[0]:
if infos[idx].get("terminal_observation") is not None and infos[idx].get("TimeLimit.truncated", False):
terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
with th.no_grad():
terminal_value = self.policy.predict_values(terminal_obs)[0] # type: ignore[arg-type]
Expand Down
31 changes: 15 additions & 16 deletions stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,28 +148,27 @@ def update(

# From {key1: [{}, {terminal_obs: ...}], key2: [{}, {terminal_obs: ...}]}
# to [{}, {terminal_obs: {key1: ..., key2: ...}}]
for key in stacked_infos.keys():
for env_idx in range(len(infos)):
if "terminal_observation" in infos[env_idx]:
for env_idx in dones.nonzero()[0]:
if "terminal_observation" in infos[env_idx]:
for key in stacked_infos.keys():
infos[env_idx]["terminal_observation"][key] = stacked_infos[key][env_idx]["terminal_observation"]
return stacked_obs, infos

shift = -observations.shape[self.stack_dimension]
self.stacked_obs = np.roll(self.stacked_obs, shift, axis=self.stack_dimension)
for env_idx, done in enumerate(dones):
if done:
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
previous_stack = self.stacked_obs[env_idx, ..., :shift]

new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
for env_idx in dones.nonzero()[0]:
if "terminal_observation" in infos[env_idx]:
old_terminal = infos[env_idx]["terminal_observation"]
if self.channels_first:
previous_stack = self.stacked_obs[env_idx, :shift, ...]
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stacked_obs[env_idx] = 0
previous_stack = self.stacked_obs[env_idx, ..., :shift]

new_terminal = np.concatenate((previous_stack, old_terminal), axis=self.repeat_axis)
infos[env_idx]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stacked_obs[env_idx] = 0
if self.channels_first:
self.stacked_obs[:, shift:, ...] = observations
else:
Expand Down
29 changes: 14 additions & 15 deletions stable_baselines3/common/vec_env/vec_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,20 @@ def step_wait(self) -> VecEnvStepReturn:
self.episode_returns += rewards
self.episode_lengths += 1
new_infos = list(infos[:])
for i in range(len(dones)):
if dones[i]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
for i in dones.nonzero()[0]:
info = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
episode_info[key] = info[key]
info["episode"] = episode_info
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.results_writer:
self.results_writer.write_row(episode_info)
new_infos[i] = info
return obs, rewards, dones, new_infos

def close(self) -> None:
Expand Down
4 changes: 1 addition & 3 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def step_wait(self) -> VecEnvStepReturn:
rewards = self.normalize_reward(rewards)

# Normalize the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
for idx in dones.nonzero()[0]:
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"])

Expand Down
4 changes: 1 addition & 3 deletions stable_baselines3/common/vec_env/vec_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def step_wait(self) -> VecEnvStepReturn:
observations, rewards, dones, infos = self.venv.step_wait()

# Transpose the terminal observations
for idx, done in enumerate(dones):
if not done:
continue
for idx in dones.nonzero()[0]:
if "terminal_observation" in infos[idx]:
infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"])

Expand Down
5 changes: 2 additions & 3 deletions stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,8 @@ def add( # type: ignore[override]
super().add(obs, next_obs, action, reward, done, infos)

# When episode ends, compute and store the episode length
for env_idx in range(self.n_envs):
if done[env_idx]:
self._compute_episode_length(env_idx)
for env_idx in done.nonzero()[0]:
self._compute_episode_length(env_idx)

def _compute_episode_length(self, env_idx: int) -> None:
"""
Expand Down