forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support TSM-R50 Python API (wang-xinyu#488)
* add tensorrt temporal shift module and related pytorch implementations * add .gitignore and getn weights script. * rename get_wts.py script * Add tsm-r50 demo. * update readme * remove useless codes * update readme * update readme * remote video and .gitignore, update tutorial * update readme and tutorial * fix a few bugs and test on tensorrt 5.1 * update readme
- Loading branch information
1 parent
d9bdd7e
commit 8cfc8ee
Showing
8 changed files
with
869 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Temporal Shift Module | ||
|
||
TSM-R50 from "TSM: Temporal Shift Module for Efficient Video Understanding" <https://arxiv.org/abs/1811.08383> | ||
|
||
TSM is a widely used Action Recognition model. This TensorRT implementation is tested with TensorRT 5.1 and TensorRT 7.2. | ||
|
||
For the PyTorch implementation, you can refer to [open-mmlab/mmaction2](https://github.com/open-mmlab/mmaction2) or [mit-han-lab/temporal-shift-module](https://github.com/mit-han-lab/temporal-shift-module). | ||
|
||
More details about the shift module(which is the core of TSM) could to [test_shift.py](./test_shift.py). | ||
|
||
## Tutorial | ||
|
||
+ An example could refer to [demo.sh](./demo.sh) | ||
+ Requirements: Successfully installed `torch>=1.3.0, torchvision` | ||
|
||
+ Step 1: Train/Download TSM-R50 checkpoints from [offical Github repo](https://github.com/mit-han-lab/temporal-shift-module) or [MMAction2](https://github.com/open-mmlab/mmaction2) | ||
+ Supported settings: `num_segments`, `shift_div`, `num_classes`. | ||
+ Fixed settings: `backbone`(ResNet50), `shift_place`(blockres), `temporal_pool`(False). | ||
|
||
+ Step 2: Convert PyTorch checkpoints to TensorRT weights. | ||
|
||
```shell | ||
python gen_wts.py /path/to/pytorch.pth --out-filename /path/to/tensorrt.wts | ||
``` | ||
|
||
+ Step 3: Modify configs in `tsm_r50.py`. | ||
|
||
```python | ||
BATCH_SIZE = 1 | ||
NUM_SEGMENTS = 8 | ||
INPUT_H = 224 | ||
INPUT_W = 224 | ||
OUTPUT_SIZE = 400 | ||
SHIFT_DIV = 8 | ||
``` | ||
|
||
+ Step 4: Inference with `tsm_r50.py`. | ||
|
||
```shell | ||
usage: tsm_r50.py [-h] [--tensorrt-weights TENSORRT_WEIGHTS] [--input-video INPUT_VIDEO] [--save-engine-path SAVE_ENGINE_PATH] [--load-engine-path LOAD_ENGINE_PATH] [--test-mmaction2] [--mmaction2-config MMACTION2_CONFIG] [--mmaction2-checkpoint MMACTION2_CHECKPOINT] | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--tensorrt-weights TENSORRT_WEIGHTS | ||
Path to TensorRT weights, which is generated by gen_weights.py | ||
--input-video INPUT_VIDEO | ||
Path to local video file | ||
--save-engine-path SAVE_ENGINE_PATH | ||
Save engine to local file | ||
--load-engine-path LOAD_ENGINE_PATH | ||
Saved engine file path | ||
--test-mmaction2 Compare TensorRT results with MMAction2 Results | ||
--mmaction2-config MMACTION2_CONFIG | ||
Path to MMAction2 config file | ||
--mmaction2-checkpoint MMACTION2_CHECKPOINT | ||
Path to MMAction2 checkpoint url or file path | ||
``` | ||
|
||
## TODO | ||
|
||
+ [x] Python Shift module. | ||
+ [x] Generate wts of official tsm and mmaction2 tsm. | ||
+ [x] Python API Definition | ||
+ [x] Test with mmaction2 demo | ||
+ [x] Tutorial | ||
+ [ ] C++ API Definition |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Step 1: Get checkpoints from mmaction2 | ||
# https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsm | ||
wget https://download.openmmlab.com/mmaction/recognition/tsm/tsm_r50_1x1x8_50e_kinetics400_rgb/tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth | ||
|
||
# Step 2: Convert pytorch checkpoints to TensorRT weights | ||
python gen_wts.py tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth --out-filename ./tsm_r50_kinetics400_mmaction2.wts | ||
|
||
# Step 3: Skip this step since we use default settings. | ||
|
||
# Step 4: Inference | ||
# 1) Save local engine file to `./tsm_r50_kinetics400_mmaction2.trt`. | ||
python tsm_r50.py \ | ||
--tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ | ||
--save-engine-path ./tsm_r50_kinetics400_mmaction2.trt | ||
|
||
# 2) Predict the recognition result using a single video `demo.mp4`. | ||
# Should print `Result class id 6`, aka `arm wrestling` | ||
# Download demo video | ||
wget https://raw.githubusercontent.com/open-mmlab/mmaction2/master/demo/demo.mp4 | ||
# # use *.wts as input | ||
# python tsm_r50.py --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ | ||
# --input-video ./demo.mp4 | ||
# use engine file as input | ||
python tsm_r50.py --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \ | ||
--input-video ./demo.mp4 | ||
|
||
# 3) Optional: Compare inference result with MMAction2 TSM-R50 model | ||
# Have to install MMAction2 First, please refer to https://github.com/open-mmlab/mmaction2/blob/master/docs/install.md | ||
# pip3 install pytest-runner | ||
# pip3 install mmcv | ||
# pip3 install mmaction2 | ||
# # use *.wts as input | ||
# python tsm_r50.py \ | ||
# --tensorrt-weights ./tsm_r50_kinetics400_mmaction2.wts \ | ||
# --test-mmaction2 \ | ||
# --mmaction2-config mmaction2_tsm_r50_config.py \ | ||
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth | ||
# # use TensorRT engine as input | ||
# python tsm_r50.py \ | ||
# --load-engine-path ./tsm_r50_kinetics400_mmaction2.trt \ | ||
# --test-mmaction2 \ | ||
# --mmaction2-config mmaction2_tsm_r50_config.py \ | ||
# --mmaction2-checkpoint tsm_r50_1x1x8_50e_kinetics400_rgb_20200607-af7fb746.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
import struct | ||
|
||
import torch | ||
import numpy as np | ||
|
||
|
||
def write_one_weight(writer, name, weight): | ||
assert isinstance(weight, np.ndarray) | ||
values = weight.reshape(-1) | ||
writer.write('{} {}'.format(name, len(values))) | ||
for value in values: | ||
writer.write(' ') | ||
# float to bytes to hex_string | ||
writer.write(struct.pack('>f', float(value)).hex()) | ||
writer.write('\n') | ||
|
||
|
||
def convert_name(name): | ||
return name.replace("module.", "").replace("base_model.", "").\ | ||
replace("net.", "").replace("new_fc", "fc").replace("backbone.", "").\ | ||
replace("cls_head.fc_cls", "fc").replace(".conv.", ".").\ | ||
replace("conv1.bn", "bn1").replace("conv2.bn", "bn2").\ | ||
replace("conv3.bn", "bn3").replace("downsample.bn", "downsample.1").\ | ||
replace("downsample.weight", "downsample.0.weight") | ||
|
||
|
||
def main(args): | ||
ckpt = torch.load(args.checkpoint)['state_dict'] | ||
ckpt = {k: v for k, v in ckpt.items() if 'num_batches_tracked' not in k} | ||
with open(args.out_filename, "w") as f: | ||
f.write(f"{len(ckpt)}\n") | ||
for k, v in ckpt.items(): | ||
key = convert_name(k) | ||
write_one_weight(f, key, v.cpu().numpy()) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("checkpoint", type=str, help="Path to checkpoint file") | ||
parser.add_argument("--out-filename", | ||
type=str, | ||
default="tsm_r50.wts", | ||
help="Path to converted wegiths file") | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# model settings | ||
model = dict( | ||
type='Recognizer2D', | ||
backbone=dict( | ||
type='ResNetTSM', | ||
pretrained='torchvision://resnet50', | ||
depth=50, | ||
norm_eval=False, | ||
shift_div=8), | ||
cls_head=dict( | ||
type='TSMHead', | ||
num_classes=400, | ||
in_channels=2048, | ||
spatial_type='avg', | ||
consensus=dict(type='AvgConsensus', dim=1), | ||
dropout_ratio=0.5, | ||
init_std=0.001, | ||
is_shift=True), | ||
# model training and testing settings | ||
train_cfg=None, | ||
test_cfg=dict(average_clips='prob')) |
Oops, something went wrong.