|
10 | 10 | from sb3_contrib import TQC
|
11 | 11 | from stable_baselines3 import SAC
|
12 | 12 | from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
| 13 | +from stable_baselines3.common.logger import TensorBoardOutputFormat |
13 | 14 | from stable_baselines3.common.vec_env import VecEnv
|
14 | 15 |
|
15 | 16 |
|
@@ -193,3 +194,36 @@ def _on_training_end(self) -> None:
|
193 | 194 | if self.verbose > 0:
|
194 | 195 | print("Waiting for training thread to terminate")
|
195 | 196 | self.process.join()
|
| 197 | + |
| 198 | + |
| 199 | +class RawStatisticsCallback(BaseCallback): |
| 200 | + """ |
| 201 | + Callback used for logging raw episode data (return and episode length). |
| 202 | + """ |
| 203 | + |
| 204 | + def __init__(self, verbose=0): |
| 205 | + super(RawStatisticsCallback, self).__init__(verbose) |
| 206 | + # Custom counter to reports stats |
| 207 | + # (and avoid reporting multiple values for the same step) |
| 208 | + self._timesteps_counter = 0 |
| 209 | + self._tensorboard_writer = None |
| 210 | + |
| 211 | + def _init_callback(self) -> None: |
| 212 | + # Retrieve tensorboard writer to not flood the logger output |
| 213 | + for out_format in self.logger.output_formats: |
| 214 | + if isinstance(out_format, TensorBoardOutputFormat): |
| 215 | + self._tensorboard_writer = out_format |
| 216 | + assert self._tensorboard_writer is not None, "You must activate tensorboard logging when using RawStatisticsCallback" |
| 217 | + |
| 218 | + def _on_step(self) -> bool: |
| 219 | + for info in self.locals["infos"]: |
| 220 | + if "episode" in info: |
| 221 | + logger_dict = { |
| 222 | + "raw/rollouts/episodic_return": info["episode"]["r"], |
| 223 | + "raw/rollouts/episodic_length": info["episode"]["l"], |
| 224 | + } |
| 225 | + exclude_dict = {key: None for key in logger_dict.keys()} |
| 226 | + self._timesteps_counter += info["episode"]["l"] |
| 227 | + self._tensorboard_writer.write(logger_dict, exclude_dict, self._timesteps_counter) |
| 228 | + |
| 229 | + return True |
0 commit comments