-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathtrain.py
54 lines (39 loc) · 1.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
from dataclasses import dataclass
from src.gym import PathPlanningGymFactory
from src.trainer.agent import AgentFactory
from src.base.evaluator import Evaluator
from src.base.logger import Logger
from src.trainer.trainer import TrainerFactory
from utils import AbstractParams
@dataclass
class PathPlanningParams(AbstractParams):
trainer: TrainerFactory.default_param_type() = TrainerFactory.default_params()
gym: PathPlanningGymFactory.default_param_type() = PathPlanningGymFactory.default_params()
logger: Logger.Params = Logger.Params()
evaluator: Evaluator.Params = Evaluator.Params()
agent: AgentFactory.default_param_type() = AgentFactory.default_params()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = PathPlanningParams.add_args_to_parser(parser)
args = parser.parse_args()
params, args = PathPlanningParams.from_parsed_args(args)
log_dir = params.create_folders(args)
gym = PathPlanningGymFactory.create(params.gym)
action_space = gym.action_space
obs_space = gym.observation_space
agent = AgentFactory.create(params.agent, obs_space=obs_space, act_space=action_space)
if args.verbose:
agent.summary()
logger = Logger(params.logger, log_dir, agent)
trainer = TrainerFactory.create(params.trainer, gym=gym, logger=logger, agent=agent)
evaluator = Evaluator(params.evaluator, trainer, gym)
logger.evaluator = evaluator
params.save_to(params.log_dir + "config.json")
if not args.gpu and args.gpu_id is None:
print("Running on CPU")
else:
print(f"Running on GPU {args.gpu_id}" if args.gpu_id is not None else "Running on GPU")
trainer.train()
agent.save_keras(params.log_dir + 'models/')
gym.close()