Skip to content

Commit

Permalink
train_algos bash script, train_gcse_manual python script, yaml configs
Browse files Browse the repository at this point in the history
  • Loading branch information
felimomo committed Feb 26, 2024
1 parent 5c08c59 commit 72e2cf8
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 31 deletions.
1 change: 0 additions & 1 deletion hyperpars/gcse-boilerplate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

env_id: "gcsenv"
config: {}
n_envs: 12
tensorboard: "/home/rstudio/logs"
use_sde: True
id: "1"
Expand Down
11 changes: 11 additions & 0 deletions scripts/train_algos.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

# move to script directory for normalized relative paths.
scriptdir="$(dirname "$0")"
cd "$scriptdir"

python train_gcse_manual.py -t 1000000 -a ppo -ne 10 &
python train_gcse_manual.py -t 1000000 -a rppo -ne 10 &
python train_gcse_manual.py -t 1000000 -a her -ne 10 &
python train_gcse_manual.py -t 1000000 -a tqc -ne 10 &
python train_gcse_manual.py -t 1000000 -a td3 &
11 changes: 0 additions & 11 deletions scripts/train_benchmarks.sh

This file was deleted.

41 changes: 41 additions & 0 deletions scripts/train_gcse_manual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/opt/venv/bin/python
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--algo", help="Algo to train", type=str,
choices=[
'PPO', 'RecurrentPPO', 'ppo', 'recurrentppo',
'ARS', 'A2C', 'ars', 'a2c',
'DDPG', 'ddpg',
'HER', 'her',
'SAC', 'sac'
'TD3', 'td3',
'TQC', 'tqc',
]
)
parser.add_argument("-t", "--time-steps", help="N. timesteps to train for", type=int)
parser.add_argument(
"-ne",
"--n-envs",
help="Number of envs to use simultaneously for faster training. "
"Check algos for compatibility with this arg.",
type=int,
)

args = parser.parse_args()

manual_kwargs = {}
if args.algo:
manual_kwargs['algo'] = args.algo
if args.time_steps:
manual_kwargs['time_steps'] = args.time_steps
if args.n_envs:
manual_kwargs['n_envs'] = args.n_envs

import os
boilerplate_cfg = os.path.join("..", "hyperpars", "gcse-boilerplate.yml")


import rl4greencrab
from rl4greencrab import sb3_train

sb3_train(boilerplate_cfg, **manual_kwargs)
53 changes: 34 additions & 19 deletions src/rl4greencrab/utils/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,37 @@

import gymnasium as gym
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO, A2C, DQN, SAC, TD3
from sb3_contrib import TQC, ARS
from stable_baselines3 import PPO, A2C, DQN, SAC, TD3, HER
from sb3_contrib import TQC, ARS, RecurrentPPO

def algorithm(algo):
algos = {
"PPO": PPO,
"ARS": ARS,
"TQC": TQC,
"A2C": A2C,
"SAC": SAC,
"DQN": DQN,
"TD3": TD3,
"ppo": PPO,
"ars": ARS,
"tqc": TQC,
"a2c": A2C,
"sac": SAC,
"dqn": DQN,
"td3": TD3,
'PPO': PPO,
'ppo': PPO,
'RecurrentPPO': RecurrentPPO,
'RPPO': RecurrentPPO,
'recurrentppo': RecurrentPPO,
'rppo': RecurrentPPO,
#
'ARS': ARS,
'ars': ARS,
'A2C': A2C,
'a2c':A2C ,
#
'DDPG': DDPG,
'ddpg': DDPG,
#
'HER': HER,
'her': HER,
#
'SAC': SAC,
'sac': SAC,
#
'TD3': TD3,
'td3': TD3,
#
'TQC': TQC,
'tqc': TQC,
}
return algos[algo]

Expand All @@ -31,9 +43,12 @@ def sb3_train(config_file, **kwargs):
options = {**options, **kwargs}
# updates / expands on yaml options with optional user-provided input

vec_env = make_vec_env(
options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]}
)
if "n_envs" in options:
env = make_vec_env(
options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]}
)
else:
env = gym.make(options["env_id"])
ALGO = algorithm(options["algo"])
model_id = options["algo"] + "-" + options["env_id"] + "-" + options["id"]
save_id = os.path.join(options["save_path"], model_id)
Expand Down

0 comments on commit 72e2cf8

Please sign in to comment.