diff --git a/hyperpars/gcse-boilerplate.yml b/hyperpars/gcse-boilerplate.yml index c3096e1..d5cbe1e 100644 --- a/hyperpars/gcse-boilerplate.yml +++ b/hyperpars/gcse-boilerplate.yml @@ -2,7 +2,6 @@ env_id: "gcsenv" config: {} -n_envs: 12 tensorboard: "/home/rstudio/logs" use_sde: True id: "1" diff --git a/scripts/train_algos.sh b/scripts/train_algos.sh new file mode 100644 index 0000000..4649012 --- /dev/null +++ b/scripts/train_algos.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# move to script directory for normalized relative paths. +scriptdir="$(dirname "$0")" +cd "$scriptdir" + +python train_gcse_manual.py -t 1000000 -a ppo -ne 10 & +python train_gcse_manual.py -t 1000000 -a rppo -ne 10 & +python train_gcse_manual.py -t 1000000 -a her -ne 10 & +python train_gcse_manual.py -t 1000000 -a tqc -ne 10 & +python train_gcse_manual.py -t 1000000 -a td3 & diff --git a/scripts/train_benchmarks.sh b/scripts/train_benchmarks.sh deleted file mode 100644 index 9f43b24..0000000 --- a/scripts/train_benchmarks.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -# move to script directory for normalized relative paths. -scriptdir="$(dirname "$0")" -cd "$scriptdir" - -python train.py --file ../hyperpars/systematic-benchmarks/TQC-Nmem_1/10-tqc_nmem-1_bmk.yml & -python train.py --file ../hyperpars/systematic-benchmarks/TQC-Nmem_1/11-tqc_nmem-1_bmk.yml & -python train.py --file ../hyperpars/systematic-benchmarks/TQC-Nmem_1/12-tqc_nmem-1_bmk.yml & -python train.py --file ../hyperpars/systematic-benchmarks/TQC-Nmem_1/13-tqc_nmem-1_bmk.yml & -python train.py --file ../hyperpars/systematic-benchmarks/TQC-Nmem_1/14-tqc_nmem-1_bmk.yml & \ No newline at end of file diff --git a/scripts/train_gcse_manual.py b/scripts/train_gcse_manual.py new file mode 100644 index 0000000..64d46d2 --- /dev/null +++ b/scripts/train_gcse_manual.py @@ -0,0 +1,41 @@ +#!/opt/venv/bin/python +import argparse +parser = argparse.ArgumentParser() +parser.add_argument("-a", "--algo", help="Algo to train", type=str, + choices=[ + 'PPO', 'RecurrentPPO', 'ppo', 'recurrentppo', + 'ARS', 'A2C', 'ars', 'a2c', + 'DDPG', 'ddpg', + 'HER', 'her', + 'SAC', 'sac' + 'TD3', 'td3', + 'TQC', 'tqc', + ] +) +parser.add_argument("-t", "--time-steps", help="N. timesteps to train for", type=int) +parser.add_argument( + "-ne", + "--n-envs", + help="Number of envs to use simultaneously for faster training. " + "Check algos for compatibility with this arg.", + type=int, +) + +args = parser.parse_args() + +manual_kwargs = {} +if args.algo: + manual_kwargs['algo'] = args.algo +if args.time_steps: + manual_kwargs['time_steps'] = args.time_steps +if args.n_envs: + manual_kwargs['n_envs'] = args.n_envs + +import os +boilerplate_cfg = os.path.join("..", "hyperpars", "gcse-boilerplate.yml") + + +import rl4greencrab +from rl4greencrab import sb3_train + +sb3_train(boilerplate_cfg, **manual_kwargs) diff --git a/src/rl4greencrab/utils/sb3.py b/src/rl4greencrab/utils/sb3.py index 8573fc6..91634a2 100644 --- a/src/rl4greencrab/utils/sb3.py +++ b/src/rl4greencrab/utils/sb3.py @@ -3,25 +3,37 @@ import gymnasium as gym from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3 import PPO, A2C, DQN, SAC, TD3 -from sb3_contrib import TQC, ARS +from stable_baselines3 import PPO, A2C, DQN, SAC, TD3, HER +from sb3_contrib import TQC, ARS, RecurrentPPO def algorithm(algo): algos = { - "PPO": PPO, - "ARS": ARS, - "TQC": TQC, - "A2C": A2C, - "SAC": SAC, - "DQN": DQN, - "TD3": TD3, - "ppo": PPO, - "ars": ARS, - "tqc": TQC, - "a2c": A2C, - "sac": SAC, - "dqn": DQN, - "td3": TD3, + 'PPO': PPO, + 'ppo': PPO, + 'RecurrentPPO': RecurrentPPO, + 'RPPO': RecurrentPPO, + 'recurrentppo': RecurrentPPO, + 'rppo': RecurrentPPO, + # + 'ARS': ARS, + 'ars': ARS, + 'A2C': A2C, + 'a2c':A2C , + # + 'DDPG': DDPG, + 'ddpg': DDPG, + # + 'HER': HER, + 'her': HER, + # + 'SAC': SAC, + 'sac': SAC, + # + 'TD3': TD3, + 'td3': TD3, + # + 'TQC': TQC, + 'tqc': TQC, } return algos[algo] @@ -31,9 +43,12 @@ def sb3_train(config_file, **kwargs): options = {**options, **kwargs} # updates / expands on yaml options with optional user-provided input - vec_env = make_vec_env( - options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]} - ) + if "n_envs" in options: + env = make_vec_env( + options["env_id"], options["n_envs"], env_kwargs={"config": options["config"]} + ) + else: + env = gym.make(options["env_id"]) ALGO = algorithm(options["algo"]) model_id = options["algo"] + "-" + options["env_id"] + "-" + options["id"] save_id = os.path.join(options["save_path"], model_id)