Skip to content

Commit c7763b7

Browse files
authored
Add TQC support (#46)
* Add TQC from sb3-contrib * Update plot script * Tuned hyperparams for Hopper * Update plot * Update hyperparams * Update plot script: allow to merge files * Make pytype happy * Update humanoid params * Revert Humanoids params * Fix deps * Fixes * Add support for HER + TQC
1 parent 26dfece commit c7763b7

File tree

9 files changed

+255
-18
lines changed

9 files changed

+255
-18
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### New Features
66
- Added support for `HER`
77
- Added low-pass filter wrappers in `utils/wrappers.py`
8+
- Added `TQC` support, implementation from sb3-contrib
89

910
### Bug fixes
1011
- Fixed `TimeFeatureWrapper` inferring max timesteps

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pytest:
66

77
# Type check
88
type:
9-
pytype ${LINT_PATHS}
9+
pytype -j auto ${LINT_PATHS}
1010

1111
lint:
1212
# stop the build if there are Python syntax errors or undefined names

enjoy.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def main(): # noqa: C901
6565

6666
if args.exp_id == 0:
6767
args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
68-
print("Loading latest experiment, id={}".format(args.exp_id))
68+
print(f"Loading latest experiment, id={args.exp_id}")
6969

7070
# Sanity checks
7171
if args.exp_id > 0:
72-
log_path = os.path.join(folder, algo, "{}_{}".format(env_id, args.exp_id))
72+
log_path = os.path.join(folder, algo, f"{env_id}_{args.exp_id}")
7373
else:
7474
log_path = os.path.join(folder, algo)
7575

@@ -93,7 +93,7 @@ def main(): # noqa: C901
9393
if not found:
9494
raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}")
9595

96-
if algo in ["dqn", "ddpg", "sac", "td3"]:
96+
if algo in ["dqn", "ddpg", "sac", "td3", "tqc"]:
9797
args.n_envs = 1
9898

9999
set_random_seed(args.seed)
@@ -134,7 +134,7 @@ def main(): # noqa: C901
134134
)
135135

136136
kwargs = dict(seed=args.seed)
137-
if algo in ["dqn", "ddpg", "sac", "her", "td3"]:
137+
if algo in ["dqn", "ddpg", "sac", "her", "td3", "tqc"]:
138138
# Dummy buffer size as we don't need memory to enjoy the trained agent
139139
kwargs.update(dict(buffer_size=1))
140140

@@ -143,7 +143,7 @@ def main(): # noqa: C901
143143
obs = env.reset()
144144

145145
# Force deterministic for DQN, DDPG, SAC and HER (that is a wrapper around)
146-
deterministic = args.deterministic or algo in ["dqn", "ddpg", "sac", "her", "td3"] and not args.stochastic
146+
deterministic = args.deterministic or algo in ["dqn", "ddpg", "sac", "her", "td3", "tqc"] and not args.stochastic
147147

148148
state = None
149149
episode_reward = 0.0

