Skip to content

Commit f71d490

Browse files
authored
Upgrade Stable-Baselines3 (#68)
* Update train freq * Upgrade Stable-Baselines3 * Additional fixes * Bump version * Upgrade docker * Pre-process train freq
1 parent 0923702 commit f71d490

11 files changed

+42
-51
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
## Pre-Release 0.11.0a7 (WIP)
1+
## Pre-Release 0.11.1 (2021-02-27)
22

33
### Breaking Changes
44
- Removed `LinearNormalActionNoise`
55
- Evaluation is now deterministic by default, except for Atari games
66
- `sb3_contrib` is now required
77
- `TimeFeatureWrapper` was moved to the contrib repo
88
- Replaced old `plot_train.py` script with updated `plot_training_success.py`
9+
- Renamed ``n_episodes_rollout`` to ``train_freq`` tuple to match latest version of SB3
910

1011
### New Features
1112
- Added option to choose which `VecEnv` class to use for multiprocessing

hyperparams/ddpg.yml

+11-12
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Pendulum-v0:
1414
noise_type: 'normal'
1515
noise_std: 0.1
1616
gradient_steps: -1
17-
n_episodes_rollout: 1
17+
train_freq: [1, "episode"]
1818
learning_rate: !!float 1e-3
1919
policy_kwargs: "dict(net_arch=[400, 300])"
2020

@@ -27,7 +27,7 @@ LunarLanderContinuous-v2:
2727
noise_type: 'normal'
2828
noise_std: 0.1
2929
gradient_steps: -1
30-
n_episodes_rollout: 1
30+
train_freq: [1, "episode"]
3131
learning_rate: !!float 1e-3
3232
policy_kwargs: "dict(net_arch=[400, 300])"
3333

@@ -40,7 +40,7 @@ BipedalWalker-v3:
4040
noise_type: 'normal'
4141
noise_std: 0.1
4242
gradient_steps: -1
43-
n_episodes_rollout: 1
43+
train_freq: [1, "episode"]
4444
learning_rate: !!float 1e-3
4545
policy_kwargs: "dict(net_arch=[400, 300])"
4646

@@ -54,7 +54,7 @@ BipedalWalkerHardcore-v3:
5454
noise_type: 'normal'
5555
noise_std: 0.1
5656
gradient_steps: -1
57-
n_episodes_rollout: 1
57+
train_freq: [1, "episode"]
5858
learning_rate: !!float 1e-3
5959
policy_kwargs: "dict(net_arch=[400, 300])"
6060

@@ -69,7 +69,7 @@ HalfCheetahBulletEnv-v0:
6969
noise_type: 'normal'
7070
noise_std: 0.1
7171
gradient_steps: -1
72-
n_episodes_rollout: 1
72+
train_freq: [1, "episode"]
7373
learning_rate: !!float 1e-3
7474
policy_kwargs: "dict(net_arch=[400, 300])"
7575

@@ -84,7 +84,7 @@ AntBulletEnv-v0:
8484
noise_type: 'normal'
8585
noise_std: 0.1
8686
gradient_steps: -1
87-
n_episodes_rollout: 1
87+
train_freq: [1, "episode"]
8888
learning_rate: !!float 7e-4
8989
policy_kwargs: "dict(net_arch=[400, 300])"
9090

@@ -100,7 +100,6 @@ HopperBulletEnv-v0:
100100
noise_std: 0.1
101101
train_freq: 64
102102
gradient_steps: 64
103-
n_episodes_rollout: -1
104103
batch_size: 256
105104
learning_rate: !!float 7e-4
106105
policy_kwargs: "dict(net_arch=[400, 300])"
@@ -116,7 +115,7 @@ Walker2DBulletEnv-v0:
116115
noise_type: 'normal'
117116
noise_std: 0.1
118117
gradient_steps: -1
119-
n_episodes_rollout: 1
118+
train_freq: [1, "episode"]
120119
batch_size: 256
121120
learning_rate: !!float 7e-4
122121
policy_kwargs: "dict(net_arch=[400, 300])"
@@ -133,7 +132,7 @@ HumanoidBulletEnv-v0:
133132
noise_type: 'normal'
134133
noise_std: 0.1
135134
gradient_steps: -1
136-
n_episodes_rollout: 1
135+
train_freq: [1, "episode"]
137136
learning_rate: !!float 1e-3
138137
policy_kwargs: "dict(net_arch=[400, 300])"
139138

@@ -148,7 +147,7 @@ ReacherBulletEnv-v0:
148147
noise_type: 'normal'
149148
noise_std: 0.1
150149
gradient_steps: -1
151-
n_episodes_rollout: 1
150+
train_freq: [1, "episode"]
152151
learning_rate: !!float 1e-3
153152
policy_kwargs: "dict(net_arch=[400, 300])"
154153

@@ -163,7 +162,7 @@ InvertedDoublePendulumBulletEnv-v0:
163162
noise_type: 'normal'
164163
noise_std: 0.1
165164
gradient_steps: -1
166-
n_episodes_rollout: 1
165+
train_freq: [1, "episode"]
167166
learning_rate: !!float 1e-3
168167
policy_kwargs: "dict(net_arch=[400, 300])"
169168

@@ -178,6 +177,6 @@ InvertedPendulumSwingupBulletEnv-v0:
178177
noise_type: 'normal'
179178
noise_std: 0.1
180179
gradient_steps: -1
181-
n_episodes_rollout: 1
180+
train_freq: [1, "episode"]
182181
learning_rate: !!float 1e-3
183182
policy_kwargs: "dict(net_arch=[400, 300])"

hyperparams/her.yml

+3-4
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ NeckGoalEnvRelativeSparse-v2:
1313
ent_coef: 'auto'
1414
gamma: 0.99
1515
tau: 0.02
16-
n_episodes_rollout: 1
16+
train_freq: [1, "episode"]
1717
gradient_steps: -1
18-
train_freq: -1
18+
1919
# 10 episodes of warm-up
2020
learning_starts: 1500
2121
use_sde_at_warmup: True
@@ -40,9 +40,8 @@ NeckGoalEnvRelativeDense-v2:
4040
ent_coef: 'auto'
4141
gamma: 0.99
4242
tau: 0.02
43-
n_episodes_rollout: 1
43+
train_freq: [1, "episode"]
4444
gradient_steps: -1
45-
train_freq: -1
4645
# 10 episodes of warm-up
4746
learning_starts: 1500
4847
use_sde_at_warmup: True

hyperparams/sac.yml

+4-7
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ NeckEnvRelative-v2:
1616
ent_coef: 'auto'
1717
gamma: 0.99
1818
tau: 0.02
19-
n_episodes_rollout: 1
19+
train_freq: [1, "episode"]
2020
gradient_steps: -1
21-
train_freq: -1
2221
# 10 episodes of warm-up
2322
learning_starts: 3000
2423
use_sde_at_warmup: True
@@ -69,9 +68,8 @@ Pendulum-v0:
6968
policy: 'MlpPolicy'
7069
learning_rate: !!float 1e-3
7170
use_sde: True
72-
n_episodes_rollout: 1
71+
train_freq: [1, "episode"]
7372
gradient_steps: -1
74-
train_freq: -1
7573
policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"
7674

7775
LunarLanderContinuous-v2:
@@ -297,7 +295,7 @@ CarRacing-v0:
297295
gamma: 0.98
298296
tau: 0.02
299297
train_freq: 64
300-
# n_episodes_rollout: 1
298+
# train_freq: [1, "episode"]
301299
gradient_steps: 64
302300
# sde_sample_freq: 64
303301
learning_starts: 1000
@@ -323,8 +321,7 @@ donkey-generated-track-v0:
323321
gamma: 0.99
324322
tau: 0.02
325323
# train_freq: 64
326-
train_freq: -1
327-
n_episodes_rollout: 1
324+
train_freq: [1, "episode"]
328325
# gradient_steps: -1
329326
gradient_steps: 64
330327
learning_starts: 500

hyperparams/td3.yml

+12-12
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Pendulum-v0:
1414
noise_type: 'normal'
1515
noise_std: 0.1
1616
gradient_steps: -1
17-
n_episodes_rollout: 1
17+
train_freq: [1, "episode"]
1818
learning_rate: !!float 1e-3
1919
policy_kwargs: "dict(net_arch=[400, 300])"
2020

@@ -27,7 +27,7 @@ LunarLanderContinuous-v2:
2727
noise_type: 'normal'
2828
noise_std: 0.1
2929
gradient_steps: -1
30-
n_episodes_rollout: 1
30+
train_freq: [1, "episode"]
3131
learning_rate: !!float 1e-3
3232
policy_kwargs: "dict(net_arch=[400, 300])"
3333

@@ -40,7 +40,7 @@ BipedalWalker-v3:
4040
noise_type: 'normal'
4141
noise_std: 0.1
4242
gradient_steps: -1
43-
n_episodes_rollout: 1
43+
train_freq: [1, "episode"]
4444
learning_rate: !!float 1e-3
4545
policy_kwargs: "dict(net_arch=[400, 300])"
4646

@@ -54,7 +54,7 @@ BipedalWalkerHardcore-v3:
5454
noise_type: 'normal'
5555
noise_std: 0.1
5656
gradient_steps: -1
57-
n_episodes_rollout: 1
57+
train_freq: [1, "episode"]
5858
learning_rate: !!float 1e-3
5959
policy_kwargs: "dict(net_arch=[400, 300])"
6060

@@ -69,7 +69,7 @@ HalfCheetahBulletEnv-v0:
6969
noise_type: 'normal'
7070
noise_std: 0.1
7171
gradient_steps: -1
72-
n_episodes_rollout: 1
72+
train_freq: [1, "episode"]
7373
learning_rate: !!float 1e-3
7474
policy_kwargs: "dict(net_arch=[400, 300])"
7575

@@ -83,7 +83,7 @@ AntBulletEnv-v0:
8383
noise_type: 'normal'
8484
noise_std: 0.1
8585
gradient_steps: -1
86-
n_episodes_rollout: 1
86+
train_freq: [1, "episode"]
8787
learning_rate: !!float 1e-3
8888
policy_kwargs: "dict(net_arch=[400, 300])"
8989

@@ -97,7 +97,7 @@ HopperBulletEnv-v0:
9797
noise_type: 'normal'
9898
noise_std: 0.1
9999
gradient_steps: -1
100-
n_episodes_rollout: 1
100+
train_freq: [1, "episode"]
101101
learning_rate: !!float 1e-3
102102
policy_kwargs: "dict(net_arch=[400, 300])"
103103

@@ -111,7 +111,7 @@ Walker2DBulletEnv-v0:
111111
noise_type: 'normal'
112112
noise_std: 0.1
113113
gradient_steps: -1
114-
n_episodes_rollout: 1
114+
train_freq: [1, "episode"]
115115
learning_rate: !!float 1e-3
116116
policy_kwargs: "dict(net_arch=[400, 300])"
117117

@@ -127,7 +127,7 @@ HumanoidBulletEnv-v0:
127127
noise_type: 'normal'
128128
noise_std: 0.1
129129
gradient_steps: -1
130-
n_episodes_rollout: 1
130+
train_freq: [1, "episode"]
131131
learning_rate: !!float 1e-3
132132
policy_kwargs: "dict(net_arch=[400, 300])"
133133

@@ -142,7 +142,7 @@ ReacherBulletEnv-v0:
142142
noise_type: 'normal'
143143
noise_std: 0.1
144144
gradient_steps: -1
145-
n_episodes_rollout: 1
145+
train_freq: [1, "episode"]
146146
learning_rate: !!float 1e-3
147147
policy_kwargs: "dict(net_arch=[400, 300])"
148148

@@ -157,7 +157,7 @@ InvertedDoublePendulumBulletEnv-v0:
157157
noise_type: 'normal'
158158
noise_std: 0.1
159159
gradient_steps: -1
160-
n_episodes_rollout: 1
160+
train_freq: [1, "episode"]
161161
learning_rate: !!float 1e-3
162162
policy_kwargs: "dict(net_arch=[400, 300])"
163163

@@ -172,7 +172,7 @@ InvertedPendulumSwingupBulletEnv-v0:
172172
noise_type: 'normal'
173173
noise_std: 0.1
174174
gradient_steps: -1
175-
n_episodes_rollout: 1
175+
train_freq: [1, "episode"]
176176
learning_rate: !!float 1e-3
177177
policy_kwargs: "dict(net_arch=[400, 300])"
178178

hyperparams/tqc.yml

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Pendulum-v0:
1919
policy: 'MlpPolicy'
2020
learning_rate: !!float 1e-3
2121
use_sde: True
22-
n_episodes_rollout: -1
2322
gradient_steps: 64
2423
train_freq: 64
2524
policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
stable-baselines3[extra,tests,docs]>=0.11.0a4
1+
stable-baselines3[extra,tests,docs]>=0.11.1
22
box2d-py==2.3.8
33
pybullet
44
gym-minigrid
@@ -7,4 +7,4 @@ optuna
77
pytablewriter
88
seaborn
99
pyyaml>=5.1
10-
sb3-contrib>=0.11.0a4
10+
sb3-contrib>=0.11.1

scripts/build_docker.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
PARENT=stablebaselines/stable-baselines3
44

55
TAG=stablebaselines/rl-baselines3-zoo
6-
VERSION=0.11.0a4
6+
VERSION=0.11.1
77

88
if [[ ${USE_GPU} == "True" ]]; then
99
PARENT="${PARENT}:${VERSION}"

utils/exp_manager.py

+4
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def _preprocess_hyperparams(
309309
hyperparams = self._preprocess_her_model_class(hyperparams)
310310
hyperparams = self._preprocess_schedules(hyperparams)
311311

312+
# Pre-process train_freq
313+
if "train_freq" in hyperparams and isinstance(hyperparams["train_freq"], list):
314+
hyperparams["train_freq"] = tuple(hyperparams["train_freq"])
315+
312316
# Should we overwrite the number of timesteps?
313317
if self.n_timesteps > 0:
314318
if self.verbose:

utils/hyperparams_opt.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,10 @@ def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]:
212212
episodic = trial.suggest_categorical("episodic", [True, False])
213213

214214
if episodic:
215-
n_episodes_rollout = 1
216-
train_freq, gradient_steps = -1, -1
215+
train_freq, gradient_steps = (1, "episode"), -1
217216
else:
218217
train_freq = trial.suggest_categorical("train_freq", [1, 16, 128, 256, 1000, 2000])
219218
gradient_steps = train_freq
220-
n_episodes_rollout = -1
221219

222220
noise_type = trial.suggest_categorical("noise_type", ["ornstein-uhlenbeck", "normal", None])
223221
noise_std = trial.suggest_uniform("noise_std", 0, 1)
@@ -241,7 +239,6 @@ def sample_td3_params(trial: optuna.Trial) -> Dict[str, Any]:
241239
"buffer_size": buffer_size,
242240
"train_freq": train_freq,
243241
"gradient_steps": gradient_steps,
244-
"n_episodes_rollout": n_episodes_rollout,
245242
"policy_kwargs": dict(net_arch=net_arch),
246243
}
247244

