Skip to content

Commit 963caaf

Browse files
committed
Add provided template.
1 parent c61d4b2 commit 963caaf

10 files changed

+870
-0
lines changed

AgentPlayer.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# AI Agent. It will learn and play the snake game
2+
import os
3+
import torch
4+
import numpy as np
5+
from SnakeGame import SnakeGame, Direction
6+
from Model import Linear_QNet
7+
from Helper import plot
8+
9+
class Agent:
10+
11+
def __init__(self):
12+
if torch.cuda.is_available():
13+
torch.device = 'cuda'
14+
else:
15+
torch.device = 'cpu'
16+
17+
self.n_games = 0
18+
self.model = Linear_QNet(11, 256, 3)
19+
20+
def get_state(self, game):
21+
point_l = (game.headx - 20, game.heady)
22+
point_r = (game.headx + 20, game.heady)
23+
point_u = (game.headx, game.heady + 20)
24+
point_d = (game.headx, game.heady - 20)
25+
26+
dir_l = game.direction == Direction.LEFT
27+
dir_r = game.direction == Direction.RIGHT
28+
dir_u = game.direction == Direction.UP
29+
dir_d = game.direction == Direction.DOWN
30+
31+
state = [
32+
# Danger straight
33+
(dir_r and game.iscollision(point_r)) or
34+
(dir_l and game.iscollision(point_l)) or
35+
(dir_u and game.iscollision(point_u)) or
36+
(dir_d and game.iscollision(point_d)),
37+
38+
# Danger right
39+
(dir_u and game.iscollision(point_r)) or
40+
(dir_d and game.iscollision(point_l)) or
41+
(dir_l and game.iscollision(point_u)) or
42+
(dir_r and game.iscollision(point_d)),
43+
44+
# Danger left
45+
(dir_d and game.iscollision(point_r)) or
46+
(dir_u and game.iscollision(point_l)) or
47+
(dir_r and game.iscollision(point_u)) or
48+
(dir_l and game.iscollision(point_d)),
49+
50+
# Move direction
51+
dir_l,
52+
dir_r,
53+
dir_u,
54+
dir_d,
55+
56+
# Apple location
57+
game.applex < game.headx, # Apple left
58+
game.applex > game.headx, # Apple right
59+
game.appley > game.heady, # Apple up
60+
game.appley < game.heady # Apple down
61+
]
62+
63+
return np.array(state, dtype=int)
64+
65+
def get_action(self, state):
66+
final_move = [0,0,0]
67+
state0 = torch.tensor(state, dtype=torch.float)
68+
prediction = self.model(state0)
69+
move = torch.argmax(prediction).item()
70+
final_move[move] = 1
71+
return final_move
72+
73+
def get_trained_model(file_name='model79.pth'):
74+
trained_model = None
75+
model_folder_path = './model'
76+
file_name = os.path.join(model_folder_path, file_name)
77+
if os.path.exists(file_name):
78+
try:
79+
trained_model = Linear_QNet(11, 256, 3)
80+
checkpoint = torch.load(file_name)
81+
trained_model.load_state_dict(checkpoint['model_state_dict'])
82+
except:
83+
print("Continue...")
84+
return trained_model
85+
86+
def play():
87+
plot_scores = []
88+
plot_mean_scores = []
89+
total_score = 0
90+
record = 0
91+
agent = Agent()
92+
trained_model = get_trained_model()
93+
if trained_model is None:
94+
print("no trained model found. ")
95+
quit
96+
else:
97+
agent.model = trained_model
98+
agent.model.eval()
99+
100+
game = SnakeGame(delay=0.000)
101+
while True:
102+
# get old state
103+
state_old = agent.get_state(game)
104+
105+
# get move
106+
action = agent.get_action(state_old)
107+
108+
# perform move and get new state
109+
game.play_step(action)
110+
score = game.score
111+
done = game.game_over
112+
if done:
113+
# plot result
114+
game.reset()
115+
agent.n_games += 1
116+
117+
print('Game', agent.n_games, 'Score', score, 'Record:', record)
118+
119+
plot_scores.append(score)
120+
total_score += score
121+
mean_score = total_score / agent.n_games
122+
plot_mean_scores.append(mean_score)
123+
plot(plot_scores, plot_mean_scores,"Playing...")
124+
125+
126+
if __name__ == '__main__':
127+
play()

