-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
36 lines (28 loc) · 1.14 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
from sklearn import metrics
import network
from dataset import MelSpecEncoded
import splitter as sp
from trainer import TrainCache
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('timestamp', type=str, help='models timestamp')
args = parser.parse_args()
model, params = TrainCache.load(network.MelCNN2d, args.timestamp)
prefixes = ['CatsAndDogs']
if params['n_classes'] == 2:
prefixes = [p + '_binary' for p in prefixes]
dataset = MelSpecEncoded(prefixes)
loaders = sp.TestSplitter(batch_size=dataset.size).get_all(dataset)
data = next(iter(loaders))
x, y = data['x'], data['y']
y_hat = model(x).argmax(axis=1)
y = y.detach().cpu().numpy()
y_hat = y_hat.detach().cpu().numpy()
y = y == dataset.pos_class
y_hat = y_hat == dataset.pos_class
print(f'Accuracy: {metrics.accuracy_score(y, y_hat):.4f}')
print(f'Recall: {metrics.recall_score(y, y_hat):.4f}')
print(f'Precision: {metrics.precision_score(y, y_hat):.4f}')
print(f'Kappa: {metrics.cohen_kappa_score(y, y_hat):.4f}')
print(f'F1: {metrics.f1_score(y, y_hat):.4f}')