Skip to content

Commit 8c78653

Browse files
authored
Add LogEveryNTimesteps callback (#2083)
* Add log every n step callback * Rename _dump_logs() and update changelog * Improve error messages
1 parent c5c29a3 commit 8c78653

9 files changed

+72
-15
lines changed

docs/guide/callbacks.rst

+26-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ Stable Baselines provides you with a set of common callbacks for:
143143
- evaluating the model periodically and saving the best one (:ref:`EvalCallback`)
144144
- chaining callbacks (:ref:`CallbackList`)
145145
- triggering callback on events (:ref:`EventCallback`, :ref:`EveryNTimesteps`)
146+
- logging data every N timesteps (:ref:`LogEveryNTimesteps`)
146147
- stopping the training early based on a reward threshold (:ref:`StopTrainingOnRewardThreshold <StopTrainingCallback>`)
147148

148149

@@ -313,7 +314,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
313314

314315
.. note::
315316

316-
Because of the way ``PPO1`` and ``TRPO`` work (they rely on MPI), ``n_steps`` is a lower bound between two events.
317+
Because of the way ``VecEnv`` work, ``n_steps`` is a lower bound between two events when using multiple environments.
317318

318319

319320
.. code-block:: python
@@ -330,7 +331,30 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t
330331
331332
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
332333
333-
model.learn(int(2e4), callback=event_callback)
334+
model.learn(20_000, callback=event_callback)
335+
336+
.. _LogEveryNTimesteps:
337+
338+
LogEveryNTimesteps
339+
^^^^^^^^^^^^^^^^^^
340+
341+
A callback derived from :ref:`EveryNTimesteps` that will dump the logged data every ``n_steps`` timesteps.
342+
343+
344+
.. code-block:: python
345+
346+
import gymnasium as gym
347+
348+
from stable_baselines3 import PPO
349+
from stable_baselines3.common.callbacks import LogEveryNTimesteps
350+
351+
event_callback = LogEveryNTimesteps(n_steps=1_000)
352+
353+
model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)
354+
355+
# Disable auto-logging by passing `log_interval=None`
356+
model.learn(10_000, callback=event_callback, log_interval=None)
357+
334358
335359
336360
.. _StopTrainingOnMaxEpisodes:

docs/misc/changelog.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.6.0a0 (WIP)
6+
Release 2.6.0a1 (WIP)
77
--------------------------
88

99

@@ -13,6 +13,7 @@ Breaking Changes:
1313
New Features:
1414
^^^^^^^^^^^^^
1515
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
16+
- Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)
1617

1718
Bug Fixes:
1819
^^^^^^^^^^
@@ -29,15 +30,17 @@ Bug Fixes:
2930

3031
Deprecations:
3132
^^^^^^^^^^^^^
33+
- ``algo._dump_logs()`` is deprecated in favor of ``algo.dump_logs()`` and will be removed in SB3 v2.7.0
3234

3335
Others:
3436
^^^^^^^
3537
- Updated black from v24 to v25
38+
- Improved error messages when checking Box space equality (loading ``VecNormalize``)
3639

3740
Documentation:
3841
^^^^^^^^^^^^^^
3942
- Clarify the use of Gym wrappers with ``make_vec_env`` in the section on Vectorized Environments (@pstahlhofen)
40-
43+
- Updated callback doc for ``EveryNTimesteps``
4144

4245
Release 2.5.0 (2025-01-27)
4346
--------------------------

stable_baselines3/common/base_class.py

