-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
46 lines (32 loc) · 1.51 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tqdm
import numpy as np
import torch
from torch.autograd import Variable
from utils.datasets import *
from utils.utils import *
def evaluate(model, path, iou_thres, conf_thres, nms_thres, img_size, batch_size):
model.eval() # 验证模式
# 加载验证集
dataset = ListDataset(path, img_size=img_size, augment=False, multiscale=False)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn
)
# GPU
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
labels = []
sample_metrics = []
for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc="Detecting objects")):
# Extract labels
labels += targets[:, 1].tolist()
# Rescale target
targets[:, 2:] = xywh2xyxy(targets[:, 2:])
targets[:, 2:] *= img_size
imgs = Variable(imgs.type(Tensor), requires_grad=False)
with torch.no_grad():
outputs = model(imgs)
outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=nms_thres)
sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=iou_thres)
# Concatenate sample statistics
true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]
precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, labels)
return precision, recall, AP, f1, ap_class