18
18
import sys
19
19
20
20
from numpy .core .fromnumeric import trace
21
+
21
22
sys .path .append ("./" )
22
23
23
- import logging
24
24
import argparse
25
+ import logging
26
+ import traceback
25
27
26
28
import numpy as np
27
- import tensorrt as trt
28
- import pycuda .driver as cuda
29
29
import pycuda .autoinit
30
- import traceback
31
-
30
+ import pycuda . driver as cuda
31
+ import tensorrt as trt
32
32
from yolort .v5 .utils .datasets import IMG_FORMATS , VID_FORMATS , LoadImages
33
33
34
34
logging .basicConfig (level = logging .INFO )
@@ -54,8 +54,10 @@ def __init__(self, calib_shape=None, calib_dtype=None) -> None:
54
54
self .shape = (self .batch_size , 3 , * calib_shape )
55
55
self .num_images = len (self .dataset )
56
56
self .image_index = 0
57
-
58
- def get_batch (self , ):
57
+
58
+ def get_batch (
59
+ self ,
60
+ ):
59
61
return iter (self .dataset )
60
62
61
63
@@ -73,7 +75,7 @@ def __init__(self, cache_file):
73
75
self .image_batcher : ImageBatcher = None
74
76
self .batch_allocation = None
75
77
self .batch_generator = None
76
-
78
+
77
79
def set_image_batcher (self , image_batcher : ImageBatcher ):
78
80
"""
79
81
Define the image batcher to use, if any. If using only the cache file, an image batcher doesn't need
@@ -111,16 +113,20 @@ def get_batch(self, names):
111
113
image = image [np .newaxis , :, :, :]
112
114
batch , _ , _ , _ = image .shape
113
115
self .image_batcher .image_index += 1
114
-
115
- log .info ("Calibrating image {} / {}" .format (self .image_batcher .image_index , self .image_batcher .num_images ))
116
+
117
+ log .info (
118
+ "Calibrating image {} / {}" .format (
119
+ self .image_batcher .image_index , self .image_batcher .num_images
120
+ )
121
+ )
116
122
cuda .memcpy_htod (self .batch_allocation , np .ascontiguousarray (batch ))
117
123
return [int (self .batch_allocation )]
118
124
except StopIteration :
119
125
log .info ("Finished calibration batches" )
120
126
return None
121
127
except Exception :
122
128
traceback .print_exc ()
123
-
129
+
124
130
def read_calibration_cache (self ):
125
131
"""
126
132
Overrides from trt.IInt8EntropyCalibrator2.
@@ -171,7 +177,7 @@ def create_network(self, onnx_path):
171
177
Parse the ONNX graph and create the corresponding TensorRT network definition.
172
178
:param onnx_path: The path to the ONNX graph to load.
173
179
"""
174
- network_flags = ( 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH ) )
180
+ network_flags = 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
175
181
176
182
self .network = self .builder .create_network (network_flags )
177
183
self .parser = trt .OnnxParser (self .network , self .trt_logger )
@@ -196,8 +202,16 @@ def create_network(self, onnx_path):
196
202
assert self .batch_size > 0
197
203
self .builder .max_batch_size = self .batch_size
198
204
199
- def create_engine (self , engine_path , precision , calib_input = None , calib_cache = None , calib_num_images = 25000 ,
200
- calib_batch_size = 8 , calib_preprocessor = None ):
205
+ def create_engine (
206
+ self ,
207
+ engine_path ,
208
+ precision ,
209
+ calib_input = None ,
210
+ calib_cache = None ,
211
+ calib_num_images = 25000 ,
212
+ calib_batch_size = 8 ,
213
+ calib_preprocessor = None ,
214
+ ):
201
215
"""
202
216
Build the TensorRT engine and serialize it to disk.
203
217
:param engine_path: The path where to serialize the engine to.
@@ -229,9 +243,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
229
243
if not os .path .exists (calib_cache ):
230
244
calib_shape = [calib_batch_size ] + list (inputs [0 ].shape [1 :])
231
245
calib_dtype = trt .nptype (inputs [0 ].dtype )
232
- self .config .int8_calibrator .set_image_batcher (
233
- ImageBatcher (calib_shape , calib_dtype )
234
- )
246
+ self .config .int8_calibrator .set_image_batcher (ImageBatcher (calib_shape , calib_dtype ))
235
247
236
248
with self .builder .build_engine (self .network , self .config ) as engine :
237
249
with open (engine_path , "wb" ) as f :
@@ -242,26 +254,53 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
242
254
def main (args ):
243
255
builder = EngineBuilder (args .verbose )
244
256
builder .create_network (args .onnx )
245
- builder .create_engine (args .engine , args .precision , args .calib_input , args .calib_cache , args .calib_num_images ,
246
- args .calib_batch_size , args .calib_preprocessor )
257
+ builder .create_engine (
258
+ args .engine ,
259
+ args .precision ,
260
+ args .calib_input ,
261
+ args .calib_cache ,
262
+ args .calib_num_images ,
263
+ args .calib_batch_size ,
264
+ args .calib_preprocessor ,
265
+ )
247
266
248
267
249
268
if __name__ == "__main__" :
250
269
parser = argparse .ArgumentParser ()
251
270
parser .add_argument ("-o" , "--onnx" , help = "The input ONNX model file to load" )
252
271
parser .add_argument ("-e" , "--engine" , help = "The output path for the TRT engine" )
253
- parser .add_argument ("-p" , "--precision" , default = "fp16" , choices = ["fp32" , "fp16" , "int8" ],
254
- help = "The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'" )
272
+ parser .add_argument (
273
+ "-p" ,
274
+ "--precision" ,
275
+ default = "fp16" ,
276
+ choices = ["fp32" , "fp16" , "int8" ],
277
+ help = "The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'" ,
278
+ )
255
279
parser .add_argument ("-v" , "--verbose" , action = "store_true" , help = "Enable more verbose log output" )
256
280
parser .add_argument ("--calib_input" , help = "The directory holding images to use for calibration" )
257
- parser .add_argument ("--calib_cache" , default = "./calibration.cache" ,
258
- help = "The file path for INT8 calibration cache to use, default: ./calibration.cache" )
259
- parser .add_argument ("--calib_num_images" , default = 10 , type = int ,
260
- help = "The maximum number of images to use for calibration, default: 25000" )
261
- parser .add_argument ("--calib_batch_size" , default = 1 , type = int ,
262
- help = "The batch size for the calibration process, default: 1" )
263
- parser .add_argument ("--calib_preprocessor" , default = "V2" , choices = ["V1" , "V1MS" , "V2" ],
264
- help = "Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2" )
281
+ parser .add_argument (
282
+ "--calib_cache" ,
283
+ default = "./calibration.cache" ,
284
+ help = "The file path for INT8 calibration cache to use, default: ./calibration.cache" ,
285
+ )
286
+ parser .add_argument (
287
+ "--calib_num_images" ,
288
+ default = 10 ,
289
+ type = int ,
290
+ help = "The maximum number of images to use for calibration, default: 25000" ,
291
+ )
292
+ parser .add_argument (
293
+ "--calib_batch_size" ,
294
+ default = 1 ,
295
+ type = int ,
296
+ help = "The batch size for the calibration process, default: 1" ,
297
+ )
298
+ parser .add_argument (
299
+ "--calib_preprocessor" ,
300
+ default = "V2" ,
301
+ choices = ["V1" , "V1MS" , "V2" ],
302
+ help = "Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2" ,
303
+ )
265
304
args = parser .parse_args ()
266
305
if not all ([args .onnx , args .engine ]):
267
306
parser .print_help ()
0 commit comments