-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprogress.py
More file actions
123 lines (104 loc) · 4.32 KB
/
progress.py
File metadata and controls
123 lines (104 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import csv
from hyperparameters import *
import pandas as pd
import matplotlib.pyplot as plt
import logging
import numpy as np
class Writer(object):
"""
Create csv for keeping track of learning progress.
On default, a running average (combined current with last average) is plottet indicated by the "running" flag.
Plotting simple averages is also possible.
"""
def __init__(self, filename, model_name, running=True):
self.filename = WRITER_DIRECTORY + filename + '.csv'
self.model_name = model_name
self.count = 1
self.episodes = 0
self.loss_agg = 0
self.rewar_agg = 0
self.fieldnames = ["step", "loss", "reward", "episode", "avg loss", "avg reward"]
# if file already exists keep writing to it
try:
open(self.filename, "r")
df = pd.read_csv(self.filename)
self.episodes = np.asarray(df["episode"].dropna())[-1]
# if we want to plot smooth averages, set the current average to the last one
if running:
self.loss_agg = np.asarray(df["avg loss"].dropna())[-1]
self.rewar_agg = np.asarray(df["avg reward"].dropna())[-1]
self.count+=1
# else create new csv
except FileNotFoundError:
with open(self.filename, 'w', newline='') as file:
self.writer = csv.DictWriter(file, fieldnames=self.fieldnames)
self.writer.writeheader()
except EOFError:
return
def save_progress(self, step, loss, reward):
"""
Write progress to csv.
Accumulate loss and reward in each step to compute averages.
Number of steps before termination not neccesarily the same each time,
thus counting is neccesary.
"""
with open(self.filename, 'a', newline='') as file:
self.writer = csv.DictWriter(file, fieldnames=self.fieldnames)
self.writer.writerow({"step": step, "loss": loss, "reward": reward})
self.loss_agg += loss
self.rewar_agg += reward
self.count+=1
def plot_progress(self, rewards=True, average=True):
"""
Plotting training progress and saving average over episode.
"""
df = pd.read_csv(self.filename)
train_step = df["step"].to_numpy()[-1]
if average:
fig_name = WRITER_DIRECTORY + self.model_name + '_' + str(self.episodes) + '.png'
self.plot_progress_sub(fig_name, "episode", "avg reward", "avg loss", rewards)
else:
fig_name = WRITER_DIRECTORY + self.model_name + '_' + str(train_step) + '.png'
self.plot_progress_sub(fig_name, "steps", "reward", "loss", rewards)
# if plotting then episode over, save average over episode and reset
self.episodes+=1
with open(self.filename, 'a', newline='') as file:
self.writer = csv.DictWriter(file, fieldnames=self.fieldnames)
self.writer.writerow({"episode": self.episodes, "avg loss": (self.loss_agg/self.count), "avg reward": (self.rewar_agg/self.count)})
self.count = 1
self.loss_agg = 0
self.rewar_agg = 0
def plot_progress_sub(self, fig_name, stepstr, rewarstr, lossstr, rewards):
"""
Subroutine for creating matplotlib plots.
"""
df = pd.read_csv(self.filename)
if rewards:
df_loss = df[lossstr].dropna()
df_reward = df[rewarstr].dropna()
df_reward = df_reward.dropna()
fig, ax1 = plt.subplots()
ax1.set_xlabel(stepstr)
ax1.set_ylabel(lossstr, color='b')
if stepstr !="steps":
df_x = df[stepstr].dropna()
ax1.plot(df_x, df_loss, 'b')
else: ax1.plot(df_loss, 'b')
ax2 = ax1.twinx()
ax2.set_ylabel(rewarstr, color='r')
if stepstr!="steps":
ax2.plot(df_x, df_reward, 'r')
else: ax2.plot(df_reward, 'r')
else:
df = df[lossstr]
pl.figure()
plt.plot(df)
plt.title(self.model_name)
plt.savefig(fig_name)
if __name__=="__main__":
# testeng
w = Writer("try_runnung", "model_running", running=True)
for i in range(500):
w.save_progress(i,100-i,2*i)
if i%100==0:
w.plot_progress()