Skip to content

Commit f421dad

Browse files
vwxyzjnaraffin
andauthored
Add a callback to log raw stats (#216)
* Add a callback to log raw stats * Fixes and use tensorboard output directly * Add test case and changelog * fix CI * Update test Co-authored-by: Antonin Raffin <[email protected]>
1 parent b7c948f commit f421dad

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
- Upgrade to Stable-Baselines3 (SB3) >= 1.4.1a1
55
- Upgrade to sb3-contrib >= 1.4.1a1
66
- Upgraded to gym 0.21
7-
- Support experiment tracking via Weights and Biases (@vwxyzjn)
7+
- Support experiment tracking via Weights and Biases via the `--track` flag (@vwxyzjn)
8+
- Support tracking raw episodic stats via `RawStatisticsCallback` (@vwxyzjn, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/216)
89

910
### New Features
1011
- Verbose mode for each trial (when doing hyperparam optimization) can now be activated using the debug mode (verbose == 2)

tests/test_callbacks.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import subprocess
2+
3+
4+
def _assert_eq(left, right):
5+
assert left == right, f"{left} != {right}"
6+
7+
8+
def test_raw_stat_callback(tmp_path):
9+
args = [
10+
"-n",
11+
str(200),
12+
"--algo",
13+
"ppo",
14+
"--env",
15+
"CartPole-v1",
16+
"-params",
17+
"callback:'utils.callbacks.RawStatisticsCallback'",
18+
"--tensorboard-log",
19+
f"{tmp_path}",
20+
]
21+
22+
return_code = subprocess.call(["python", "train.py"] + args)
23+
_assert_eq(return_code, 0)

utils/callbacks.py

+34
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sb3_contrib import TQC
1111
from stable_baselines3 import SAC
1212
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
13+
from stable_baselines3.common.logger import TensorBoardOutputFormat
1314
from stable_baselines3.common.vec_env import VecEnv
1415

1516

@@ -193,3 +194,36 @@ def _on_training_end(self) -> None:
193194
if self.verbose > 0:
194195
print("Waiting for training thread to terminate")
195196
self.process.join()
197+
198+
199+
class RawStatisticsCallback(BaseCallback):
200+
"""
201+
Callback used for logging raw episode data (return and episode length).
202+
"""
203+
204+
def __init__(self, verbose=0):
205+
super(RawStatisticsCallback, self).__init__(verbose)
206+
# Custom counter to reports stats
207+
# (and avoid reporting multiple values for the same step)
208+
self._timesteps_counter = 0
209+
self._tensorboard_writer = None
210+
211+
def _init_callback(self) -> None:
212+
# Retrieve tensorboard writer to not flood the logger output
213+
for out_format in self.logger.output_formats:
214+
if isinstance(out_format, TensorBoardOutputFormat):
215+
self._tensorboard_writer = out_format
216+
assert self._tensorboard_writer is not None, "You must activate tensorboard logging when using RawStatisticsCallback"
217+
218+
def _on_step(self) -> bool:
219+
for info in self.locals["infos"]:
220+
if "episode" in info:
221+
logger_dict = {
222+
"raw/rollouts/episodic_return": info["episode"]["r"],
223+
"raw/rollouts/episodic_length": info["episode"]["l"],
224+
}
225+
exclude_dict = {key: None for key in logger_dict.keys()}
226+
self._timesteps_counter += info["episode"]["l"]
227+
self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter)
228+
229+
return True

0 commit comments

Comments
 (0)