1
+ # AI Agent. It will learn and play the snake game
2
+ import os
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from collections import deque
7
+ from SnakeGame import SnakeGame , Direction #, Point
8
+ from Model import Linear_QNet , QTrainer
9
+ from Helper import plot
10
+
11
+
12
+ MAX_MEMORY = 200_000
13
+ BATCH_SIZE = 2000
14
+ LR = 0.001
15
+
16
+ class Agent :
17
+
18
+ def __init__ (self ):
19
+ self .n_games = 0
20
+ self .epsilon = 0 # randomness
21
+ self .gamma = 0.8 # discount rate
22
+ self .memory = deque (maxlen = MAX_MEMORY ) # popleft()
23
+ self .model = Linear_QNet (11 , 256 , 3 )
24
+ self .trainer = QTrainer (self .model , lr = LR , gamma = self .gamma )
25
+ self .is_trained_model = False
26
+
27
+ def get_state (self , game ):
28
+ point_l = (game .headx - 20 , game .heady )
29
+ point_r = (game .headx + 20 , game .heady )
30
+ point_u = (game .headx , game .heady + 20 )
31
+ point_d = (game .headx , game .heady - 20 )
32
+
33
+ dir_l = game .direction == Direction .LEFT
34
+ dir_r = game .direction == Direction .RIGHT
35
+ dir_u = game .direction == Direction .UP
36
+ dir_d = game .direction == Direction .DOWN
37
+
38
+ state = [
39
+ # Danger straight
40
+ (dir_r and game .iscollision (point_r )) or
41
+ (dir_l and game .iscollision (point_l )) or
42
+ (dir_u and game .iscollision (point_u )) or
43
+ (dir_d and game .iscollision (point_d )),
44
+
45
+ # Danger right
46
+ (dir_u and game .iscollision (point_r )) or
47
+ (dir_d and game .iscollision (point_l )) or
48
+ (dir_l and game .iscollision (point_u )) or
49
+ (dir_r and game .iscollision (point_d )),
50
+
51
+ # Danger left
52
+ (dir_d and game .iscollision (point_r )) or
53
+ (dir_u and game .iscollision (point_l )) or
54
+ (dir_r and game .iscollision (point_u )) or
55
+ (dir_l and game .iscollision (point_d )),
56
+
57
+ # Move direction
58
+ dir_l ,
59
+ dir_r ,
60
+ dir_u ,
61
+ dir_d ,
62
+
63
+ # Apple location
64
+ game .applex < game .headx , # Apple left
65
+ game .applex > game .headx , # Apple right
66
+ game .appley > game .heady , # Apple up
67
+ game .appley < game .heady # Apple down
68
+ ]
69
+
70
+ return np .array (state , dtype = int )
71
+
72
+ def get_distance (self ,game ):
73
+ return game .head .distance (game .apple )
74
+
75
+ def remember (self , state , action , reward , next_state , done ):
76
+ self .memory .append ((state , action , reward , next_state , done )) # popleft if MAX_MEMORY is reached
77
+
78
+ def train_long_memory (self ):
79
+ if len (self .memory ) > BATCH_SIZE :
80
+ mini_sample = random .sample (self .memory , BATCH_SIZE ) # list of tuples
81
+ else :
82
+ mini_sample = self .memory
83
+
84
+ states , actions , rewards , next_states , dones = zip (* mini_sample )
85
+ self .trainer .train_step (states , actions , rewards , next_states , dones )
86
+
87
+ def train_short_memory (self , state , action , reward , next_state , done ):
88
+ self .trainer .train_step (state , action , reward , next_state , done )
89
+
90
+ def get_action (self , state ):
91
+ # random moves: tradeoff exploration / exploitation
92
+ self .epsilon = 80 - self .n_games
93
+ final_move = [0 ,0 ,0 ]
94
+ if random .randint (0 , 200 ) < self .epsilon and self .is_trained_model == False :
95
+ move = random .randint (0 , 2 )
96
+ final_move [move ] = 1
97
+ else :
98
+ state0 = torch .tensor (state , dtype = torch .float )
99
+ prediction = self .model (state0 )
100
+ move = torch .argmax (prediction ).item ()
101
+ final_move [move ] = 1
102
+
103
+ return final_move
104
+
105
+ def get_trained_model (file_name = 'model.pth' ):
106
+ trained_model = None
107
+ record = 0
108
+ model_folder_path = './model'
109
+ file_name = os .path .join (model_folder_path , file_name )
110
+ if os .path .exists (file_name ):
111
+ try :
112
+ trained_model = Linear_QNet (11 , 256 , 3 )
113
+ checkpoint = torch .load (file_name )
114
+ trained_model .load_state_dict (checkpoint ['model_state_dict' ])
115
+ record = checkpoint ['record' ]
116
+ except :
117
+ print ("Continue..." )
118
+
119
+ return trained_model , record
120
+
121
+ def rebuild ():
122
+ record = 68
123
+ model_folder_path = './model'
124
+ source_file = 'model68.pth'
125
+ destination_file = 'model.pth'
126
+ source = os .path .join (model_folder_path , source_file )
127
+ destination = os .path .join (model_folder_path , destination_file )
128
+ trained_model = Linear_QNet (11 , 256 , 3 )
129
+ trained_model .load_state_dict (torch .load (source ))
130
+ torch .save (
131
+ {
132
+ 'record' : record ,
133
+ 'model_state_dict' : trained_model .state_dict (),
134
+ }, destination )
135
+
136
+ def train ():
137
+ plot_scores = []
138
+ plot_mean_scores = []
139
+ total_score = 0
140
+ record = 0
141
+ agent = Agent ()
142
+ trained_model , record = get_trained_model ()
143
+ if trained_model is not None :
144
+ # load previous save the model and continue training
145
+ agent .is_trained_model = True
146
+ agent .model = trained_model
147
+ agent .model .train ()
148
+
149
+ game = SnakeGame (delay = 0.00 )
150
+ while True :
151
+ # get current state
152
+ state_current = agent .get_state (game )
153
+ distance_current = agent .get_distance (game )
154
+
155
+ # get move
156
+ action = agent .get_action (state_current )
157
+
158
+ # perform move and get new state
159
+ reward = 0
160
+ game .play_step (action )
161
+ done = game .game_over
162
+ score = game .score
163
+
164
+ # assign reward based on result of the move
165
+ if game .game_over :
166
+ reward = - 10
167
+
168
+ if game .eatapple ():
169
+ reward = 10
170
+
171
+ state_new = agent .get_state (game )
172
+ distance_new = agent .get_distance (game )
173
+
174
+ # reward if it moves toward the apple
175
+ if distance_new < distance_current :
176
+ reward += 1
177
+
178
+ # train short memory
179
+ agent .train_short_memory (state_current , action , reward , state_new , done )
180
+
181
+ # remember
182
+ agent .remember (state_current , action , reward , state_new , done )
183
+
184
+ if done :
185
+ # train long memory, plot result
186
+ game .reset ()
187
+ agent .n_games += 1
188
+ agent .train_long_memory ()
189
+
190
+ if score > record :
191
+ record = score
192
+ agent .model .save (record )
193
+
194
+ print ('Game' , agent .n_games , 'Score' , score , 'Record:' , record )
195
+
196
+ plot_scores .append (score )
197
+ total_score += score
198
+ mean_score = total_score / agent .n_games
199
+ plot_mean_scores .append (mean_score )
200
+ plot (plot_scores , plot_mean_scores )
201
+
202
+
203
+ if __name__ == '__main__' :
204
+ #rebuild()
205
+ train ()
0 commit comments