Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
davide97l committed Feb 18, 2020
1 parent 6de7543 commit c3480b4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 2 additions & 0 deletions DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(self, params):
def train(self,bat_s,bat_a,bat_t,bat_n,bat_r):
feed_dict={self.x: bat_n, self.q_t: np.zeros(bat_n.shape[0]),
self.actions: bat_a, self.terminals: bat_t, self.rewards: bat_r}
# get Q-value of the next state
q_t = self.sess.run(self.y, feed_dict=feed_dict)
q_t = np.amax(q_t, axis=1)
feed_dict={self.x: bat_s, self.q_t: q_t, self.actions: bat_a, self.terminals: bat_t, self.rewards: bat_r}
# make one training step
_,cnt,cost = self.sess.run([self.optim, self.global_step,self.cost],feed_dict=feed_dict)
return cnt, cost

Expand Down
Binary file modified Pacman performances.xlsx
Binary file not shown.
11 changes: 6 additions & 5 deletions dqnAgents.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# only params value can be modified
params = {
# Model backups
'load_file': "model-smallClassic_784471_5703", # relative path to the saved model
'load_file': None, # relative path to the saved model
'save_file': "smallClassic", # name of the model
'save_interval': 100000, # Number of steps between each checkpoint

Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, width, height, numTraining=0):
def getMove(self):

# Exploit / Explore
if np.random.rand() > self.params['eps']:
if np.random.rand() >= self.params['eps']:
# Exploit action
self.Q_pred = self.qnet.sess.run(
self.qnet.y,
Expand Down Expand Up @@ -195,6 +195,8 @@ def observation_step(self, state):
self.frame += 1
self.params['eps'] = max(self.params['eps_final'],
1.00 - float(self.cnt) / float(self.params['eps_step']))
if self.numeps >= params['num_training']:
params['eps'] = 0

# Do an observation after each step (this method is called in the game.py file after each step)
def observationFunction(self, state):
Expand Down Expand Up @@ -378,16 +380,15 @@ def registerInitialState(self, state): # inspects the starting state
self.frame = 0
self.numeps += 1

if self.numeps >params['num_training']:
if self.numeps >= params['num_training']:
params['eps'] = 0

# Returns an action from the agent (this method is called in the game.py file when the agent has to select an action)
def getAction(self, state):
move = self.getMove()

# Stop moving when not legal
legal = state.getLegalActions(0)
if move not in legal:
move = Directions.STOP
move = random.choice(legal)

return move

0 comments on commit c3480b4

Please sign in to comment.