Skip to content

Upgrade perf_run script to support TRT 10 and fix some issues #3650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def constant_fold(
# The constants are created on CPU to save GPU memory for TensorRT compilation.
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
if node.target == torch.ops.aten.embedding.default:
continue
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from contextlib import nullcontext
from tempfile import tempdir
from typing import Any, Dict, List, Optional, Sequence, Tuple

import tensorrt as trt
Expand Down Expand Up @@ -539,7 +538,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:

with tempfile.TemporaryDirectory() as tmpdir:
self.cudagraph.debug_dump(
f"{tempdir}/{self.name}_cudagraph.dot"
f"{tmpdir}/{self.name}_cudagraph.dot"
)

self.cudagraph.replay() # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ This is a comprehensive Python benchmark suite to run perf runs using different
5. TensorRT


Note: Please note that for ONNX models, user can convert the ONNX model to TensorRT serialized engine and then use this package.

## Prerequisite

Benchmark scripts depends on following Python packages in addition to requirements.txt packages
Expand Down Expand Up @@ -47,13 +45,15 @@ Here are the list of `CompileSpec` options that can be provided directly to comp
* `--backends` : Comma separated string of backends. Eg: torch, torch_compile, dynamo, tensorrt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `dynamo` or `torch_compile`, the input should be a Pytorch module (instead of a torchscript module).
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if `dynamo` or `torch_compile` is a chosen backend)
* `--onnx` : ONNX model file which helps bypass the step of exporting ONNX from `model_torch`. If this argument is provided, the ONNX will be directly converted to TRT engine
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
* `--batch_size` : Batch size
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
* `--device` : Device ID
* `--truncate` : Truncate long and double weights in the network in Torch-TensorRT
* `--is_trt_engine` : Boolean flag to be enabled if the model file provided is a TensorRT engine.
* `--report` : Path of the output file where performance summary is written.
* `--optimization_level` : Builder optimization level for TensorRT (from 1 to 5, 5 is the highest optimization).

Eg:

Expand Down
162 changes: 111 additions & 51 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
compile_settings = {
"inputs": input_tensors,
"enabled_precisions": {precision_to_dtype(precision)},
"truncate_long_and_double": params.get("truncate", False),
"use_python_runtime": params.get("use_python_runtime", False),
"truncate_double": params.get("truncate", False),
}

if precision == "int8":
Expand Down Expand Up @@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
ir="dynamo",
enabled_precisions={precision_to_dtype(precision)},
min_block_size=params.get("min_block_size", 1),
debug=False,
truncate_long_and_double=params.get("truncate", False),
truncate_double=params.get("truncate", False),
immutable_weights=params.get("immutable_weights", True),
strip_engine_weights=params.get("strip_engine_weights", False),
refit_identical_engine_weights=params.get(
Expand All @@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
cache_built_engines=params.get("cache_built_engines", False),
reuse_cached_engines=params.get("reuse_cached_engines", False),
use_python_runtime=params.get("use_python_runtime", False),
optimization_level=params.get("optimization_level", 3),
)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
Expand Down Expand Up @@ -437,61 +436,106 @@ def run_tensorrt(
precision,
batch_size=1,
):
# Export an ONNX model and convert to TRT
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse_from_file("./tmp.onnx")
if not success:
raise ValueError("ONNX conversion failed")

config = builder.create_builder_config()
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
start_compile = timeit.default_timer()
serialized_engine = builder.build_serialized_network(network, config)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
compile_time_s = 0
if params["is_trt_engine"]:
serialized_engine = model
else:
if params["onnx"]:
onnx_path = params["onnx"]
else:
onnx_path = "./onnx-trt.onnx"
torch.onnx.export(model, tuple(input_tensors), onnx_path, dynamo=True)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
success = parser.parse_from_file(onnx_path)
if not success:
raise ValueError("ONNX conversion failed")

config = builder.create_builder_config()
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
config.builder_optimization_level = params.get("optimization_level", 3)
start_compile = timeit.default_timer()
serialized_engine = builder.build_serialized_network(network, config)
end_compile = timeit.default_timer()
compile_time_s = end_compile - start_compile
# Deserialize the TensorRT engine
with trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)

print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
iters = params.get("iterations", 20)

# Compiling the bindings
bindings = engine.num_bindings * [None]
k = 0
for idx, _ in enumerate(bindings):
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))
shape = tuple(engine.get_binding_shape(idx))
device = torch_device_from_trt(engine.get_location(idx))
if not engine.binding_is_input(idx):
# Output bindings
output = torch.empty(size=shape, dtype=dtype, device=device)
bindings[idx] = output.data_ptr()
else:
# Input bindings
bindings[idx] = input_tensors[k].data_ptr()
k += 1
start_time = timeit.default_timer()
# Get I/O tensor information using TensorRT 10 API
input_names = []
output_names = []
output_dtypes = []
output_shapes = []

for i in range(engine.num_io_tensors):
tensor_name = engine.get_tensor_name(i)
tensor_mode = engine.get_tensor_mode(tensor_name)
tensor_dtype = engine.get_tensor_dtype(tensor_name)
tensor_shape = engine.get_tensor_shape(tensor_name)

