-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_test_set_predictions.py
executable file
·70 lines (50 loc) · 1.77 KB
/
evaluate_test_set_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# evaluate_test_set_predictions.py
# Copyright 2019 Robert Jones [email protected] Craic Computing LLC
# This software is made freely available under the terms of the MIT license
# given a test file and a file of prediction results, report the record ID, the label and the prediction
# TSV input format is
#
# <id> <label> <arbitrary char> <text>
# Results file format is
#
# <probability state 0> <probability state 1>
import os
import argparse
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--tsv", required=True, help="path to training format TSV file")
ap.add_argument("--results", required=True, help="path to prediction results")
args = vars(ap.parse_args())
test_tsv_file = args["tsv"]
results_file = args["results"]
ids = []
labels = []
results = []
with open(test_tsv_file, 'r') as f:
for line in f:
fields = line.rstrip().split("\t")
ids.append(fields[0])
label = 0
if fields[1] == "1":
label = 1
labels.append(label)
with open(results_file, 'r') as f:
for line in f:
fields = line.rstrip().split("\t")
neg_prob = float(fields[0])
pos_prob = float(fields[1])
result = 0
if pos_prob > neg_prob:
result = 1
results.append(result)
n = len(ids)
n_correct = 0
for i in range(len(ids)):
flag = ""
if labels[i] == results[i]:
flag = "*"
n_correct += 1
print("{:s}\t{:d}\t{:d}\t{:s}".format(ids[i], labels[i], results[i], flag))
print("n {:d} correct {:d} {:3f}".format(n, n_correct, float(n_correct)/n))
if __name__ == "__main__":
main()