hyperparams/tqc.yml

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Tuned
2+
MountainCarContinuous-v0:
3+
n_timesteps: !!float 50000
4+
policy: 'MlpPolicy'
5+
learning_rate: !!float 3e-4
6+
buffer_size: 50000
7+
batch_size: 512
8+
ent_coef: 0.1
9+
train_freq: 32
10+
gradient_steps: 32
11+
gamma: 0.9999
12+
tau: 0.01
13+
learning_starts: 0
14+
use_sde: True
15+
policy_kwargs: "dict(log_std_init=-3.67, net_arch=[64, 64])"
16+
17+
Pendulum-v0:
18+
n_timesteps: 20000
19+
policy: 'MlpPolicy'
20+
learning_rate: !!float 1e-3
21+
use_sde: True
22+
n_episodes_rollout: 1
23+
gradient_steps: -1
24+
train_freq: -1
25+
policy_kwargs: "dict(log_std_init=-2, net_arch=[64, 64])"
26+
27+
LunarLanderContinuous-v2:
28+
n_timesteps: !!float 5e5
29+
policy: 'MlpPolicy'
30+
batch_size: 256
31+
learning_starts: 1000
32+
33+
BipedalWalker-v3:
34+
n_timesteps: !!float 5e5
35+
policy: 'MlpPolicy'
36+
learning_rate: !!float 7.3e-4
37+
buffer_size: 300000
38+
batch_size: 256
39+
ent_coef: 'auto'
40+
gamma: 0.98
41+
tau: 0.02
42+
train_freq: 64
43+
gradient_steps: 64
44+
learning_starts: 10000
45+
use_sde: True
46+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
47+
48+
# Almost tuned
49+
# History wrapper of size 2 for better performances
50+
BipedalWalkerHardcore-v3:
51+
n_timesteps: !!float 2e6
52+
policy: 'MlpPolicy'
53+
learning_rate: lin_7.3e-4
54+
buffer_size: 1000000
55+
batch_size: 256
56+
ent_coef: 'auto'
57+
gamma: 0.99
58+
tau: 0.01
59+
train_freq: 64
60+
gradient_steps: 64
61+
learning_starts: 10000
62+
use_sde: True
63+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300], use_expln=True)"
64+
65+
# === Bullet envs ===
66+
67+
# Tuned
68+
HalfCheetahBulletEnv-v0:
69+
env_wrapper: utils.wrappers.TimeFeatureWrapper
70+
n_timesteps: !!float 1e6
71+
policy: 'MlpPolicy'
72+
learning_rate: !!float 7.3e-4
73+
buffer_size: 300000
74+
batch_size: 256
75+
ent_coef: 'auto'
76+
gamma: 0.98
77+
tau: 0.02
78+
train_freq: 64
79+
gradient_steps: 64
80+
learning_starts: 10000
81+
use_sde: True
82+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
83+
84+
# Tuned
85+
AntBulletEnv-v0:
86+
env_wrapper: utils.wrappers.TimeFeatureWrapper
87+
n_timesteps: !!float 1e6
88+
policy: 'MlpPolicy'
89+
learning_rate: !!float 7.3e-4
90+
buffer_size: 300000
91+
batch_size: 256
92+
ent_coef: 'auto'
93+
gamma: 0.98
94+
tau: 0.02
95+
train_freq: 64
96+
gradient_steps: 64
97+
learning_starts: 10000
98+
use_sde: True
99+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
100+
101+
# Tuned
102+
HopperBulletEnv-v0:
103+
env_wrapper: utils.wrappers.TimeFeatureWrapper
104+
n_timesteps: !!float 1e6
105+
policy: 'MlpPolicy'
106+
learning_rate: lin_7.3e-4
107+
buffer_size: 300000
108+
batch_size: 256
109+
ent_coef: 'auto'
110+
gamma: 0.98
111+
tau: 0.02
112+
train_freq: 64
113+
gradient_steps: 64
114+
learning_starts: 10000
115+
use_sde: True
116+
top_quantiles_to_drop_per_net: 5
117+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
118+
119+
# Tuned
120+
Walker2DBulletEnv-v0:
121+
env_wrapper: utils.wrappers.TimeFeatureWrapper
122+
n_timesteps: !!float 1e6
123+
policy: 'MlpPolicy'
124+
learning_rate: lin_7.3e-4
125+
buffer_size: 300000
126+
batch_size: 256
127+
ent_coef: 'auto'
128+
gamma: 0.98
129+
tau: 0.02
130+
train_freq: 64
131+
gradient_steps: 64
132+
learning_starts: 10000
133+
use_sde: True
134+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
135+
136+
137+
ReacherBulletEnv-v0:
138+
env_wrapper: utils.wrappers.TimeFeatureWrapper
139+
n_timesteps: !!float 3e5
140+
policy: 'MlpPolicy'
141+
learning_rate: !!float 7.3e-4
142+
buffer_size: 300000
143+
batch_size: 256
144+
ent_coef: 'auto'
145+
gamma: 0.98
146+
tau: 0.02
147+
train_freq: 64
148+
gradient_steps: 64
149+
learning_starts: 10000
150+
use_sde: True
151+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
152+
153+
154+
# Almost tuned
155+
HumanoidBulletEnv-v0:
156+
env_wrapper: utils.wrappers.TimeFeatureWrapper
157+
n_timesteps: !!float 1e7
158+
policy: 'MlpPolicy'
159+
learning_rate: lin_7.3e-4
160+
buffer_size: 300000
161+
batch_size: 256
162+
ent_coef: 'auto'
163+
gamma: 0.98
164+
tau: 0.02
165+
train_freq: 64
166+
gradient_steps: 64
167+
learning_starts: 10000
168+
top_quantiles_to_drop_per_net: 5
169+
use_sde: True
170+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
171+
172+
InvertedDoublePendulumBulletEnv-v0:
173+
env_wrapper: utils.wrappers.TimeFeatureWrapper
174+
n_timesteps: !!float 5e5
175+
policy: 'MlpPolicy'
176+
learning_rate: !!float 7.3e-4
177+
buffer_size: 300000
178+
batch_size: 256
179+
ent_coef: 'auto'
180+
gamma: 0.98
181+
tau: 0.02
182+
train_freq: 64
183+
gradient_steps: 64
184+
learning_starts: 10000
185+
use_sde: True
186+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"
187+
188+
InvertedPendulumSwingupBulletEnv-v0:
189+
env_wrapper: utils.wrappers.TimeFeatureWrapper
190+
n_timesteps: !!float 3e5
191+
policy: 'MlpPolicy'
192+
learning_rate: !!float 7.3e-4
193+
buffer_size: 300000
194+
batch_size: 256
195+
ent_coef: 'auto'
196+
gamma: 0.98
197+
tau: 0.02
198+
train_freq: 64
199+
gradient_steps: 64
200+
learning_starts: 10000
201+
use_sde: True
202+
policy_kwargs: "dict(log_std_init=-3, net_arch=[400, 300])"

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
stable-baselines3[extra,tests,docs]>=0.10.0a0
1+
stable-baselines3[extra,tests,docs]>=0.10.0a1
22
box2d-py==2.3.5
33
pybullet
44
gym-minigrid
@@ -7,3 +7,4 @@ optuna
77
pytablewriter
88
seaborn
99
pyyaml>=5.1
10+
sb3-contrib>=0.10.0a1

