Skip to content

Commit 1027245

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d8615d3 commit 1027245

File tree

1 file changed

+68
-29
lines changed

1 file changed

+68
-29
lines changed

notebooks/build_engine.py

+68-29
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
import sys
1919

2020
from numpy.core.fromnumeric import trace
21+
2122
sys.path.append("./")
2223

23-
import logging
2424
import argparse
25+
import logging
26+
import traceback
2527

2628
import numpy as np
27-
import tensorrt as trt
28-
import pycuda.driver as cuda
2929
import pycuda.autoinit
30-
import traceback
31-
30+
import pycuda.driver as cuda
31+
import tensorrt as trt
3232
from yolort.v5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages
3333

3434
logging.basicConfig(level=logging.INFO)
@@ -54,8 +54,10 @@ def __init__(self, calib_shape=None, calib_dtype=None) -> None:
5454
self.shape = (self.batch_size, 3, *calib_shape)
5555
self.num_images = len(self.dataset)
5656
self.image_index = 0
57-
58-
def get_batch(self, ):
57+
58+
def get_batch(
59+
self,
60+
):
5961
return iter(self.dataset)
6062

6163

@@ -73,7 +75,7 @@ def __init__(self, cache_file):
7375
self.image_batcher: ImageBatcher = None
7476
self.batch_allocation = None
7577
self.batch_generator = None
76-
78+
7779
def set_image_batcher(self, image_batcher: ImageBatcher):
7880
"""
7981
Define the image batcher to use, if any. If using only the cache file, an image batcher doesn't need
@@ -111,16 +113,20 @@ def get_batch(self, names):
111113
image = image[np.newaxis, :, :, :]
112114
batch, _, _, _ = image.shape
113115
self.image_batcher.image_index += 1
114-
115-
log.info("Calibrating image {} / {}".format(self.image_batcher.image_index, self.image_batcher.num_images))
116+
117+
log.info(
118+
"Calibrating image {} / {}".format(
119+
self.image_batcher.image_index, self.image_batcher.num_images
120+
)
121+
)
116122
cuda.memcpy_htod(self.batch_allocation, np.ascontiguousarray(batch))
117123
return [int(self.batch_allocation)]
118124
except StopIteration:
119125
log.info("Finished calibration batches")
120126
return None
121127
except Exception:
122128
traceback.print_exc()
123-
129+
124130
def read_calibration_cache(self):
125131
"""
126132
Overrides from trt.IInt8EntropyCalibrator2.
@@ -171,7 +177,7 @@ def create_network(self, onnx_path):
171177
Parse the ONNX graph and create the corresponding TensorRT network definition.
172178
:param onnx_path: The path to the ONNX graph to load.
173179
"""
174-
network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
180+
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
175181

176182
self.network = self.builder.create_network(network_flags)
177183
self.parser = trt.OnnxParser(self.network, self.trt_logger)
@@ -196,8 +202,16 @@ def create_network(self, onnx_path):
196202
assert self.batch_size > 0
197203
self.builder.max_batch_size = self.batch_size
198204

199-
def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000,
200-
calib_batch_size=8, calib_preprocessor=None):
205+
def create_engine(
206+
self,
207+
engine_path,
208+
precision,
209+
calib_input=None,
210+
calib_cache=None,
211+
calib_num_images=25000,
212+
calib_batch_size=8,
213+
calib_preprocessor=None,
214+
):
201215
"""
202216
Build the TensorRT engine and serialize it to disk.
203217
:param engine_path: The path where to serialize the engine to.
@@ -229,9 +243,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
229243
if not os.path.exists(calib_cache):
230244
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
231245
calib_dtype = trt.nptype(inputs[0].dtype)
232-
self.config.int8_calibrator.set_image_batcher(
233-
ImageBatcher(calib_shape, calib_dtype)
234-
)
246+
self.config.int8_calibrator.set_image_batcher(ImageBatcher(calib_shape, calib_dtype))
235247

236248
with self.builder.build_engine(self.network, self.config) as engine:
237249
with open(engine_path, "wb") as f:
@@ -242,26 +254,53 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No
242254
def main(args):
243255
builder = EngineBuilder(args.verbose)
244256
builder.create_network(args.onnx)
245-
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
246-
args.calib_batch_size, args.calib_preprocessor)
257+
builder.create_engine(
258+
args.engine,
259+
args.precision,
260+
args.calib_input,
261+
args.calib_cache,
262+
args.calib_num_images,
263+
args.calib_batch_size,
264+
args.calib_preprocessor,
265+
)
247266

248267

249268
if __name__ == "__main__":
250269
parser = argparse.ArgumentParser()
251270
parser.add_argument("-o", "--onnx", help="The input ONNX model file to load")
252271
parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
253-
parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"],
254-
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'")
272+
parser.add_argument(
273+
"-p",
274+
"--precision",
275+
default="fp16",
276+
choices=["fp32", "fp16", "int8"],
277+
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
278+
)
255279
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
256280
parser.add_argument("--calib_input", help="The directory holding images to use for calibration")
257-
parser.add_argument("--calib_cache", default="./calibration.cache",
258-
help="The file path for INT8 calibration cache to use, default: ./calibration.cache")
259-
parser.add_argument("--calib_num_images", default=10, type=int,
260-
help="The maximum number of images to use for calibration, default: 25000")
261-
parser.add_argument("--calib_batch_size", default=1, type=int,
262-
help="The batch size for the calibration process, default: 1")
263-
parser.add_argument("--calib_preprocessor", default="V2", choices=["V1", "V1MS", "V2"],
264-
help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2")
281+
parser.add_argument(
282+
"--calib_cache",
283+
default="./calibration.cache",
284+
help="The file path for INT8 calibration cache to use, default: ./calibration.cache",
285+
)
286+
parser.add_argument(
287+
"--calib_num_images",
288+
default=10,
289+
type=int,
290+
help="The maximum number of images to use for calibration, default: 25000",
291+
)
292+
parser.add_argument(
293+
"--calib_batch_size",
294+
default=1,
295+
type=int,
296+
help="The batch size for the calibration process, default: 1",
297+
)
298+
parser.add_argument(
299+
"--calib_preprocessor",
300+
default="V2",
301+
choices=["V1", "V1MS", "V2"],
302+
help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2",
303+
)
265304
args = parser.parse_args()
266305
if not all([args.onnx, args.engine]):
267306
parser.print_help()

0 commit comments

Comments
 (0)