diff --git a/lenet/lenet.py b/lenet/lenet.py index 593f7bf2..4de6ba97 100644 --- a/lenet/lenet.py +++ b/lenet/lenet.py @@ -27,7 +27,7 @@ def load_weights(file): weight_map = {} with open(file, "r") as f: - lines = f.readlines() + lines = [line.strip() for line in f] count = int(lines[0]) assert count == len(lines) - 1 for i in range(1, count + 1): diff --git a/resnet/resnet50.py b/resnet/resnet50.py index fa3c84a9..9d7ddf48 100644 --- a/resnet/resnet50.py +++ b/resnet/resnet50.py @@ -29,7 +29,7 @@ def load_weights(file): weight_map = {} with open(file, "r") as f: - lines = f.readlines() + lines = [line.strip() for line in f] count = int(lines[0]) assert count == len(lines) - 1 for i in range(1, count + 1): @@ -138,7 +138,7 @@ def bottleneck(network, weight_map, input, in_channels, out_channels, stride, return relu3 -def createLenetEngine(maxBatchSize, builder, config, dt): +def create_engine(maxBatchSize, builder, config, dt): weight_map = load_weights(WEIGHT_PATH) network = builder.create_network() @@ -233,7 +233,7 @@ def createLenetEngine(maxBatchSize, builder, config, dt): def APIToModel(maxBatchSize): builder = trt.Builder(TRT_LOGGER) config = builder.create_builder_config() - engine = createLenetEngine(maxBatchSize, builder, config, trt.float32) + engine = create_engine(maxBatchSize, builder, config, trt.float32) assert engine with open(ENGINE_PATH, "wb") as f: f.write(engine.serialize()) diff --git a/tsm/README.md b/tsm/README.md new file mode 100644 index 00000000..d10e5063 --- /dev/null +++ b/tsm/README.md @@ -0,0 +1,66 @@ +# Temporal Shift Module + +TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" + +TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2. + +For the PyTorch implementation, you can refer to [open-mmlab/mmaction2](https://github.com/open-mmlab/mmaction2) or [mit-han-lab/temporal-shift-module](https://github.com/mit-han-lab/temporal-shift-module). + +More details about the shift module(which is the core of TSM) could to [test_shift.py](./test_shift.py). + +## Tutorial + ++ An example could refer to [demo.sh](./demo.sh) + + Requirements: Successfully installed `torch>=1.3.0, torchvision` + ++ Step 1: Train/Download TSM-R50 checkpoints from [offical Github repo](https://github.com/mit-han-lab/temporal-shift-module) or [MMAction2](https://github.com/open-mmlab/mmaction2) + + Supported settings: `num_segments`, `shift_div`, `num_classes`. + + Fixed settings: `backbone`(ResNet50), `shift_place`(blockres), `temporal_pool`(False). + ++ Step 2: Convert PyTorch checkpoints to TensorRT weights. + +```shell +python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts +``` + ++ Step 3: Modify configs in `tsm_r50.py`. + +```python +BATCH_SIZE = 1 +NUM_SEGMENTS = 8 +INPUT_H = 224 +INPUT_W = 224 +OUTPUT_SIZE = 400 +SHIFT_DIV = 8 +``` + ++ Step 4: Inference with `tsm_r50.py`. + +```shell +usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT] + +optional arguments: + -h, --help show this help message and exit + --tensorrt-weights TENSORRT_WEIGHTS + Path to TensorRT weights, which is generated by gen_weights.py + --input-video INPUT_VIDEO + Path to local video file + --save-engine-path SAVE_ENGINE_PATH + Save engine to local file + --load-engine-path LOAD_ENGINE_PATH + Saved engine file path + --test-mmaction2 Compare TensorRT results with MMAction2 Results + --mmaction2-config MMACTION2_CONFIG + Path to MMAction2 config file + --mmaction2-checkpoint MMACTION2_CHECKPOINT + Path to MMAction2 checkpoint url or file path +``` + +## TODO + ++ [x] Python Shift module. ++ [x] Generate wts of official tsm and mmaction2 tsm. ++ [x] Python API Definition ++ [x] Test with mmaction2 demo ++ [x] Tutorial ++ [ ] C++ API Definition diff --git a/tsm/demo.sh b/tsm/demo.sh new file mode 100644 index 00000000..33b17a5d --- /dev/null +++ b/tsm/demo.sh @@ -0,0 +1,43 @@ +# Step 1: Get checkpoints from mmaction2 +# https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsm +wget https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_kinetics400_rgb/tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth + +# Step 2: Convert pytorch checkpoints to TensorRT weights +python gen_wts.py tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth --out-filename ./tsm_r50_kinetics400_mmaction2.wts + +# Step 3: Skip this step since we use default settings. + +# Step 4: Inference +# 1) Save local engine file to `./tsm_r50_kinetics400_mmaction2.trt`. +python tsm_r50.py \ + --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ + --save-engine-path ./tsm_r50_kinetics400_mmaction2.trt + +# 2) Predict the recognition result using a single video `demo.mp4`. +# Should print `Result class id 6`, aka `arm wrestling` +# Download demo video +wget https://raw.githubusercontent.com/open-mmlab/mmaction2/master/demo/demo.mp4 +# # use *.wts as input +# python tsm_r50.py --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ +# --input-video ./demo.mp4 +# use engine file as input +python tsm_r50.py --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \ + --input-video ./demo.mp4 + +# 3) Optional: Compare inference result with MMAction2 TSM-R50 model +# Have to install MMAction2 First, please refer to https://github.com/open-mmlab/mmaction2/blob/master/docs/install.md +# pip3 install pytest-runner +# pip3 install mmcv +# pip3 install mmaction2 +# # use *.wts as input +# python tsm_r50.py \ +# --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ +# --test-mmaction2 \ +# --mmaction2-config mmaction2_tsm_r50_config.py \ +# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth +# # use TensorRT engine as input +# python tsm_r50.py \ +# --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \ +# --test-mmaction2 \ +# --mmaction2-config mmaction2_tsm_r50_config.py \ +# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth diff --git a/tsm/gen_wts.py b/tsm/gen_wts.py new file mode 100644 index 00000000..1c188333 --- /dev/null +++ b/tsm/gen_wts.py @@ -0,0 +1,46 @@ +import argparse +import struct + +import torch +import numpy as np + + +def write_one_weight(writer, name, weight): + assert isinstance(weight, np.ndarray) + values = weight.reshape(-1) + writer.write('{} {}'.format(name, len(values))) + for value in values: + writer.write(' ') + # float to bytes to hex_string + writer.write(struct.pack('>f', float(value)).hex()) + writer.write('\n') + + +def convert_name(name): + return name.replace("module.", "").replace("base_model.", "").\ + replace("net.", "").replace("new_fc", "fc").replace("backbone.", "").\ + replace("cls_head.fc_cls", "fc").replace(".conv.", ".").\ + replace("conv1.bn", "bn1").replace("conv2.bn", "bn2").\ + replace("conv3.bn", "bn3").replace("downsample.bn", "downsample.1").\ + replace("downsample.weight", "downsample.0.weight") + + +def main(args): + ckpt = torch.load(args.checkpoint)['state_dict'] + ckpt = {k: v for k, v in ckpt.items() if 'num_batches_tracked' not in k} + with open(args.out_filename, "w") as f: + f.write(f"{len(ckpt)}\n") + for k, v in ckpt.items(): + key = convert_name(k) + write_one_weight(f, key, v.cpu().numpy()) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint", type=str, help="Path to checkpoint file") + parser.add_argument("--out-filename", + type=str, + default="tsm_r50.wts", + help="Path to converted wegiths file") + args = parser.parse_args() + main(args) diff --git a/tsm/mmaction2_tsm_r50_config.py b/tsm/mmaction2_tsm_r50_config.py new file mode 100644 index 00000000..477497b6 --- /dev/null +++ b/tsm/mmaction2_tsm_r50_config.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='Recognizer2D', + backbone=dict( + type='ResNetTSM', + pretrained='torchvision://resnet50', + depth=50, + norm_eval=False, + shift_div=8), + cls_head=dict( + type='TSMHead', + num_classes=400, + in_channels=2048, + spatial_type='avg', + consensus=dict(type='AvgConsensus', dim=1), + dropout_ratio=0.5, + init_std=0.001, + is_shift=True), + # model training and testing settings + train_cfg=None, + test_cfg=dict(average_clips='prob')) diff --git a/tsm/test_shift.py b/tsm/test_shift.py new file mode 100644 index 00000000..f42f878d --- /dev/null +++ b/tsm/test_shift.py @@ -0,0 +1,218 @@ +import numpy as np +import pycuda.autoinit # noqa +import pycuda.driver as cuda +import tensorrt as trt +import torch +from numpy.testing import assert_array_almost_equal + +INPUT_BLOB_NAME = 'input' +OUTPUT_BLOB_NAME = 'output' + + +def shift_mit(x, num_segments, shift_div=8): + """Official temporal shift module. + + Code Reference: https://github.com/mit-han-lab/temporal-shift-module/blob/master/ops/temporal_shift.py # noqa + Cannot convert to ONNX Model. + """ + nt, c, h, w = x.size() + n_batch = nt // num_segments + x = x.view(n_batch, num_segments, c, h, w) + + fold = c // shift_div + + out = torch.zeros_like(x) + out[:, :-1, :fold] = x[:, 1:, :fold] # shift left + out[:, 1:, fold:2 * fold] = x[:, :-1, fold:2 * fold] # shift right + out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift + + return out.view(nt, c, h, w) + + +def shift_mmaction2(x, num_segments, shift_div=8): + """MMAction2 temporal shift module. + + Code Reference: https://github.com/open-mmlab/mmaction2/blob/master/mmaction/models/backbones/resnet_tsm.py # noqa + Could convert to ONNX Model. + """ + # [N, C, H, W] + n, c, h, w = x.size() + + # [N // num_segments, num_segments, C, H*W] + # can't use 5 dimensional array on PPL2D backend for caffe + x = x.view(-1, num_segments, c, h * w) + + # get shift fold + fold = c // shift_div + + # split c channel into three parts: + # left_split, mid_split, right_split + left_split = x[:, :, :fold, :] + mid_split = x[:, :, fold:2 * fold, :] + right_split = x[:, :, 2 * fold:, :] + + # can't use torch.zeros(*A.shape) or torch.zeros_like(A) + # because array on caffe inference must be got by computing + + # shift left on num_segments channel in `left_split` + zeros = left_split - left_split + blank = zeros[:, :1, :, :] + left_split = left_split[:, 1:, :, :] + left_split = torch.cat((left_split, blank), 1) + + # shift right on num_segments channel in `mid_split` + zeros = mid_split - mid_split + blank = zeros[:, :1, :, :] + mid_split = mid_split[:, :-1, :, :] + mid_split = torch.cat((blank, mid_split), 1) + + # right_split: no shift + + # concatenate + out = torch.cat((left_split, mid_split, right_split), 2) + + # [N, C, H, W] + # restore the original dimension + return out.view(n, c, h, w) + + +def _tensorrt_shift_module(network, + input, + num_segments=8, + shift_div=8, + input_shape=(16, 64, 32, 32)): + """Temporal shift module implemented by TensorRT Network Definition API.""" + fold = input_shape[1] // shift_div + batch_size = input_shape[0] // num_segments + + # reshape + reshape = network.add_shuffle(input) + assert reshape + reshape.reshape_dims = (batch_size, num_segments) + tuple(input_shape[-3:]) + + # left + left_split = network.add_slice(reshape.get_output(0), + start=(0, 1, 0, 0, 0), + shape=(batch_size, num_segments - 1, fold, + input_shape[2], input_shape[3]), + stride=(1, 1, 1, 1, 1)) + assert left_split + left_split_shape = (batch_size, 1, fold, input_shape[2], input_shape[3]) + left_blank = network.add_constant(shape=left_split_shape, + weights=np.zeros(left_split_shape, + np.float32)) + assert left_blank + left = network.add_concatenation( + [left_split.get_output(0), + left_blank.get_output(0)]) + assert left + left.axis = 1 + + # mid + mid_split_shape = (batch_size, 1, fold, input_shape[2], input_shape[3]) + mid_blank = network.add_constant(shape=mid_split_shape, + weights=np.zeros(mid_split_shape, + np.float32)) + assert mid_blank + mid_split = network.add_slice(reshape.get_output(0), + start=(0, 0, fold, 0, 0), + shape=(batch_size, num_segments - 1, fold, + input_shape[2], input_shape[3]), + stride=(1, 1, 1, 1, 1)) + assert mid_split + mid = network.add_concatenation( + [mid_blank.get_output(0), + mid_split.get_output(0)]) + assert mid + mid.axis = 1 + + # right + right = network.add_slice(reshape.get_output(0), + start=(0, 0, 2 * fold, 0, 0), + shape=(batch_size, num_segments, + input_shape[1] - 2 * fold, input_shape[2], + input_shape[3]), + stride=(1, 1, 1, 1, 1)) + + # concat + concat = network.add_concatenation( + [left.get_output(0), + mid.get_output(0), + right.get_output(0)]) + assert concat + concat.axis = 2 + + # reshape + reshape2 = network.add_shuffle(concat.get_output(0)) + assert reshape2 + reshape2.reshape_dims = input_shape + return reshape2 + + +def shift_tensorrt(x, num_segments, shift_div, input_shape): + """Test TensorRT temporal shift module.""" + assert isinstance(x, np.ndarray) + + gLogger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(gLogger) + config = builder.create_builder_config() + + # create engine + explicit_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(explicit_flag) + input = network.add_input(INPUT_BLOB_NAME, trt.float32, input_shape) + assert input + output = _tensorrt_shift_module(network, + input, + num_segments=num_segments, + shift_div=shift_div, + input_shape=input_shape) + assert output + + # generate engine by builder/network/config + output.get_output(0).name = OUTPUT_BLOB_NAME + network.mark_output(output.get_output(0)) + builder.max_batch_size = 1 + builder.max_workspace_size = 1 << 20 + engine = builder.build_engine(network, config) + del network + assert engine.num_bindings == 2, f'{engine.num_bindings}' + context = engine.create_execution_context() + + # buffer + host_in = cuda.pagelocked_empty(trt.volume(input_shape), dtype=np.float32) + np.copyto(host_in, x.ravel()) + host_out = cuda.pagelocked_empty(trt.volume(input_shape), dtype=np.float32) + devide_in = cuda.mem_alloc(host_in.nbytes) + devide_out = cuda.mem_alloc(host_out.nbytes) + bindings = [int(devide_in), int(devide_out)] + stream = cuda.Stream() + + # do inference + cuda.memcpy_htod_async(devide_in, host_in, stream) + context.execute_async(bindings=bindings, stream_handle=stream.handle) + cuda.memcpy_dtoh_async(host_out, devide_out, stream) + stream.synchronize() + + return np.array(host_out.reshape(*input_shape)) + + +if __name__ == '__main__': + INPUT_SHAPE = (16, 64, 32, 32) + assert len(INPUT_SHAPE) == 4 + NUM_SEGMENTS = 8 + SHIFT_DIV = 8 + + # inference + inputs = np.random.rand(*INPUT_SHAPE).astype(np.float32) + inputs_pytorch = torch.tensor(inputs) + with torch.no_grad(): + rmit = shift_mit(inputs_pytorch, NUM_SEGMENTS, SHIFT_DIV).numpy() + rmmaction2 = shift_mmaction2(inputs_pytorch, NUM_SEGMENTS, + SHIFT_DIV).numpy() + rtensorrt = shift_tensorrt(inputs, NUM_SEGMENTS, SHIFT_DIV, INPUT_SHAPE) + + # test results + assert_array_almost_equal(rmit, rtensorrt) + assert_array_almost_equal(rmmaction2, rtensorrt) + print("Tests PASSED") diff --git a/tsm/tsm_r50.py b/tsm/tsm_r50.py new file mode 100644 index 00000000..767dac61 --- /dev/null +++ b/tsm/tsm_r50.py @@ -0,0 +1,471 @@ +import argparse +import os +import struct + +import numpy as np +import pycuda.autoinit # noqa +import pycuda.driver as cuda +import tensorrt as trt + +BATCH_SIZE = 1 +NUM_SEGMENTS = 8 +INPUT_H = 224 +INPUT_W = 224 +OUTPUT_SIZE = 400 +SHIFT_DIV = 8 + +assert INPUT_H % 32 == 0 and INPUT_W % 32 == 0, \ + "Input height and width should be a multiple of 32." + +EPS = 1e-5 +INPUT_BLOB_NAME = "data" +OUTPUT_BLOB_NAME = "prob" + +TRT_LOGGER = trt.Logger(trt.Logger.INFO) + + +def load_weights(file): + print(f"Loading weights: {file}") + + assert os.path.exists(file), f'Unable to load weight file {file}' + + weight_map = {} + with open(file, "r") as f: + lines = [line.strip() for line in f] + count = int(lines[0]) + assert count == len(lines) - 1 + for i in range(1, count + 1): + splits = lines[i].split(" ") + name = splits[0] + cur_count = int(splits[1]) + assert cur_count + 2 == len(splits) + values = [] + for j in range(2, len(splits)): + # hex string to bytes to float + values.append(struct.unpack(">f", bytes.fromhex(splits[j]))) + weight_map[name] = np.array(values, dtype=np.float32) + + return weight_map + + +def add_shift_module(network, input, input_shape, num_segments=8, shift_div=8): + fold = input_shape[1] // shift_div + + # left + left_split = network.add_slice(input, + start=(1, 0, 0, 0), + shape=(num_segments - 1, fold, + input_shape[2], input_shape[3]), + stride=(1, 1, 1, 1)) + assert left_split + left_split_shape = (1, fold, input_shape[2], input_shape[3]) + left_blank = network.add_constant(shape=left_split_shape, + weights=np.zeros(left_split_shape, + np.float32)) + assert left_blank + left = network.add_concatenation( + [left_split.get_output(0), + left_blank.get_output(0)]) + assert left + left.axis = 0 + + # mid + mid_split_shape = (1, fold, input_shape[2], input_shape[3]) + mid_blank = network.add_constant(shape=mid_split_shape, + weights=np.zeros(mid_split_shape, + np.float32)) + assert mid_blank + mid_split = network.add_slice(input, + start=(0, fold, 0, 0), + shape=(num_segments - 1, fold, + input_shape[2], input_shape[3]), + stride=(1, 1, 1, 1)) + assert mid_split + mid = network.add_concatenation( + [mid_blank.get_output(0), + mid_split.get_output(0)]) + assert mid + mid.axis = 0 + + # right + right = network.add_slice(input, + start=(0, 2 * fold, 0, 0), + shape=(num_segments, input_shape[1] - 2 * fold, + input_shape[2], input_shape[3]), + stride=(1, 1, 1, 1)) + + # concat left mid right + output = network.add_concatenation( + [left.get_output(0), + mid.get_output(0), + right.get_output(0)]) + assert output + output.axis = 1 + return output + + +def add_batch_norm_2d(network, weight_map, input, layer_name, eps): + gamma = weight_map[layer_name + ".weight"] + beta = weight_map[layer_name + ".bias"] + mean = weight_map[layer_name + ".running_mean"] + var = weight_map[layer_name + ".running_var"] + var = np.sqrt(var + eps) + + scale = gamma / var + shift = -mean / var * gamma + beta + return network.add_scale(input=input, + mode=trt.ScaleMode.CHANNEL, + shift=shift, + scale=scale) + + +def bottleneck(network, weight_map, input, in_channels, out_channels, stride, + layer_name, input_shape): + shift = add_shift_module(network, input, input_shape, NUM_SEGMENTS, + SHIFT_DIV) + assert shift + + conv1 = network.add_convolution(input=shift.get_output(0), + num_output_maps=out_channels, + kernel_shape=(1, 1), + kernel=weight_map[layer_name + + "conv1.weight"], + bias=trt.Weights()) + assert conv1 + + bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0), + layer_name + "bn1", EPS) + assert bn1 + + relu1 = network.add_activation(bn1.get_output(0), + type=trt.ActivationType.RELU) + assert relu1 + + conv2 = network.add_convolution(input=relu1.get_output(0), + num_output_maps=out_channels, + kernel_shape=(3, 3), + kernel=weight_map[layer_name + + "conv2.weight"], + bias=trt.Weights()) + assert conv2 + conv2.stride = (stride, stride) + conv2.padding = (1, 1) + + bn2 = add_batch_norm_2d(network, weight_map, conv2.get_output(0), + layer_name + "bn2", EPS) + assert bn2 + + relu2 = network.add_activation(bn2.get_output(0), + type=trt.ActivationType.RELU) + assert relu2 + + conv3 = network.add_convolution(input=relu2.get_output(0), + num_output_maps=out_channels * 4, + kernel_shape=(1, 1), + kernel=weight_map[layer_name + + "conv3.weight"], + bias=trt.Weights()) + assert conv3 + + bn3 = add_batch_norm_2d(network, weight_map, conv3.get_output(0), + layer_name + "bn3", EPS) + assert bn3 + + if stride != 1 or in_channels != 4 * out_channels: + conv4 = network.add_convolution( + input=input, + num_output_maps=out_channels * 4, + kernel_shape=(1, 1), + kernel=weight_map[layer_name + "downsample.0.weight"], + bias=trt.Weights()) + assert conv4 + conv4.stride = (stride, stride) + + bn4 = add_batch_norm_2d(network, weight_map, conv4.get_output(0), + layer_name + "downsample.1", EPS) + assert bn4 + + ew1 = network.add_elementwise(bn4.get_output(0), bn3.get_output(0), + trt.ElementWiseOperation.SUM) + else: + ew1 = network.add_elementwise(input, bn3.get_output(0), + trt.ElementWiseOperation.SUM) + assert ew1 + + relu3 = network.add_activation(ew1.get_output(0), + type=trt.ActivationType.RELU) + assert relu3 + + return relu3 + + +def create_engine(maxBatchSize, builder, dt, weights): + weight_map = load_weights(weights) + network = builder.create_network() + + data = network.add_input(INPUT_BLOB_NAME, dt, + (NUM_SEGMENTS, 3, INPUT_H, INPUT_W)) + assert data + + conv1 = network.add_convolution(input=data, + num_output_maps=64, + kernel_shape=(7, 7), + kernel=weight_map["conv1.weight"], + bias=trt.Weights()) + assert conv1 + conv1.stride = (2, 2) + conv1.padding = (3, 3) + + bn1 = add_batch_norm_2d(network, weight_map, conv1.get_output(0), "bn1", + EPS) + assert bn1 + + relu1 = network.add_activation(bn1.get_output(0), + type=trt.ActivationType.RELU) + assert relu1 + + pool1 = network.add_pooling(input=relu1.get_output(0), + window_size=trt.DimsHW(3, 3), + type=trt.PoolingType.MAX) + assert pool1 + pool1.stride = (2, 2) + pool1.padding = (1, 1) + + cur_height = INPUT_H // 4 + cur_width = INPUT_W // 4 + x = bottleneck(network, weight_map, pool1.get_output(0), 64, 64, 1, + "layer1.0.", (NUM_SEGMENTS, 64, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 256, 64, 1, + "layer1.1.", (NUM_SEGMENTS, 256, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 256, 64, 1, + "layer1.2.", (NUM_SEGMENTS, 256, cur_height, cur_width)) + + x = bottleneck(network, weight_map, x.get_output(0), 256, 128, 2, + "layer2.0.", (NUM_SEGMENTS, 256, cur_height, cur_width)) + cur_height = INPUT_H // 8 + cur_width = INPUT_W // 8 + x = bottleneck(network, weight_map, x.get_output(0), 512, 128, 1, + "layer2.1.", (NUM_SEGMENTS, 512, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 512, 128, 1, + "layer2.2.", (NUM_SEGMENTS, 512, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 512, 128, 1, + "layer2.3.", (NUM_SEGMENTS, 512, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 512, 256, 2, + "layer3.0.", (NUM_SEGMENTS, 512, cur_height, cur_width)) + cur_height = INPUT_H // 16 + cur_width = INPUT_W // 16 + x = bottleneck(network, weight_map, x.get_output(0), 1024, 256, 1, + "layer3.1.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 1024, 256, 1, + "layer3.2.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 1024, 256, 1, + "layer3.3.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 1024, 256, 1, + "layer3.4.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 1024, 256, 1, + "layer3.5.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + + x = bottleneck(network, weight_map, x.get_output(0), 1024, 512, 2, + "layer4.0.", (NUM_SEGMENTS, 1024, cur_height, cur_width)) + cur_height = INPUT_H // 32 + cur_width = INPUT_W // 32 + x = bottleneck(network, weight_map, x.get_output(0), 2048, 512, 1, + "layer4.1.", (NUM_SEGMENTS, 2048, cur_height, cur_width)) + x = bottleneck(network, weight_map, x.get_output(0), 2048, 512, 1, + "layer4.2.", (NUM_SEGMENTS, 2048, cur_height, cur_width)) + + pool2 = network.add_pooling(x.get_output(0), + window_size=trt.DimsHW(cur_height, cur_width), + type=trt.PoolingType.AVERAGE) + assert pool2 + pool2.stride = (1, 1) + + fc1 = network.add_fully_connected(input=pool2.get_output(0), + num_outputs=OUTPUT_SIZE, + kernel=weight_map['fc.weight'], + bias=weight_map['fc.bias']) + assert fc1 + + reshape = network.add_shuffle(fc1.get_output(0)) + assert reshape + reshape.reshape_dims = (NUM_SEGMENTS, OUTPUT_SIZE) + + reduce = network.add_reduce(reshape.get_output(0), + op=trt.ReduceOperation.AVG, + axes=1, + keep_dims=False) + assert reduce + + softmax = network.add_softmax(reduce.get_output(0)) + assert softmax + softmax.axes = 1 + + softmax.get_output(0).name = OUTPUT_BLOB_NAME + network.mark_output(softmax.get_output(0)) + + # Build engine + builder.max_batch_size = maxBatchSize + builder.max_workspace_size = 1 << 20 + engine = builder.build_cuda_engine(network) + + del network + del weight_map + + return engine + + +def do_inference(context, host_in, host_out, batchSize): + devide_in = cuda.mem_alloc(host_in.nbytes) + devide_out = cuda.mem_alloc(host_out.nbytes) + bindings = [int(devide_in), int(devide_out)] + stream = cuda.Stream() + + cuda.memcpy_htod_async(devide_in, host_in, stream) + context.execute_async(batch_size=batchSize, + bindings=bindings, + stream_handle=stream.handle) + cuda.memcpy_dtoh_async(host_out, devide_out, stream) + stream.synchronize() + + +def inference_mmaction2(inputs, config, checkpoint): + import torch + from mmaction.models import build_model + from mmcv import Config + from mmcv.runner import load_checkpoint + + cfg = Config.fromfile(config) + cfg.model.backbone.pretrained = None + model = build_model(cfg.model, + train_cfg=None, + test_cfg=cfg.get('test_cfg')) + load_checkpoint(model, checkpoint, map_location='cpu') + model.eval() + inputs = torch.tensor(inputs) + with torch.no_grad(): + return model(return_loss=False, imgs=inputs) + + +def main(args): + assert not (args.save_engine_path and args.load_engine_path) + + if args.load_engine_path: + # load from local file + runtime = trt.Runtime(TRT_LOGGER) + assert runtime + with open(args.load_engine_path, "rb") as f: + engine = runtime.deserialize_cuda_engine(f.read()) + else: + # Create network and engine + assert args.tensorrt_weights + builder = trt.Builder(TRT_LOGGER) + engine = create_engine(BATCH_SIZE, builder, trt.float32, + args.tensorrt_weights) + assert engine + assert engine.num_bindings == 2 + + if args.save_engine_path is not None: + # save engine to local file + with open(args.save_engine_path, "wb") as f: + f.write(engine.serialize()) + print(f"{args.save_engine_path} Generated successfully.") + + context = engine.create_execution_context() + assert context + + host_in = cuda.pagelocked_empty(BATCH_SIZE * NUM_SEGMENTS * 3 * INPUT_H * + INPUT_W, + dtype=np.float32) + host_out = cuda.pagelocked_empty(BATCH_SIZE * OUTPUT_SIZE, + dtype=np.float32) + + if args.test_mmaction2: + assert args.mmaction2_config and args.mmaction2_checkpoint, \ + "MMAction2 config and checkpoint couldn't be None" + + data = np.random.randn(BATCH_SIZE, NUM_SEGMENTS, 3, INPUT_H, + INPUT_W).astype(np.float32) + + # TensorRT inference + np.copyto(host_in, data.ravel()) + do_inference(context, host_in, host_out, BATCH_SIZE) + + # pytorch inference + pytorch_results = inference_mmaction2(data, args.mmaction2_config, + args.mmaction2_checkpoint) + + # test + from numpy.testing import assert_array_almost_equal + assert_array_almost_equal(host_out.reshape(-1), + pytorch_results.reshape(-1), + decimal=4) + print("TEST PASSED") + + if args.input_video: + # Get ONE prediction result from ONE video + # Use demo.mp4 from MMAction2 + import cv2 + + # get selected frame id of uniform sampling + cap = cv2.VideoCapture(args.input_video) + sample_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + avg_interval = sample_length / float(NUM_SEGMENTS) + base_offsets = np.arange(NUM_SEGMENTS) * avg_interval + clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32) + + # read frames + frames = [] + for i in range(max(clip_offsets) + 1): + flag, frame = cap.read() + if i in clip_offsets: + frames.append(cv2.resize(frame, (INPUT_W, INPUT_W))) + frames = np.array(frames) + + # preprocessing frames + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + frames = (frames - mean) / std + frames = frames.transpose([0, 3, 1, 2]) + + # TensorRT inference + np.copyto(host_in, frames.ravel()) + do_inference(context, host_in, host_out, BATCH_SIZE) + # For demo.mp4, should be 6, aka arm wrestling + class_id = np.argmax(host_out.reshape(-1)) + print( + f'Result class id {class_id}, socre {round(host_out[class_id]):.2f}' + ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--tensorrt-weights", + type=str, + default=None, + help="Path to TensorRT weights, which is generated by gen_weights.py") + parser.add_argument("--input-video", + type=str, + default=None, + help="Path to local video file") + parser.add_argument("--save-engine-path", + type=str, + default=None, + help="Save engine to local file") + parser.add_argument("--load-engine-path", + type=str, + default=None, + help="Saved engine file path") + parser.add_argument("--test-mmaction2", + action='store_true', + help="Compare TensorRT results with MMAction2 Results") + parser.add_argument("--mmaction2-config", + type=str, + default=None, + help="Path to MMAction2 config file") + parser.add_argument("--mmaction2-checkpoint", + type=str, + default=None, + help="Path to MMAction2 checkpoint url or file path") + + main(parser.parse_args())