-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
15 lines (13 loc) · 785 Bytes
/
train.py
File metadata and controls
15 lines (13 loc) · 785 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import gymnasium as gym
import src # Needed to register the environment
from stable_baselines3 import A2C, DQN, PPO
from sb3_contrib import TRPO
from stable_baselines3.common.env_util import make_vec_env
from rl.components import CombinedExtractorBatchNorm
import torch.nn as nn
if __name__ == '__main__':
env = make_vec_env('MicroMouse-v0', env_kwargs={'file': r'micromouse_maze_tool\mazefiles\binary\1stworld.maz', 'render_mode': 'human'}, n_envs=1)
model = PPO('MultiInputPolicy', env, verbose=1, tensorboard_log='./logs', seed=1, policy_kwargs={'activation_fn': nn.ReLU, 'features_extractor_class': CombinedExtractorBatchNorm})
# model = PPO.load('saved_models/model_good.zip', env)
model.learn(total_timesteps=1_000_000)
model.save('saved_models/model.zip')