11#!/usr/bin/env python
2- from pathlib import Path
32import argparse
43import os
4+ from pathlib import Path
55
66import numpy as np
77import pympi
88import torch
99from pose_format import Pose
1010from pose_format .utils .generic import pose_normalization_info , pose_hide_legs , normalize_hands_3d
11+ from torch .fx .experimental .symbolic_shapes import lru_cache
1112
1213from sign_language_segmentation .src .utils .probs_to_segments import probs_to_segments
1314
15+ DEFAULT_MODEL = "model_E1s-1.pth"
16+
1417
15- def add_optical_flow (pose : Pose )-> None :
18+ def add_optical_flow (pose : Pose ) -> None :
1619 from pose_format .numpy .representation .distance import DistanceRepresentation
1720 from pose_format .utils .optical_flow import OpticalFlowCalculator
1821
@@ -44,6 +47,7 @@ def process_pose(pose: Pose, optical_flow=False, hand_normalization=False) -> Po
4447 return pose
4548
4649
50+ @lru_cache (maxsize = 1 )
4751def load_model (model_path : str ):
4852 model = torch .jit .load (model_path )
4953 model .eval ()
@@ -58,7 +62,7 @@ def predict(model, pose: Pose):
5862 return model (pose_data )
5963
6064
61- def save_pose_segments (tiers :dict , tier_id :str , input_file_path :Path )-> None :
65+ def save_pose_segments (tiers : dict , tier_id : str , input_file_path : Path ) -> None :
6266 # reload it without any of the processing, so we get all the original points and such.
6367 with input_file_path .open ("rb" ) as f :
6468 pose = Pose .read (f .read ())
@@ -83,42 +87,64 @@ def get_args():
8387 )
8488 parser .add_argument ("--video" , default = None , required = False , type = str , help = "path to video file" )
8589 parser .add_argument ("--subtitles" , default = None , required = False , type = str , help = "path to subtitle file" )
86- parser .add_argument ("--model" , default = "model_E1s-1.pth" , required = False , type = str , help = "path to model file" )
90+ parser .add_argument ("--model" , default = DEFAULT_MODEL , required = False , type = str , help = "path to model file" )
8791 parser .add_argument ("--no-pose-link" , action = "store_true" , help = "whether to link the pose file" )
8892
8993 return parser .parse_args ()
9094
9195
92- def main ():
93- args = get_args ()
96+ def segment_pose (pose : Pose , model : str = DEFAULT_MODEL , verbose = True ):
97+ if "E4" in model :
98+ pose = process_pose (pose , optical_flow = True , hand_normalization = True )
99+ else :
100+ pose = process_pose (pose )
94101
95- print ("Loading pose ..." )
96- with open (args .pose , "rb" ) as f :
97- pose = Pose .read (f .read ())
98- if "E4" in args .model :
99- pose = process_pose (pose , optical_flow = True , hand_normalization = True )
100- else :
101- pose = process_pose (pose )
102-
103- print ("Loading model ..." )
102+ if verbose :
103+ print ("Loading model ..." )
104104 install_dir = str (os .path .dirname (os .path .abspath (__file__ )))
105- model = load_model (os .path .join (install_dir , "dist" , args . model ))
105+ model = load_model (os .path .join (install_dir , "dist" , model ))
106106
107- print ("Estimating segments ..." )
107+ if verbose :
108+ print ("Estimating segments ..." )
108109 probs = predict (model , pose )
109110
110111 sign_segments = probs_to_segments (probs ["sign" ], 60 , 50 )
111112 sentence_segments = probs_to_segments (probs ["sentence" ], 90 , 90 )
112113
113- print ("Building ELAN file ..." )
114+ if verbose :
115+ print ("Building ELAN file ..." )
116+ eaf = pympi .Elan .Eaf (author = "sign-language-processing/transcription" )
117+
118+ fps = pose .body .fps
119+
114120 tiers = {
115121 "SIGN" : sign_segments ,
116122 "SENTENCE" : sentence_segments ,
117123 }
118124
119- fps = pose .body .fps
125+ for tier_id , segments in tiers .items ():
126+ eaf .add_tier (tier_id )
127+ for segment in segments :
128+ if segment ["end" ] == segment ["start" ]:
129+ segment ["end" ] += 1
130+
131+ # convert frame numbers to millisecond timestamps, for Elan
132+ start_time_ms = int (segment ["start" ] / fps * 1000 )
133+ end_time_ms = int (segment ["end" ] / fps * 1000 )
134+ eaf .add_annotation (tier_id , start_time_ms , end_time_ms )
135+
136+ return eaf , tiers
137+
138+
139+ def main ():
140+ args = get_args ()
141+
142+ print ("Loading pose ..." )
143+ with open (args .pose , "rb" ) as f :
144+ pose = Pose .read (f .read ())
145+
146+ eaf , tiers = segment_pose (pose , model = args .model )
120147
121- eaf = pympi .Elan .Eaf (author = "sign-language-processing/transcription" )
122148 if args .video is not None :
123149 mimetype = None # pympi is not familiar with mp4 files
124150 if args .video .endswith (".mp4" ):
@@ -128,18 +154,6 @@ def main():
128154 if not args .no_pose_link :
129155 eaf .add_linked_file (args .pose , mimetype = "application/pose" )
130156
131- for tier_id , segments in tiers .items ():
132- eaf .add_tier (tier_id )
133- for segment in segments :
134- # convert frame numbers to millisecond timestamps, for Elan
135- start_time_ms = int (segment ["start" ] / fps * 1000 )
136- end_time_ms = int (segment ["end" ] / fps * 1000 )
137- eaf .add_annotation (tier_id , start_time_ms , end_time_ms )
138-
139- if args .save_segments :
140- print (f"Saving { args .save_segments } cropped .pose files" )
141- save_pose_segments (tiers , tier_id = args .save_segments , input_file_path = args .pose )
142-
143157 if args .subtitles and os .path .exists (args .subtitles ):
144158 import srt
145159
0 commit comments