AgentTrainer.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# AI Agent. It will learn and play the snake game
2+
import os
3+
import torch
4+
import random
5+
import numpy as np
6+
from collections import deque
7+
from SnakeGame import SnakeGame, Direction #, Point
8+
from Model import Linear_QNet, QTrainer
9+
from Helper import plot
10+
11+
12+
MAX_MEMORY = 200_000
13+
BATCH_SIZE = 2000
14+
LR = 0.001
15+
16+
class Agent:
17+
18+
def __init__(self):
19+
self.n_games = 0
20+
self.epsilon = 0 # randomness
21+
self.gamma = 0.8 # discount rate
22+
self.memory = deque(maxlen=MAX_MEMORY) # popleft()
23+
self.model = Linear_QNet(11, 256, 3 )
24+
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
25+
self.is_trained_model = False
26+
27+
def get_state(self, game):
28+
point_l = (game.headx - 20, game.heady)
29+
point_r = (game.headx + 20, game.heady)
30+
point_u = (game.headx, game.heady + 20)
31+
point_d = (game.headx, game.heady - 20)
32+
33+
dir_l = game.direction == Direction.LEFT
34+
dir_r = game.direction == Direction.RIGHT
35+
dir_u = game.direction == Direction.UP
36+
dir_d = game.direction == Direction.DOWN
37+
38+
state = [
39+
# Danger straight
40+
(dir_r and game.iscollision(point_r)) or
41+
(dir_l and game.iscollision(point_l)) or
42+
(dir_u and game.iscollision(point_u)) or
43+
(dir_d and game.iscollision(point_d)),
44+
45+
# Danger right
46+
(dir_u and game.iscollision(point_r)) or
47+
(dir_d and game.iscollision(point_l)) or
48+
(dir_l and game.iscollision(point_u)) or
49+
(dir_r and game.iscollision(point_d)),
50+
51+
# Danger left
52+
(dir_d and game.iscollision(point_r)) or
53+
(dir_u and game.iscollision(point_l)) or
54+
(dir_r and game.iscollision(point_u)) or
55+
(dir_l and game.iscollision(point_d)),
56+
57+
# Move direction
58+
dir_l,
59+
dir_r,
60+
dir_u,
61+
dir_d,
62+
63+
# Apple location
64+
game.applex < game.headx, # Apple left
65+
game.applex > game.headx, # Apple right
66+
game.appley > game.heady, # Apple up
67+
game.appley < game.heady # Apple down
68+
]
69+
70+
return np.array(state, dtype=int)
71+
72+
def get_distance(self,game):
73+
return game.head.distance(game.apple)
74+
75+
def remember(self, state, action, reward, next_state, done):
76+
self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached
77+
78+
def train_long_memory(self):
79+
if len(self.memory) > BATCH_SIZE:
80+
mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
81+
else:
82+
mini_sample = self.memory
83+
84+
states, actions, rewards, next_states, dones = zip(*mini_sample)
85+
self.trainer.train_step(states, actions, rewards, next_states, dones)
86+
87+
def train_short_memory(self, state, action, reward, next_state, done):
88+
self.trainer.train_step(state, action, reward, next_state, done)
89+
90+
def get_action(self, state):
91+
# random moves: tradeoff exploration / exploitation
92+
self.epsilon = 80 - self.n_games
93+
final_move = [0,0,0]
94+
if random.randint(0, 200) < self.epsilon and self.is_trained_model == False:
95+
move = random.randint(0, 2)
96+
final_move[move] = 1
97+
else:
98+
state0 = torch.tensor(state, dtype=torch.float)
99+
prediction = self.model(state0)
100+
move = torch.argmax(prediction).item()
101+
final_move[move] = 1
102+
103+
return final_move
104+
105+
def get_trained_model(file_name='model.pth'):
106+
trained_model = None
107+
record = 0
108+
model_folder_path = './model'
109+
file_name = os.path.join(model_folder_path, file_name)
110+
if os.path.exists(file_name):
111+
try:
112+
trained_model = Linear_QNet(11, 256, 3)
113+
checkpoint = torch.load(file_name)
114+
trained_model.load_state_dict(checkpoint['model_state_dict'])
115+
record = checkpoint['record']
116+
except:
117+
print("Continue...")
118+
119+
return trained_model, record
120+
121+
def rebuild():
122+
record = 68
123+
model_folder_path = './model'
124+
source_file = 'model68.pth'
125+
destination_file = 'model.pth'
126+
source = os.path.join(model_folder_path, source_file)
127+
destination = os.path.join(model_folder_path, destination_file)
128+
trained_model = Linear_QNet(11, 256, 3)
129+
trained_model.load_state_dict(torch.load(source))
130+
torch.save(
131+
{
132+
'record': record,
133+
'model_state_dict': trained_model.state_dict(),
134+
}, destination)
135+
136+
def train():
137+
plot_scores = []
138+
plot_mean_scores = []
139+
total_score = 0
140+
record = 0
141+
agent = Agent()
142+
trained_model, record = get_trained_model()
143+
if trained_model is not None :
144+
# load previous save the model and continue training
145+
agent.is_trained_model = True
146+
agent.model = trained_model
147+
agent.model.train()
148+
149+
game = SnakeGame(delay=0.00)
150+
while True:
151+
# get current state
152+
state_current = agent.get_state(game)
153+
distance_current = agent.get_distance(game)
154+
155+
# get move
156+
action = agent.get_action(state_current)
157+
158+
# perform move and get new state
159+
reward = 0
160+
game.play_step(action)
161+
done = game.game_over
162+
score = game.score
163+
164+
# assign reward based on result of the move
165+
if game.game_over:
166+
reward = -10
167+
168+
if game.eatapple():
169+
reward = 10
170+
171+
state_new = agent.get_state(game)
172+
distance_new = agent.get_distance(game)
173+
174+
# reward if it moves toward the apple
175+
if distance_new < distance_current:
176+
reward += 1
177+
178+
# train short memory
179+
agent.train_short_memory(state_current, action, reward, state_new, done)
180+
181+
# remember
182+
agent.remember(state_current, action, reward, state_new, done)
183+
184+
if done:
185+
# train long memory, plot result
186+
game.reset()
187+
agent.n_games += 1
188+
agent.train_long_memory()
189+
190+
if score > record:
191+
record = score
192+
agent.model.save(record)
193+
194+
print('Game', agent.n_games, 'Score', score, 'Record:', record)
195+
196+
plot_scores.append(score)
197+
total_score += score
198+
mean_score = total_score / agent.n_games
199+
plot_mean_scores.append(mean_score)
200+
plot(plot_scores, plot_mean_scores)
201+
202+
203+
if __name__ == '__main__':
204+
#rebuild()
205+
train()

Helper.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import matplotlib.pyplot as plt
2+
from IPython import display
3+
4+
plt.ion()
5+
6+
def plot(scores, mean_scores,title="Training..."):
7+
display.clear_output(wait=True)
8+
display.display(plt.gcf())
9+
plt.clf()
10+
plt.title(title)
11+
plt.xlabel('Number of Games')
12+
plt.ylabel('Score')
13+
plt.plot(scores)
14+
plt.plot(mean_scores)
15+
plt.ylim(ymin=0)
16+
plt.text(len(scores)-1, scores[-1], str(scores[-1]))
17+
plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))

0 commit comments

Comments
 (0)