-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpose_estimation_gpu.py
More file actions
118 lines (92 loc) · 3.82 KB
/
pose_estimation_gpu.py
File metadata and controls
118 lines (92 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.nn as nn
import torch.utils.data
import numpy as np
# from pose_estimation_loader import ImageLoader, DetectionLoader, DetectionProcessor, DataWriter, Mscoco
# from Alpha.opt import opt
from AlphaPose.dataloader import ImageLoader, DetectionLoader, DetectionProcessor, DataWriter, Mscoco
from AlphaPose.fn import getTime
from AlphaPose.pPose_nms import pose_nms, write_json
from yolo.util import write_results, dynamic_write_results
from SPPE.src.main_fast_inference import *
import os
import glob
import sys
from tqdm import tqdm
import time
class PoseEstimator:
def __init__(self, input_path, output_path):
super(PoseEstimator, self).__init__()
torch.multiprocessing.set_start_method('spawn', force=True)
self.inputpath = input_path
self.outputpath = output_path
if not os.path.exists(self.outputpath):
os.mkdir(self.outputpath)
torch.cuda.empty_cache()
# if len(self.inputpath):
# with open(self.inputpath, 'r') as file:
# self.im_names = file.read().splitlines()
# print(self.im_names)
# else:
# raise IOError('Error: must contain either --indir/--list')
def run(self):
# Load input images
data_loader = ImageLoader(self.inputpath, batchSize=1, format='yolo').start()
print('Loading YOLO model..')
sys.stdout.flush()
# Load detection loader
det_loader = DetectionLoader(data_loader, batchSize=1).start()
det_processor = DetectionProcessor(det_loader).start()
# Load pose model
pose_dataset = Mscoco()
pose_model = InferenNet_fast(4 * 1 + 1, pose_dataset)
pose_model.cuda()
pose_model.eval()
runtime_profile = {
'dt': [],
'pt': [],
'pn': []
}
# Init data writer
self.writer = DataWriter(False).start()
data_len = data_loader.length()
im_names_desc = tqdm(range(data_len))
batchSize = 80
for i in im_names_desc:
start_time = getTime()
with torch.no_grad():
(inps, orig_img, im_name, boxes, scores, pt1, pt2) = det_processor.read()
if boxes is None or boxes.nelement() == 0:
self.writer.save(None, None, None, None, None, orig_img, im_name.split('/')[-1])
continue
ckpt_time, det_time = getTime(start_time)
runtime_profile['dt'].append(det_time)
# Pose Estimation
datalen = inps.size(0)
leftover = 0
if (datalen) % batchSize:
leftover = 1
num_batches = datalen // batchSize + leftover
hm = []
for j in range(num_batches):
inps_j = inps[j*batchSize:min((j + 1)*batchSize, datalen)].cuda()
hm_j = pose_model(inps_j)
hm.append(hm_j)
hm = torch.cat(hm)
ckpt_time, pose_time = getTime(ckpt_time)
runtime_profile['pt'].append(pose_time)
hm = hm.cpu()
self.writer.save(boxes, scores, hm, pt1, pt2, orig_img, im_name.split('/')[-1])
ckpt_time, post_time = getTime(ckpt_time)
runtime_profile['pn'].append(post_time)
torch.cuda.empty_cache()
print('===========================> Finish Model Running.')
while(self.writer.running()):
pass
self.writer.stop()
def save_result(self):
final_result = self.writer.results()
return write_json(final_result, self.outputpath)