scripts/plot_from_file.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
3030
parser = argparse.ArgumentParser("Gather results, plot them and create table")
3131
parser.add_argument("-i", "--input", help="Input filename (numpy archive)", type=str)
3232
parser.add_argument("-skip", "--skip-envs", help="Environments to skip", nargs="+", default=[], type=str)
33+
parser.add_argument("--keep-envs", help="Envs to keep", nargs="+", default=[], type=str)
34+
parser.add_argument("--skip-keys", help="Keys to skip", nargs="+", default=[], type=str)
35+
parser.add_argument("--keep-keys", help="Keys to keep", nargs="+", default=[], type=str)
3336
parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million")
3437
parser.add_argument("--skip-timesteps", action="store_true", default=False, help="Do not display learning curves")
3538
parser.add_argument("-o", "--output", help="Output filename (image)", type=str)
@@ -40,6 +43,7 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
4043
parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+")
4144
parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False)
4245
parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)
46+
parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str)
4347

4448
args = parser.parse_args()
4549

@@ -69,8 +73,26 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
6973

7074
del results["results_table"]
7175

72-
keys = [key for key in results[list(results.keys())[0]].keys()]
76+
for filename in args.merge:
77+
# Merge other files
78+
with open(filename, "rb") as file_handler:
79+
results_2 = pickle.load(file_handler)
80+
del results_2["results_table"]
81+
for key in results.keys():
82+
if key in results_2:
83+
for new_key in results_2[key].keys():
84+
results[key][new_key] = results_2[key][new_key]
85+
86+
87+
keys = [key for key in results[list(results.keys())[0]].keys() if key not in args.skip_keys]
88+
print(f"keys: {keys}")
89+
if len(args.keep_keys) > 0:
90+
keys = [key for key in keys if key in args.keep_keys]
7391
envs = [env for env in results.keys() if env not in args.skip_envs]
92+
93+
if len(args.keep_envs) > 0:
94+
envs = [env for env in envs if env in args.keep_envs]
95+
7496
labels = {key: key for key in keys}
7597
if args.labels is not None:
7698
for key, label in zip(keys, args.labels):
@@ -129,9 +151,10 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
129151
# plt.title('Influence of the time feature', fontsize=args.fontsize)
130152
# plt.title('Influence of the network architecture', fontsize=args.fontsize)
131153
# plt.title('Influence of the exploration variance $log \sigma$', fontsize=args.fontsize)
132-
plt.title("Influence of the sampling frequency", fontsize=args.fontsize)
154+
# plt.title("Influence of the sampling frequency", fontsize=args.fontsize)
133155
# plt.title('Parallel vs No Parallel Sampling', fontsize=args.fontsize)
134156
# plt.title('Influence of the exploration function input', fontsize=args.fontsize)
157+
plt.title("PyBullet envs", fontsize=args.fontsize)
135158
plt.xticks(fontsize=13)
136159
plt.xlabel("Environment", fontsize=args.fontsize)
137160
plt.ylabel("Score", fontsize=args.fontsize)

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
# HER is only a wrapper around an algo
171171
if args.algo == "her":
172172
algo_ = saved_hyperparams["model_class"]
173-
assert algo_ in {"sac", "ddpg", "dqn", "td3"}, f"{algo_} is not compatible with HER"
173+
assert algo_ in {"sac", "ddpg", "dqn", "td3", "tqc"}, f"{algo_} is not compatible with HER"
174174
# Retrieve the model class
175175
hyperparams["model_class"] = ALGOS[saved_hyperparams["model_class"]]
176176

