forked from MaciejMazurowski/thyroid-us
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
73 lines (55 loc) · 1.88 KB
/
test.py
File metadata and controls
73 lines (55 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import csv
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
from sklearn.metrics import roc_auc_score, roc_curve
from data import test_data, test_pids
from focal_loss import focal_loss
from plots import plot_roc
checkpoints_dir = "/data/test/checkpoints/"
batch_size = 128
nb_categories = 1
def predict():
weights_path = os.path.join(checkpoints_dir, "weights.h5")
net = load_model(weights_path, custom_objects={"focal_loss_fixed": focal_loss()})
X_test, y_test = test_data()
preds = net.predict(X_test, batch_size=batch_size, verbose=1)
return preds[0], y_test[0]
def test():
predictions, targets = predict()
cases_predictions = {}
cases_targets = {}
pids = test_pids()
for i in range(len(pids)):
pid = pids[i]
prev_pred = cases_predictions.get(pid, np.zeros(nb_categories))
preds = predictions[i]
cases_predictions[pid] = prev_pred + preds
cases_targets[pid] = targets[i]
y_pred = []
y_true = []
y_id = []
for pid in cases_predictions:
y_pred.append(cases_predictions[pid][0])
y_true.append(cases_targets[pid])
y_id.append(pid)
with open("./predictions_test.csv", "w") as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["ID", "Prediction", "Cancer"])
for pid, prediction, gt in zip(y_id, y_pred, y_true):
pid = pid.lstrip("0")
csvwriter.writerow([pid, prediction, gt[0]])
plot_roc(y_true, y_pred, figname="roc_test.png")
if __name__ == "__main__":
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
K.set_session(sess)
device = "/gpu:" + sys.argv[1]
with tf.device(device):
test()