-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_history.py
56 lines (52 loc) · 2.17 KB
/
loss_history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from pandas import *
from dataprocess import *
import os
import numpy as np
import scipy.io
import csv
import tensorflow as tf
from feature_func import totalfunc
from collections import Counter
from keras.layers.core import Activation, Dense, SpatialDropout1D
from keras.layers import Embedding
from keras.layers import LSTM,Flatten,Dense
from keras.models import Sequential
from keras.callbacks import Callback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import random
import matplotlib.pyplot as plt
class LossHistory(Callback):
def on_train_begin(self, logs={}):
self.losses = {'batch': [], 'epoch': []}
self.accuracy = {'batch': [], 'epoch': []}
self.val_loss = {'batch': [], 'epoch': []}
self.val_acc = {'batch': [], 'epoch': []}
def on_batch_end(self, batch, logs={}):
self.losses['batch'].append(logs.get('loss'))
self.accuracy['batch'].append(logs.get('accuracy'))
self.val_loss['batch'].append(logs.get('val_loss'))
self.val_acc['batch'].append(logs.get('val_accuracy'))
def on_epoch_end(self, batch, logs={}):
self.losses['epoch'].append(logs.get('loss'))
self.accuracy['epoch'].append(logs.get('accuracy'))
self.val_loss['epoch'].append(logs.get('val_loss'))
self.val_acc['epoch'].append(logs.get('val_accuracy'))
def loss_plot(self, loss_type):
iters = range(len(self.losses[loss_type]))
#创建一个图
plt.figure()
# acc
plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')#plt.plot(x,y),这个将数据画成曲线
# loss
plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
if loss_type == 'epoch':
# val_acc
plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
# val_loss
plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
plt.grid(True)#设置网格形式
plt.xlabel(loss_type)
plt.ylabel('acc-loss')#给x,y轴加注释
plt.legend(loc="upper right")#设置图例显示位置
plt.show()