-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn-cartpole-9.6.1.py
247 lines (194 loc) · 8.17 KB
/
dqn-cartpole-9.6.1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""Trains a DQN/DDQN to solve CartPole-v0 problem
"""
from keras.layers import Dense, Input
from keras.models import Model
from keras.optimizers import Adam
from collections import deque
import numpy as np
import random
import argparse
import gym
from gym import wrappers, logger
class DQNAgent():
def __init__(self, state_space, action_space, args, episodes=500):
self.action_space = action_space
# experience buffer
self.memory = []
# discount rate
self.gamma = 0.9
# initially 90% exploration, 10% exploitation
self.epsilon = 1.0
# iteratively applying decay til 10% exploration/90% exploitation
self.epsilon_min = 0.1
self.epsilon_decay = self.epsilon_min / self.epsilon
self.epsilon_decay = self.epsilon_decay ** (1. / float(episodes))
# Q Network weights filename
self.weights_file = 'dqn_cartpole.h5'
# Q Network for training
n_inputs = state_space.shape[0]
n_outputs = action_space.n
self.q_model = self.build_model(n_inputs, n_outputs)
self.q_model.compile(loss='mse', optimizer=Adam())
# target Q Network
self.target_q_model = self.build_model(n_inputs, n_outputs)
# copy Q Network params to target Q Network
self.update_weights()
self.replay_counter = 0
self.ddqn = True if args.ddqn else False
if self.ddqn:
print("----------Double DQN--------")
else:
print("-------------DQN------------")
# Q Network is 256-256-256 MLP
def build_model(self, n_inputs, n_outputs):
inputs = Input(shape=(n_inputs, ), name='state')
x = Dense(256, activation='relu')(inputs)
x = Dense(256, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(n_outputs, activation='linear', name='action')(x)
q_model = Model(inputs, x)
q_model.summary()
return q_model
# save Q Network params to a file
def save_weights(self):
self.q_model.save_weights(self.weights_file)
# copy trained Q Network params to target Q Network
def update_weights(self):
self.target_q_model.set_weights(self.q_model.get_weights())
# eps-greedy policy
def act(self, state):
if np.random.rand() < self.epsilon:
# explore - do random action
return self.action_space.sample()
# exploit
q_values = self.q_model.predict(state)
# select the action with max Q-value
return np.argmax(q_values[0])
# store experiences in the replay buffer
def remember(self, state, action, reward, next_state, done):
item = (state, action, reward, next_state, done)
self.memory.append(item)
# compute Q_max
# use of target Q Network solves the non-stationarity problem
def get_target_q_value(self, next_state):
# max Q value among next state's actions
if self.ddqn:
# DDQN
# current Q Network selects the action
# a'_max = argmax_a' Q(s', a')
action = np.argmax(self.q_model.predict(next_state)[0])
# target Q Network evaluates the action
# Q_max = Q_target(s', a'_max)
q_value = self.target_q_model.predict(next_state)[0][action]
else:
# DQN chooses the max Q value among next actions
# selection and evaluation of action is on the target Q Network
# Q_max = max_a' Q_target(s', a')
q_value = np.amax(self.target_q_model.predict(next_state)[0])
# Q_max = reward + gamma * Q_max
q_value *= self.gamma
q_value += reward
return q_value
# experience replay addresses the correlation issue between samples
def replay(self, batch_size):
# sars = state, action, reward, state' (next_state)
sars_batch = random.sample(self.memory, batch_size)
state_batch, q_values_batch = [], []
# fixme: for speedup, this could be done on the tensor level
# but easier to understand using a loop
for state, action, reward, next_state, done in sars_batch:
# policy prediction for a given state
q_values = self.q_model.predict(state)
# get Q_max
q_value = self.get_target_q_value(next_state)
# correction on the Q value for the action used
q_values[0][action] = reward if done else q_value
# collect batch state-q_value mapping
state_batch.append(state[0])
q_values_batch.append(q_values[0])
# train the Q-network
self.q_model.fit(np.array(state_batch),
np.array(q_values_batch),
batch_size=batch_size,
epochs=1,
verbose=0)
# update exploration-exploitation probability
self.update_epsilon()
# copy new params on old target after every 10 training updates
if self.replay_counter % 10 == 0:
self.update_weights()
self.replay_counter += 1
# decrease the exploration, increase exploitation
def update_epsilon(self):
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=None)
parser.add_argument('env_id',
nargs='?',
default='CartPole-v0',
help='Select the environment to run')
parser.add_argument("-d",
"--ddqn",
action='store_true',
help="Use Double DQN")
args = parser.parse_args()
# the number of trials without falling over
win_trials = 100
# the CartPole-v0 is considered solved if for 100 consecutive trials,
# the cart pole has not fallen over and it has achieved an average
# reward of 195.0
# a reward of +1 is provided for every timestep the pole remains
# upright
win_reward = { 'CartPole-v0' : 195.0 }
# stores the reward per episode
scores = deque(maxlen=win_trials)
logger.setLevel(logger.ERROR)
env = gym.make(args.env_id)
outdir = "/tmp/dqn-%s" % args.env_id
if args.ddqn:
outdir = "/tmp/ddqn-%s" % args.env_id
env = wrappers.Monitor(env, directory=outdir, force=True)
env.seed(0)
# instantiate the DQN/DDQN agent
agent = DQNAgent(env.observation_space, env.action_space, args)
# should be solved in this number of episodes
episode_count = 3000
state_size = env.observation_space.shape[0]
batch_size = 64
# by default, CartPole-v0 has max episode steps = 200
# you can use this to experiment beyond 200
# env._max_episode_steps = 4000
# Q-Learning sampling and fitting
for episode in range(episode_count):
state = env.reset()
state = np.reshape(state, [1, state_size])
done = False
total_reward = 0
while not done:
# in CartPole-v0, action=0 is left and action=1 is right
action = agent.act(state)
next_state, reward, done, _ = env.step(action)
# in CartPole-v0:
# state = [pos, vel, theta, angular speed]
next_state = np.reshape(next_state, [1, state_size])
# store every experience unit in replay buffer
agent.remember(state, action, reward, next_state, done)
state = next_state
total_reward += reward
# call experience relay
if len(agent.memory) >= batch_size:
agent.replay(batch_size)
scores.append(total_reward)
mean_score = np.mean(scores)
if mean_score >= win_reward[args.env_id] and episode >= win_trials:
print("Solved in episode %d: Mean survival = %0.2lf in %d episodes"
% (episode, mean_score, win_trials))
print("Epsilon: ", agent.epsilon)
agent.save_weights()
break
if episode % win_trials == 0:
print("Episode %d: Mean survival = %0.2lf in %d episodes" %
(episode, mean_score, win_trials))
# close the env and write monitor result info to disk
env.close()