-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnx_inference.py
82 lines (65 loc) · 2.82 KB
/
onnx_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from configparser import Interpolation
import numpy as np
import os
import cv2
import onnx
import onnxruntime as ort
import argparse
from loguru import logger
@logger.catch()
def main(args):
model_path = args.model_path
image_path = args.image_path
labels = args.labels
batch_size = args.batch_size
img_size = args.img_size
target_size = args.target_size
# Some checks
if ort.get_device() == "GPU":
logger.info("Inferencing on GPU")
device = "GPU"
else:
logger.info("Inferencing on GPU")
device = "CPU"
providers = ["CUDAExecutionProvider"] if device=="GPU" else ["CPUExecutionProvider"]
# Set up labels
label_temp = labels.split(",")
labels = [i.strip() for i in label_temp]
# Load model
try:
model = onnx.load(model_path)
logger.info("Checking model...")
onnx.checker.check_model(model)
onnx.helper.printable_graph(model.graph)
logger.info("Model checked...")
except:
logger.info("Error while setting up model...")
try:
logger.info("Running inference...")
ort_session = ort.InferenceSession(model_path, providers=providers)
img_list = []
for image in os.listdir(image_path):
img = cv2.imread(os.path.join(image_path, image), cv2.IMREAD_COLOR)
img = cv2.resize(img, (img_size, img_size))
img = np.moveaxis(img, -1, 0) # (batch_size, width, heigth, channels) -> (batch_size, channels, width, heigth)
img_list.append(img/255.0) # Normalize the image
outputs = ort_session.run(None, {"input":img_list})
out = np.array(outputs)
for image_num, image_name in zip(range(out.shape[1]), os.listdir(image_path)):
index = out[0][image_num]
print("Image : {0}, Class : {1}".format(image_name, labels[np.argmax(index)]))
except Exception as e:
logger.info("Exception occured : ", e)
def arguement_parser():
parser = argparse.ArgumentParser(description="Parse input for model training")
parser.add_argument('--model_path', type=str, default="/home/sahil/Documents/Classifiers/Timm_Pipeline/classifier.onnx", help="PyTorch model path")
parser.add_argument('--image_path', type=str, default="/home/sahil/Documents/Classifiers/sample_data", help='Path to images')
parser.add_argument('--labels', type=str, default="buildings, forests, mountains, glacier, street, sea", help='labels')
parser.add_argument('--batch_size', type=int, default=4, help='batch size for inference')
parser.add_argument('--img_size', type=int, default=224, help='Input image size')
parser.add_argument('--target_size', type=int, default=6, help='Number of classes')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = arguement_parser()
main(args)