Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogEveryNTimesteps callback #2083

Merged
merged 3 commits into from
Feb 14, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rename _dump_logs() and update changelog
araffin committed Feb 11, 2025
commit a8d71d6ef4d958dae4e127a853400cfa68640094
28 changes: 26 additions & 2 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
@@ -143,6 +143,7 @@ Stable Baselines provides you with a set of common callbacks for:
- evaluating the model periodically and saving the best one (:ref:`EvalCallback`)
- chaining callbacks (:ref:`CallbackList`)
- triggering callback on events (:ref:`EventCallback`, :ref:`EveryNTimesteps`)
- logging data every N timesteps (:ref:`LogEveryNTimesteps`)
- stopping the training early based on a reward threshold (:ref:`StopTrainingOnRewardThreshold <StopTrainingCallback>`)


@@ -313,7 +314,7 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t

.. note::

Because of the way ``PPO1`` and ``TRPO`` work (they rely on MPI), ``n_steps`` is a lower bound between two events.
Because of the way ``VecEnv`` work, ``n_steps`` is a lower bound between two events when using multiple environments.


.. code-block:: python
@@ -330,7 +331,30 @@ An :ref:`EventCallback` that will trigger its child callback every ``n_steps`` t

model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)

model.learn(int(2e4), callback=event_callback)
model.learn(20_000, callback=event_callback)

.. _LogEveryNTimesteps:

LogEveryNTimesteps
^^^^^^^^^^^^^^^^^^

A callback derived from :ref:`EveryNTimesteps` that will dump the logged data every ``n_steps`` timesteps.


.. code-block:: python

import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import LogEveryNTimesteps

event_callback = LogEveryNTimesteps(n_steps=1_000)

model = PPO("MlpPolicy", "Pendulum-v1", verbose=1)

# Disable auto-logging by passing `log_interval=None`
model.learn(10_000, callback=event_callback, log_interval=None)



.. _StopTrainingOnMaxEpisodes:
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
Changelog
==========

Release 2.6.0a0 (WIP)
Release 2.6.0a1 (WIP)
--------------------------


@@ -13,6 +13,7 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
- Added ``LogEveryNTimesteps`` callback to dump logs every N timesteps (note: you need to pass ``log_interval=None`` to avoid any interference)

Bug Fixes:
^^^^^^^^^^
@@ -29,6 +30,7 @@ Bug Fixes:

Deprecations:
^^^^^^^^^^^^^
- ``algo._dump_logs()`` is deprecated in favor of ``algo.dump_logs()`` and will be removed in SB3 v2.7.0

Others:
^^^^^^^
8 changes: 6 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
@@ -866,8 +866,12 @@ def save(

save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)

@abstractmethod
def _dump_logs(self) -> None:
def dump_logs(self) -> None:
"""
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
"""
raise NotImplementedError()

def _dump_logs(self, *args) -> None:
warnings.warn("algo._dump_logs() is deprecated in favor of algo.dump_logs(). It will be removed in SB3 v2.7.0")
self.dump_logs(*args)
2 changes: 1 addition & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
@@ -602,7 +602,7 @@ def __init__(self, n_steps: int):
super().__init__(n_steps, callback=ConvertCallback(self._log_data))

def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
self.model._dump_logs()
self.model.dump_logs()
return True


4 changes: 2 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -406,7 +406,7 @@ def _sample_action(
action = buffer_action
return action, buffer_action

def _dump_logs(self) -> None:
def dump_logs(self) -> None:
"""
Write log data.
"""
@@ -594,7 +594,7 @@ def collect_rollouts(

# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
self.dump_logs()
callback.on_rollout_end()

return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
4 changes: 2 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -274,7 +274,7 @@ def train(self) -> None:
"""
raise NotImplementedError

def _dump_logs(self, iteration: int = 0) -> None:
def dump_logs(self, iteration: int = 0) -> None:
"""
Write log.

@@ -332,7 +332,7 @@ def learn(
# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
self._dump_logs(iteration)
self.dump_logs(iteration)

self.train()

2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.0a0
2.6.0a1