-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
28 lines (20 loc) · 3.12 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
import torch
acc1 = BooleanAccuracy()
acc2 = CategoricalAccuracy()
t = [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]
p = [[0.03937098756432533, 0.033875323832035065, 0.05988157168030739], [0.9998742341995239, 0.999727189540863, 0.9997968077659607], [0.07989044487476349, 0.4079543650150299, 0.0020040522795170546], [0.21875324845314026, 0.9538058638572693, 0.016182225197553635], [0.18089905381202698, 0.17723192274570465, 6.718063900734705e-07], [0.7318723201751709, 0.9755846858024597, 0.8627891540527344], [0.9454143047332764, 0.9724324941635132, 0.9989460110664368], [0.011294073425233364, 0.027126627042889595, 0.00047193223144859076], [0.793698251247406, 0.11089510470628738, 0.9887107014656067], [0.9415146112442017, 0.9599100947380066, 0.9392073154449463], [0.9976443648338318, 0.9777897596359253, 0.981238842010498], [0.9845157861709595, 0.9840825796127319, 0.9996356964111328], [0.507988452911377, 0.47337979078292847, 0.0045616538263857365], [0.9517683386802673, 0.3026416003704071, 0.9998192191123962], [0.9398511648178101, 0.9846286773681641, 0.9881545305252075], [0.45978933572769165, 0.8496305346488953, 0.21800664067268372], [0.14222687482833862, 0.686724841594696, 0.0027226004749536514], [0.9919804334640503, 0.9903678894042969, 0.9939683675765991], [0.9405227303504944, 0.9095374941825867, 0.999981164932251], [0.9996458292007446, 0.9999607801437378, 0.9987685084342957], [0.9981439113616943, 0.9996436834335327, 0.999957799911499], [0.9810636043548584, 0.9899020195007324, 0.14351676404476166], [0.035644277930259705, 0.19097495079040527, 0.018681759014725685], [0.07512019574642181, 0.19563919305801392, 0.06413274258375168], [0.6643369793891907, 0.88751620054245, 0.002016345737501979], [0.7577943205833435, 0.9865579009056091, 0.11876775324344635], [0.8203856945037842, 0.8323215246200562, 0.6869458556175232], [0.8889847993850708, 0.9707444906234741, 0.9969865679740906], [0.9793078303337097, 0.9979743361473083, 0.9975119829177856], [0.657213568687439, 0.46762025356292725, 0.9853191375732422], [0.3546280264854431, 0.4987420439720154, 0.9012939929962158], [0.21208778023719788, 0.9804193377494812, 0.13975170254707336]]
m = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
t = torch.tensor(t)
p = torch.tensor(p).float()
m = torch.tensor(m).float()
m *= 0
tl = torch.argmax(t, -1).float()
pl = torch.argmax(p, -1).float()
acc1(pl, tl, m)
print(tl)
print(pl)
print((tl == pl).sum().float() / tl.shape[0])
print(acc1.get_metric(True))
acc2(p, tl, m)
print(acc2.get_metric(True))