From d86a0eade6919942d4f45530d3c805d4eeea95ae Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 20 Sep 2021 19:40:36 +0800 Subject: [PATCH 1/2] decrease memory use by remove unnecessary module inference --- torch2trt/torch2trt.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..ece88d08 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -522,18 +522,16 @@ def torch2trt(module, inputs = tuple(inputs) if not isinstance(inputs, tuple): inputs = (inputs,) - - # run once to get num outputs - outputs = module(*inputs) - if not isinstance(outputs, tuple) and not isinstance(outputs, list): - outputs = (outputs,) - if input_names is None: input_names = default_input_names(len(inputs)) - if output_names is None: - output_names = default_output_names(len(outputs)) - + if use_onnx: + # run once to get num outputs + outputs = module(*inputs) + if not isinstance(outputs, tuple) and not isinstance(outputs, list): + outputs = (outputs,) + if output_names is None: + output_names = default_output_names(len(outputs)) f = io.BytesIO() torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) @@ -553,6 +551,8 @@ def torch2trt(module, if not isinstance(outputs, tuple) and not isinstance(outputs, list): outputs = (outputs,) + if output_names is None: + output_names = default_output_names(len(outputs)) ctx.mark_outputs(outputs, output_names) builder.max_workspace_size = max_workspace_size From d68b3e0fee5d3a34c185c81ddb37030468444e4c Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 21 Sep 2021 10:54:34 +0800 Subject: [PATCH 2/2] delete and empty cuda cache before build_cuda_engine to provide more GPU memory for build --- torch2trt/torch2trt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index ece88d08..e07d37a7 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -575,7 +575,11 @@ def torch2trt(module, inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) + del inputs + del outputs + torch.cuda.empty_cache() engine = builder.build_cuda_engine(network) + torch.cuda.empty_cache() module_trt = TRTModule(engine, input_names, output_names)