+10
Original file line numberDiff line numberDiff line change
@@ -865,3 +865,13 @@ def save(
865865
params_to_save = self.get_parameters()
866866

867867
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
868+
869+
def dump_logs(self) -> None:
870+
"""
871+
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
872+
"""
873+
raise NotImplementedError()
874+
875+
def _dump_logs(self, *args) -> None:
876+
warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0")
877+
self.dump_logs(*args)

stable_baselines3/common/callbacks.py

+15
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,21 @@ def _on_step(self) -> bool:
591591
return True
592592

593593

594+
class LogEveryNTimesteps(EveryNTimesteps):
595+
"""
596+
Log data every ``n_steps`` timesteps
597+
598+
:param n_steps: Number of timesteps between two trigger.
599+
"""
600+
601+
def __init__(self, n_steps: int):
602+
super().__init__(n_steps, callback=ConvertCallback(self._log_data))
603+
604+
def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
605+
self.model.dump_logs()
606+
return True
607+
608+
594609
class StopTrainingOnMaxEpisodes(BaseCallback):
595610
"""
596611
Stop the training once a maximum number of episodes are played.

stable_baselines3/common/off_policy_algorithm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@ def _sample_action(
406406
action = buffer_action
407407
return action, buffer_action
408408

409-
def _dump_logs(self) -> None:
409+
def dump_logs(self) -> None:
410410
"""
411-
Write log.
411+
Write log data.
412412
"""
413413
assert self.ep_info_buffer is not None
414414
assert self.ep_success_buffer is not None
@@ -594,7 +594,7 @@ def collect_rollouts(
594594

595595
# Log training infos
596596
if log_interval is not None and self._episode_num % log_interval == 0:
597-
self._dump_logs()
597+
self.dump_logs()
598598
callback.on_rollout_end()
599599

600600
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)

stable_baselines3/common/on_policy_algorithm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def train(self) -> None:
274274
"""
275275
raise NotImplementedError
276276

277-
def _dump_logs(self, iteration: int) -> None:
277+
def dump_logs(self, iteration: int = 0) -> None:
278278
"""
279279
Write log.
280280
@@ -285,7 +285,8 @@ def _dump_logs(self, iteration: int) -> None:
285285

286286
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
287287
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
288-
self.logger.record("time/iterations", iteration, exclude="tensorboard")
288+
if iteration > 0:
289+
self.logger.record("time/iterations", iteration, exclude="tensorboard")
289290
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
290291
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
291292
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
@@ -331,7 +332,7 @@ def learn(
331332
# Display training infos
332333
if log_interval is not None and iteration % log_interval == 0:
333334
assert self.ep_info_buffer is not None
334-
self._dump_logs(iteration)
335+
self.dump_logs(iteration)
335336

336337
self.train()
337338

stable_baselines3/common/utils.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,14 @@ def check_shape_equal(space1: spaces.Space, space2: spaces.Space) -> None:
244244
:param space2: Other space
245245
"""
246246
if isinstance(space1, spaces.Dict):
247-
assert isinstance(space2, spaces.Dict), "spaces must be of the same type"
248-
assert space1.spaces.keys() == space2.spaces.keys(), "spaces must have the same keys"
247+
assert isinstance(space2, spaces.Dict), f"spaces must be of the same type: {type(space1)} != {type(space2)}"
248+
assert (
249+
space1.spaces.keys() == space2.spaces.keys()
250+
), f"spaces must have the same keys: {list(space1.spaces.keys())} != {list(space2.spaces.keys())}"
249251
for key in space1.spaces.keys():
250252
check_shape_equal(space1.spaces[key], space2.spaces[key])
251253
elif isinstance(space1, spaces.Box):
252-
assert space1.shape == space2.shape, "spaces must have the same shape"
254+
assert space1.shape == space2.shape, f"spaces must have the same shape: {space1.shape} != {space2.shape}"
253255

254256

255257
def is_vectorized_box_observation(observation: np.ndarray, observation_space: spaces.Box) -> bool:

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.0a0
1+
2.6.0a1

tests/test_callbacks.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CheckpointCallback,
1414
EvalCallback,
1515
EveryNTimesteps,
16+
LogEveryNTimesteps,
1617
StopTrainingOnMaxEpisodes,
1718
StopTrainingOnNoModelImprovement,
1819
StopTrainingOnRewardThreshold,
@@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class):
6263
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
6364

6465
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
66+
log_callback = LogEveryNTimesteps(n_steps=250)
6567

6668
# Stop training if max number of episodes is reached
6769
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
6870

69-
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes])
71+
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes])
7072
model.learn(500, callback=callback)
7173

7274
# Check access to local variables

0 commit comments

Comments
 (0)