if tensor_mode == trt.TensorIOMode.INPUT:
input_names.append(tensor_name)
else: # trt.TensorIOMode.OUTPUT
output_names.append(tensor_name)
output_dtypes.append(torch_dtype_from_trt(tensor_dtype))
output_shapes.append(tuple(tensor_shape))

# Create output tensors
output_tensors = []
for i, (shape, dtype) in enumerate(zip(output_shapes, output_dtypes)):
output = torch.empty(size=shape, dtype=dtype, device="cuda")
output_tensors.append(output)

timings = []
with engine.create_execution_context() as context:
# Set input tensor addresses
for i, (input_name, input_tensor) in enumerate(zip(input_names, input_tensors)):
context.set_tensor_address(input_name, input_tensor.data_ptr())

# Set output tensor addresses
for output_name, output_tensor in zip(output_names, output_tensors):
context.set_tensor_address(output_name, output_tensor.data_ptr())

# Create a dedicated stream for TensorRT execution
dedicated_stream = torch.cuda.Stream()
current_stream = torch.cuda.current_stream()

setup_time = timeit.default_timer()

# Warm up
for i in range(WARMUP_ITER):
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
# Wait for current stream to finish
dedicated_stream.wait_stream(current_stream)
context.execute_async_v3(dedicated_stream.cuda_stream)
# Wait for TensorRT stream to finish
current_stream.wait_stream(dedicated_stream)
torch.cuda.synchronize()

infer_start_time = timeit.default_timer()
# Performance measurement
for i in range(iters):
start_time = timeit.default_timer()
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
# Wait for current stream to finish
dedicated_stream.wait_stream(current_stream)
context.execute_async_v3(dedicated_stream.cuda_stream)
# Wait for TensorRT stream to finish
current_stream.wait_stream(dedicated_stream)
torch.cuda.synchronize()
end_time = timeit.default_timer()
meas_time = end_time - start_time
timings.append(meas_time)

end_time = timeit.default_timer()

# to compare against torch-trt dynamo apples to apples
infer_time = (end_time - infer_start_time + setup_time - start_time) / iters
timings.append(infer_time)

recordStats("TensorRT", timings, precision, batch_size, compile_time_s)

Expand All @@ -504,7 +548,6 @@ def run(
params,
precision,
batch_size=1,
is_trt_engine=False,
model_torch=None,
):
for backend in backends:
Expand All @@ -523,7 +566,7 @@ def run(
print("int8 precision expects calibration cache file for inference")
return False

if (model is None) and (backend in ("tensorrt", "ts_trt", "all")):
if (model is None) and (backend in ("ts_trt", "all")):
warnings.warn(
f"Requested backend {backend} without specifying a TorchScript Model, "
+ "skipping this backend"
Expand All @@ -547,11 +590,10 @@ def run(
batch_size,
)
run_tensorrt(
model,
model_torch,
input_tensors,
params,
precision,
is_trt_engine,
batch_size,
)
run_dynamo(model_torch, input_tensors, params, precision, batch_size)
Expand Down Expand Up @@ -604,6 +646,12 @@ def run(
default="",
help="Name of torch model file",
)
arg_parser.add_argument(
"--onnx",
type=str,
default="",
help="ONNX model file which helps bypass the step of exporting ONNX from torchscript model. If this argument is provided, the ONNX will be directly converted to TRT engine",
)
arg_parser.add_argument(
"--inputs",
type=str,
Expand Down Expand Up @@ -643,6 +691,12 @@ def run(
action="store_true",
help="Truncate long and double weights in the network in Torch-TensorRT",
)
arg_parser.add_argument(
"--optimization_level",
type=int,
default=3,
help="Builder optimization level for TensorRT",
)
arg_parser.add_argument(
"--is_trt_engine",
action="store_true",
Expand Down Expand Up @@ -702,8 +756,13 @@ def run(

# Load TorchScript model, if provided
if os.path.exists(model_name):
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()
if params["is_trt_engine"]:
with open(model_name, "rb") as f:
model = f.read()
print("Loading user provided trt engine: ", model_name)
else:
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()

# Load PyTorch Model, if provided
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
Expand All @@ -719,7 +778,9 @@ def run(
)

backends = parse_backends(params["backends"])
if ("dynamo" in backends or "torch_compile" in backends) and (model_torch is None):
if any(
backend in ["dynamo", "torch_compile", "tensorrt"] for backend in backends
) and (model_torch is None):
raise ValueError(
"No Pytorch model (nn.Module) is provided for torchdynamo compilation. Please provide a pytorch model using --model_torch argument"
)
Expand All @@ -746,7 +807,6 @@ def run(
params,
precision,
batch_size,
is_trt_engine,
model_torch=model_torch,
)

Expand Down
3 changes: 1 addition & 2 deletions tools/perf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ pyyaml
onnx
pandas
transformers
diffusers==0.21.4
diffusers
timm==0.9.8

2 changes: 2 additions & 0 deletions tools/perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def torch_dtype_from_trt(dtype):
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.int64:
return torch.int64
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
Expand Down
Loading