utils/utils.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,31 @@
66

77
import gym
88
import yaml
9-
10-
# from stable_baselines3.common import logger
119
from stable_baselines3 import A2C, DDPG, DQN, HER, PPO, SAC, TD3
1210
from stable_baselines3.common.monitor import Monitor
13-
14-
# from stable_baselines3.common.cmd_util import make_atari_env
1511
from stable_baselines3.common.utils import set_random_seed
1612
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
1713

14+
try:
15+
from sb3_contrib import TQC # pytype: disable=import-error
16+
except ImportError:
17+
TQC = None
18+
1819
# For custom activation fn
1920
from torch import nn as nn # noqa: F401 pylint: disable=unused-import
2021

21-
ALGOS = {"a2c": A2C, "ddpg": DDPG, "dqn": DQN, "her": HER, "ppo": PPO, "sac": SAC, "td3": TD3}
22+
ALGOS = {
23+
"a2c": A2C,
24+
"ddpg": DDPG,
25+
"dqn": DQN,
26+
"ppo": PPO,
27+
"her": HER,
28+
"sac": SAC,
29+
"td3": TD3,
30+
}
31+
32+
if TQC is not None:
33+
ALGOS["tqc"] = TQC
2234

2335

2436
def flatten_dict_observations(env):

utils/wrappers.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import gym
22
import numpy as np
33
from matplotlib import pyplot as plt
4+
from scipy.signal import iirfilter, sosfilt, zpk2sos
45

56

67
class DoneOnSuccessWrapper(gym.Wrapper):
@@ -170,9 +171,6 @@ def step(self, action):
170171

171172

172173
# from https://docs.obspy.org
173-
from scipy.signal import iirfilter, sosfilt, zpk2sos
174-
175-
176174
def lowpass(data, freq, df, corners=4, zerophase=False):
177175
"""
178176
Butterworth-Lowpass Filter.

0 commit comments

Comments
 (0)