Skip to content

Commit 6ec178f

Browse files
committed
[AI] Fix tensorboard
1 parent 34e61f3 commit 6ec178f

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

ai-example.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@ def init_argparse() -> argparse.ArgumentParser:
7373
parser.add_argument("-p", "--plot", action=argparse.BooleanOptionalAction)
7474
parser.add_argument('-w', '--weights', type=str, help='a trained model weights')
7575
parser.add_argument("-d", "--days", type=int, default=365)
76+
parser.add_argument("-ev", "--evaluate", action=argparse.BooleanOptionalAction)
77+
parser.add_argument("-ep", "--epochs", type=int, default=20)
7678
return parser
7779

7880

81+
7982
def main():
8083
parser = init_argparse()
8184
args = parser.parse_args()
@@ -90,7 +93,7 @@ def main():
9093
agent = obs.DQNAgent(action_size)
9194

9295
logdir = "tensorboard_logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
93-
tensorboard_callback = TensorBoard(log_dir=logdir)
96+
tensorboard_callback = TensorBoard(log_dir=logdir, histogram_freq=1, write_images=False, batch_size=args.batch_size)
9497

9598
if args.weights:
9699
print(f"Loading model {args.weights}...")
@@ -107,7 +110,9 @@ def main():
107110

108111
if args.train and len(agent.memory) > args.batch_size:
109112
print("Starting replay...")
110-
agent.replay(args.batch_size, tensorboard_callback)
113+
score = agent.replay(args.batch_size, args.epochs, args.evaluate, tensorboard_callback)
114+
if args.evaluate:
115+
print(f"Score = {score}")
111116

112117
if args.train and (episode + 1) % 10 == 0: # checkpoint weights
113118
print("Saving...")

octobot_script/ai/agents.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def act(self, state):
2323
act_values = self.model.predict(state)
2424
return np.argmax(act_values[0]) # returns action
2525

26-
def replay(self, batch_size=32, tensorboard_callback=None):
26+
def replay(self, batch_size=32, epochs=1, evaluate=False, tensorboard_callback=None):
2727
# pylint: disable=unsubscriptable-object
2828
""" vectorized implementation; 30x speed up compared with for loop """
2929
minibatch = random.sample(self.memory, batch_size)
@@ -44,10 +44,14 @@ def replay(self, batch_size=32, tensorboard_callback=None):
4444
# make the agent to approximately map the current state to future discounted reward
4545
target_f[range(batch_size), actions] = target
4646

47-
self.model.fit(states, target_f, epochs=1, verbose=0)
47+
self.model.fit(states, target_f, batch_size=batch_size, epochs=epochs, verbose=0, callbacks=[tensorboard_callback])
4848

4949
if self.epsilon > self.epsilon_min:
5050
self.epsilon *= self.epsilon_decay
51+
52+
if evaluate:
53+
return self.model.evaluate(states, target_f, batch_size=32)
54+
return 0
5155

5256
def load(self, name):
5357
self.model.load_weights(name)

0 commit comments

Comments
 (0)