Skip to content

Commit 8e3744d

Browse files
vwxyzjnaraffin
andauthored
Support experiment tracking with W&B (#213)
* Support experiment tracking with W&B * Quick fix * Fix CI * Update train.py Co-authored-by: Antonin RAFFIN <[email protected]> * fix CI * Add documentation * Update CHANGELOG.md Co-authored-by: Antonin RAFFIN <[email protected]> * Address comments * Update Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 0e00446 commit 8e3744d

File tree

6 files changed

+60
-7
lines changed

6 files changed

+60
-7
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ rl-trained_agents/
1313
htmlcov/
1414
git_rewrite_commit_history.sh
1515
.vscode/
16+
wandb
17+
runs

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Upgrade to Stable-Baselines3 (SB3) >= 1.4.1a1
55
- Upgrade to sb3-contrib >= 1.4.1a1
66
- Upgraded to gym 0.21
7+
- Support experiment tracking via Weights and Biases (@vwxyzjn)
78

89
### New Features
910

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ Note that the default hyperparameters used in the zoo when tuning are not always
196196

197197
When working with continuous actions, we recommend to enable [gSDE](https://arxiv.org/abs/2005.05719) by uncommenting lines in [utils/hyperparams_opt.py](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/utils/hyperparams_opt.py).
198198

199+
200+
## Experiment tracking
201+
202+
We support tracking experiment data such as learning curves and hyperparameters via [Weights and Biases](https://wandb.ai).
203+
204+
The following command
205+
```
206+
python train.py --algo ppo --env CartPole-v1 --track --wandb-project-name sb3
207+
```
208+
209+
yields a tracked experiment at this [URL](https://wandb.ai/openrlbenchmark/sb3/runs/1b65ldmh).
210+
211+
212+
199213
## Env normalization
200214
201215
In the hyperparameter file, `normalize: True` means that the training environment will be wrapped in a [VecNormalize](https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_normalize.py#L13) wrapper.

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ cloudpickle>=1.5.0
1313
plotly
1414
panda-gym==1.1.1 # tmp fix: until compatibility with panda-gym v2
1515
rliable>=1.0.5
16+
wandb

train.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import difflib
33
import importlib
44
import os
5+
import time
56
import uuid
67

78
import gym
@@ -114,6 +115,14 @@
114115
help="Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)",
115116
)
116117
parser.add_argument("-uuid", "--uuid", action="store_true", default=False, help="Ensure that the run has a unique ID")
118+
parser.add_argument(
119+
"--track",
120+
action="store_true",
121+
default=False,
122+
help="if toggled, this experiment will be tracked with Weights and Biases",
123+
)
124+
parser.add_argument("--wandb-project-name", type=str, default="sb3", help="the wandb's project name")
125+
parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project")
117126
args = parser.parse_args()
118127

119128
# Going through custom gym packages to let them register in the global registory
@@ -153,6 +162,26 @@
153162
print("=" * 10, env_id, "=" * 10)
154163
print(f"Seed: {args.seed}")
155164

165+
if args.track:
166+
try:
167+
import wandb
168+
except ImportError:
169+
raise ImportError(
170+
"if you want to use Weights & Biases to track experiment, please install W&B via `pip install wandb`"
171+
)
172+
173+
run_name = f"{args.env}__{args.algo}__{args.seed}__{int(time.time())}"
174+
run = wandb.init(
175+
name=run_name,
176+
project=args.wandb_project_name,
177+
entity=args.wandb_entity,
178+
config=vars(args),
179+
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
180+
monitor_gym=True, # auto-upload the videos of agents playing the game
181+
save_code=True, # optional
182+
)
183+
args.tensorboard_log = f"runs/{run_name}"
184+
156185
exp_manager = ExperimentManager(
157186
args,
158187
args.algo,
@@ -188,11 +217,17 @@
188217
)
189218

190219
# Prepare experiment and launch hyperparameter optimization if needed
191-
model = exp_manager.setup_experiment()
220+
results = exp_manager.setup_experiment()
221+
if results is not None:
222+
model, saved_hyperparams = results
223+
if args.track:
224+
# we need to save the loaded hyperparameters
225+
args.saved_hyperparams = saved_hyperparams
226+
run.config.setdefaults(vars(args))
192227

193-
# Normal training
194-
if model is not None:
195-
exp_manager.learn(model)
196-
exp_manager.save_trained_model(model)
228+
# Normal training
229+
if model is not None:
230+
exp_manager.learn(model)
231+
exp_manager.save_trained_model(model)
197232
else:
198233
exp_manager.hyperparameters_optimization()

utils/exp_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
)
154154
self.params_path = f"{self.save_path}/{self.env_id}"
155155

156-
def setup_experiment(self) -> Optional[BaseAlgorithm]:
156+
def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
157157
"""
158158
Read hyperparameters, pre-process them (create schedules, wrappers, callbacks, action noise objects)
159159
create the environment and possibly the model.
@@ -187,7 +187,7 @@ def setup_experiment(self) -> Optional[BaseAlgorithm]:
187187
)
188188

189189
self._save_config(saved_hyperparams)
190-
return model
190+
return model, saved_hyperparams
191191

192192
def learn(self, model: BaseAlgorithm) -> None:
193193
"""

0 commit comments

Comments
 (0)