@@ -274,12 +271,10 @@ def sample_ddpg_params(trial: optuna.Trial) -> Dict[str, Any]:
274271
episodic = trial.suggest_categorical("episodic", [True, False])
275272

276273
if episodic:
277-
n_episodes_rollout = 1
278-
train_freq, gradient_steps = -1, -1
274+
train_freq, gradient_steps = (1, "episode"), -1
279275
else:
280276
train_freq = trial.suggest_categorical("train_freq", [1, 16, 128, 256, 1000, 2000])
281277
gradient_steps = train_freq
282-
n_episodes_rollout = -1
283278

284279
noise_type = trial.suggest_categorical("noise_type", ["ornstein-uhlenbeck", "normal", None])
285280
noise_std = trial.suggest_uniform("noise_std", 0, 1)
@@ -302,7 +297,6 @@ def sample_ddpg_params(trial: optuna.Trial) -> Dict[str, Any]:
302297
"buffer_size": buffer_size,
303298
"train_freq": train_freq,
304299
"gradient_steps": gradient_steps,
305-
"n_episodes_rollout": n_episodes_rollout,
306300
"policy_kwargs": dict(net_arch=net_arch),
307301
}
308302

@@ -337,7 +331,6 @@ def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]:
337331
train_freq = trial.suggest_categorical("train_freq", [1, 4, 8, 16, 128, 256, 1000])
338332
subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8])
339333
gradient_steps = max(train_freq // subsample_steps, 1)
340-
n_episodes_rollout = -1
341334

342335
net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])
343336

@@ -350,7 +343,6 @@ def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]:
350343
"buffer_size": buffer_size,
351344
"train_freq": train_freq,
352345
"gradient_steps": gradient_steps,
353-
"n_episodes_rollout": n_episodes_rollout,
354346
"exploration_fraction": exploration_fraction,
355347
"exploration_final_eps": exploration_final_eps,
356348
"target_update_interval": target_update_interval,

version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.11.0a7
1+
0.11.1

0 commit comments

Comments
 (0)