Skip to content

Commit

Permalink
Support TSM-R50 Python API (wang-xinyu#488)
Browse files Browse the repository at this point in the history
* add tensorrt temporal shift module and related pytorch implementations

* add .gitignore and getn weights script.

* rename get_wts.py script

* Add tsm-r50 demo.

* update readme

* remove useless codes

* update readme

* update readme

* remote video and .gitignore, update tutorial

* update readme and tutorial

* fix a few bugs and test on tensorrt 5.1

* update readme
  • Loading branch information
irvingzhang0512 authored Apr 18, 2021
1 parent d9bdd7e commit 8cfc8ee
Show file tree
Hide file tree
Showing 8 changed files with 869 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lenet/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def load_weights(file):

weight_map = {}
with open(file, "r") as f:
lines = f.readlines()
lines = [line.strip() for line in f]
count = int(lines[0])
assert count == len(lines) - 1
for i in range(1, count + 1):
Expand Down
6 changes: 3 additions & 3 deletions resnet/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def load_weights(file):

weight_map = {}
with open(file, "r") as f:
lines = f.readlines()
lines = [line.strip() for line in f]
count = int(lines[0])
assert count == len(lines) - 1
for i in range(1, count + 1):
Expand Down Expand Up @@ -138,7 +138,7 @@ def bottleneck(network, weight_map, input, in_channels, out_channels, stride,
return relu3


def createLenetEngine(maxBatchSize, builder, config, dt):
def create_engine(maxBatchSize, builder, config, dt):
weight_map = load_weights(WEIGHT_PATH)
network = builder.create_network()

Expand Down Expand Up @@ -233,7 +233,7 @@ def createLenetEngine(maxBatchSize, builder, config, dt):
def APIToModel(maxBatchSize):
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
engine = createLenetEngine(maxBatchSize, builder, config, trt.float32)
engine = create_engine(maxBatchSize, builder, config, trt.float32)
assert engine
with open(ENGINE_PATH, "wb") as f:
f.write(engine.serialize())
Expand Down
66 changes: 66 additions & 0 deletions tsm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Temporal Shift Module

TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" <https://arxiv.org/abs/1811.08383>

TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2.

For the PyTorch implementation, you can refer to [open-mmlab/mmaction2](https://github.com/open-mmlab/mmaction2) or [mit-han-lab/temporal-shift-module](https://github.com/mit-han-lab/temporal-shift-module).

More details about the shift module(which is the core of TSM) could to [test_shift.py](./test_shift.py).

## Tutorial

+ An example could refer to [demo.sh](./demo.sh)
+ Requirements: Successfully installed `torch>=1.3.0, torchvision`

+ Step 1: Train/Download TSM-R50 checkpoints from [offical Github repo](https://github.com/mit-han-lab/temporal-shift-module) or [MMAction2](https://github.com/open-mmlab/mmaction2)
+ Supported settings: `num_segments`, `shift_div`, `num_classes`.
+ Fixed settings: `backbone`(ResNet50), `shift_place`(blockres), `temporal_pool`(False).

+ Step 2: Convert PyTorch checkpoints to TensorRT weights.

```shell
python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts
```

+ Step 3: Modify configs in `tsm_r50.py`.

```python
BATCH_SIZE = 1
NUM_SEGMENTS = 8
INPUT_H = 224
INPUT_W = 224
OUTPUT_SIZE = 400
SHIFT_DIV = 8
```

+ Step 4: Inference with `tsm_r50.py`.

```shell
usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT]

optional arguments:
-h, --help show this help message and exit
--tensorrt-weights TENSORRT_WEIGHTS
Path to TensorRT weights, which is generated by gen_weights.py
--input-video INPUT_VIDEO
Path to local video file
--save-engine-path SAVE_ENGINE_PATH
Save engine to local file
--load-engine-path LOAD_ENGINE_PATH
Saved engine file path
--test-mmaction2 Compare TensorRT results with MMAction2 Results
--mmaction2-config MMACTION2_CONFIG
Path to MMAction2 config file
--mmaction2-checkpoint MMACTION2_CHECKPOINT
Path to MMAction2 checkpoint url or file path
```

## TODO

+ [x] Python Shift module.
+ [x] Generate wts of official tsm and mmaction2 tsm.
+ [x] Python API Definition
+ [x] Test with mmaction2 demo
+ [x] Tutorial
+ [ ] C++ API Definition
43 changes: 43 additions & 0 deletions tsm/demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Step 1: Get checkpoints from mmaction2
# https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsm
wget https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_kinetics400_rgb/tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth

# Step 2: Convert pytorch checkpoints to TensorRT weights
python gen_wts.py tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth --out-filename ./tsm_r50_kinetics400_mmaction2.wts

# Step 3: Skip this step since we use default settings.

# Step 4: Inference
# 1) Save local engine file to `./tsm_r50_kinetics400_mmaction2.trt`.
python tsm_r50.py \
--tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
--save-engine-path ./tsm_r50_kinetics400_mmaction2.trt

# 2) Predict the recognition result using a single video `demo.mp4`.
# Should print `Result class id 6`, aka `arm wrestling`
# Download demo video
wget https://raw.githubusercontent.com/open-mmlab/mmaction2/master/demo/demo.mp4
# # use *.wts as input
# python tsm_r50.py --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
# --input-video ./demo.mp4
# use engine file as input
python tsm_r50.py --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \
--input-video ./demo.mp4

# 3) Optional: Compare inference result with MMAction2 TSM-R50 model
# Have to install MMAction2 First, please refer to https://github.com/open-mmlab/mmaction2/blob/master/docs/install.md
# pip3 install pytest-runner
# pip3 install mmcv
# pip3 install mmaction2
# # use *.wts as input
# python tsm_r50.py \
# --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \
# --test-mmaction2 \
# --mmaction2-config mmaction2_tsm_r50_config.py \
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth
# # use TensorRT engine as input
# python tsm_r50.py \
# --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \
# --test-mmaction2 \
# --mmaction2-config mmaction2_tsm_r50_config.py \
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth
46 changes: 46 additions & 0 deletions tsm/gen_wts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import struct

import torch
import numpy as np


def write_one_weight(writer, name, weight):
assert isinstance(weight, np.ndarray)
values = weight.reshape(-1)
writer.write('{} {}'.format(name, len(values)))
for value in values:
writer.write(' ')
# float to bytes to hex_string
writer.write(struct.pack('>f', float(value)).hex())
writer.write('\n')


def convert_name(name):
return name.replace("module.", "").replace("base_model.", "").\
replace("net.", "").replace("new_fc", "fc").replace("backbone.", "").\
replace("cls_head.fc_cls", "fc").replace(".conv.", ".").\
replace("conv1.bn", "bn1").replace("conv2.bn", "bn2").\
replace("conv3.bn", "bn3").replace("downsample.bn", "downsample.1").\
replace("downsample.weight", "downsample.0.weight")


def main(args):
ckpt = torch.load(args.checkpoint)['state_dict']
ckpt = {k: v for k, v in ckpt.items() if 'num_batches_tracked' not in k}
with open(args.out_filename, "w") as f:
f.write(f"{len(ckpt)}\n")
for k, v in ckpt.items():
key = convert_name(k)
write_one_weight(f, key, v.cpu().numpy())


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint", type=str, help="Path to checkpoint file")
parser.add_argument("--out-filename",
type=str,
default="tsm_r50.wts",
help="Path to converted wegiths file")
args = parser.parse_args()
main(args)
21 changes: 21 additions & 0 deletions tsm/mmaction2_tsm_r50_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# model settings
model = dict(
type='Recognizer2D',
backbone=dict(
type='ResNetTSM',
pretrained='torchvision://resnet50',
depth=50,
norm_eval=False,
shift_div=8),
cls_head=dict(
type='TSMHead',
num_classes=400,
in_channels=2048,
spatial_type='avg',
consensus=dict(type='AvgConsensus', dim=1),
dropout_ratio=0.5,
init_std=0.001,
is_shift=True),
# model training and testing settings
train_cfg=None,
test_cfg=dict(average_clips='prob'))
Loading

0 comments on commit 8cfc8ee

Please sign in to comment.