15
15
import torch
16
16
17
17
from pyannote .audio import Model
18
- from tqdm import tqdm
18
+ from transformers import Pipeline
19
+ from transformers .pipelines .pt_utils import PipelineIterator
19
20
20
21
from faster_whisper .audio import decode_audio , pad_or_trim
21
22
from faster_whisper .feature_extractor import FeatureExtractor
@@ -104,7 +105,7 @@ class TranscriptionInfo(NamedTuple):
104
105
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
105
106
106
107
107
- class BatchedInferencePipeline :
108
+ class BatchedInferencePipeline ( Pipeline ) :
108
109
"""
109
110
Huggingface Pipeline wrapper for WhisperModel.
110
111
Copyright (c) 2022, Max Bain
@@ -118,29 +119,55 @@ def __init__(
118
119
use_vad_model : bool = True ,
119
120
options : Optional [NamedTuple ] = None ,
120
121
tokenizer = None ,
122
+ device : Union [int , str , "torch.device" ] = - 1 ,
121
123
chunk_length : int = 30 ,
122
124
vad_device : Union [int , str , "torch.device" ] = "auto" ,
123
125
vad_onset : float = 0.500 ,
124
126
vad_offset : float = 0.363 ,
127
+ framework = "pt" ,
125
128
language : Optional [str ] = None ,
129
+ ** kwargs ,
126
130
):
127
131
self .model : WhisperModel = model
128
132
self .tokenizer = tokenizer
129
133
self .options = options
130
134
self .preset_language = language
135
+ self ._batch_size = kwargs .pop ("batch_size" , None )
136
+ self ._num_workers = 0
131
137
self .use_vad_model = use_vad_model
132
138
self .vad_onset = vad_onset
133
139
self .vad_offset = vad_offset
134
140
self .vad_model_path = os .path .join (get_assets_path (), "pyannote_vad_model.bin" )
135
- if self .use_vad_model :
141
+ self .vad_model = None
142
+
143
+ (
144
+ self ._preprocess_params ,
145
+ self ._forward_params ,
146
+ self ._postprocess_params ,
147
+ ) = self ._sanitize_parameters (** kwargs )
148
+ self .call_count = 0
149
+ self .framework = framework
150
+ if self .framework == "pt" :
151
+ self .device = self .get_device (device )
152
+ else :
153
+ self .device = device
154
+
155
+ if self .use_vad_model and self .vad_model is None :
136
156
self .vad_device = self .get_device (vad_device )
157
+
158
+ # load vad model and perform VAD preprocessing if needed
137
159
self .vad_model = self .load_vad_model (
138
160
vad_onset = self .vad_onset , vad_offset = self .vad_offset
139
161
)
140
- else :
141
- self .vad_model = None
142
162
self .chunk_length = chunk_length # VAD merging size
143
163
self .last_speech_timestamp = 0.0
164
+ super (Pipeline , self ).__init__ ()
165
+
166
+ def _sanitize_parameters (self , ** kwargs ):
167
+ preprocess_kwargs = {}
168
+ if "tokenizer" in kwargs :
169
+ preprocess_kwargs ["maybe_arg" ] = kwargs ["maybe_arg" ]
170
+ return preprocess_kwargs , {}, {}
144
171
145
172
def get_device (self , device : Union [int , str , "torch.device" ]):
146
173
"""
@@ -166,17 +193,27 @@ def get_device(self, device: Union[int, str, "torch.device"]):
166
193
else :
167
194
return torch .device (f"cuda:{ device } " )
168
195
169
- def forward (self , features , segments_metadata , ** forward_params ):
196
+ def preprocess (self , inputs ):
197
+ audio = inputs ["inputs" ]
198
+ to_cpu = (
199
+ self .model .model .device == "cuda" and len (self .model .model .device_index ) > 1
200
+ )
201
+ features = self .model .feature_extractor (audio , padding = True , to_cpu = to_cpu )[
202
+ :, : self .model .feature_extractor .nb_max_frames
203
+ ]
204
+
205
+ inputs ["features" ] = features
206
+ del features
207
+ return inputs
208
+
209
+ def _forward (self , model_inputs , ** forward_params ):
170
210
encoder_output , outputs = self .model .generate_segment_batched (
171
- features , self .tokenizer , forward_params
211
+ model_inputs [ " features" ] , self .tokenizer , forward_params
172
212
)
173
213
214
+ segment_size = encoder_output .shape [1 ] * 2
174
215
segmented_outputs = []
175
- segment_sizes = []
176
- for segment_metadata , output in zip (segments_metadata , outputs ):
177
- duration = segment_metadata ["end_time" ] - segment_metadata ["start_time" ]
178
- segment_size = int (duration * self .model .frames_per_second )
179
- segment_sizes .append (segment_size )
216
+ for segment_metadata , output in zip (model_inputs ["seg_metadata" ], outputs ):
180
217
(
181
218
subsegments ,
182
219
seek ,
@@ -186,7 +223,8 @@ def forward(self, features, segments_metadata, **forward_params):
186
223
tokens = output ["tokens" ],
187
224
time_offset = segment_metadata ["start_time" ],
188
225
segment_size = segment_size ,
189
- segment_duration = duration ,
226
+ segment_duration = segment_metadata ["end_time" ]
227
+ - segment_metadata ["start_time" ],
190
228
seek = 0 ,
191
229
)
192
230
segmented_outputs .append (
@@ -210,13 +248,89 @@ def forward(self, features, segments_metadata, **forward_params):
210
248
segmented_outputs ,
211
249
self .tokenizer ,
212
250
encoder_output ,
213
- segment_sizes ,
251
+ segment_size ,
214
252
forward_params ["prepend_punctuations" ],
215
253
forward_params ["append_punctuations" ],
216
254
self .last_speech_timestamp ,
217
255
)
218
256
219
- return segmented_outputs
257
+ return {"output" : segmented_outputs }
258
+
259
+ def __call__ (self , inputs , options , batch_size = None , ** kwargs ):
260
+ if batch_size is None :
261
+ if self ._batch_size is None :
262
+ batch_size = 1
263
+ else :
264
+ batch_size = self ._batch_size
265
+
266
+ (
267
+ preprocess_params ,
268
+ forward_params ,
269
+ postprocess_params ,
270
+ ) = self ._sanitize_parameters (** kwargs )
271
+
272
+ # Fuse __init__ params and __call__ params without modifying the __init__ ones.
273
+ preprocess_params = {
274
+ ** self ._preprocess_params ,
275
+ ** preprocess_params ,
276
+ }
277
+ options_dict = options ._asdict ()
278
+ forward_params = {** self ._forward_params , ** forward_params , ** options_dict }
279
+ postprocess_params = {** self ._postprocess_params , ** postprocess_params }
280
+
281
+ self .call_count += 1
282
+ if (
283
+ self .call_count > 10
284
+ and self .framework == "pt"
285
+ and self .device .type == "cuda"
286
+ ):
287
+ logging .warning (
288
+ "You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
289
+ )
290
+
291
+ return self .get_iterator (
292
+ inputs ,
293
+ batch_size ,
294
+ preprocess_params ,
295
+ forward_params ,
296
+ postprocess_params ,
297
+ )
298
+
299
+ def postprocess (self , model_outputs ):
300
+ return model_outputs
301
+
302
+ def get_iterator (
303
+ self ,
304
+ inputs ,
305
+ batch_size : int ,
306
+ preprocess_params = None ,
307
+ forward_params = None ,
308
+ postprocess_params = None ,
309
+ ):
310
+ def stack (items ):
311
+ return {
312
+ "inputs" : [x ["inputs" ] for x in items ],
313
+ "seg_metadata" : [x ["seg_metadata" ] for x in items ],
314
+ "features" : torch .stack ([x ["features" ] for x in items ]),
315
+ }
316
+
317
+ if "TOKENIZERS_PARALLELISM" not in os .environ :
318
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
319
+
320
+ dataset = PipelineIterator (inputs , self .preprocess , preprocess_params )
321
+ dataloader = torch .utils .data .DataLoader (
322
+ dataset ,
323
+ num_workers = self ._num_workers ,
324
+ batch_size = batch_size ,
325
+ collate_fn = stack ,
326
+ )
327
+ model_iterator = PipelineIterator (
328
+ dataloader , self .forward , forward_params , loader_batch_size = batch_size
329
+ )
330
+ final_iterator = PipelineIterator (
331
+ model_iterator , self .postprocess , postprocess_params
332
+ )
333
+ return final_iterator
220
334
221
335
def get_language_and_tokenizer (
222
336
self , audio , task : Optional [str ] = None , language : Optional [str ] = None
@@ -255,8 +369,7 @@ def get_language_and_tokenizer(
255
369
@staticmethod
256
370
def audio_split (audio , segments , sampling_rate ):
257
371
"""Returns splitted audio chunks as iterator"""
258
- audio_segments = []
259
- segments_metadata = []
372
+
260
373
for seg in segments :
261
374
f1 = int (seg ["start" ] * sampling_rate )
262
375
f2 = int (seg ["end" ] * sampling_rate )
@@ -265,9 +378,7 @@ def audio_split(audio, segments, sampling_rate):
265
378
"end_time" : seg ["end" ],
266
379
"stitched_seg" : seg ["segments" ],
267
380
}
268
- audio_segments .append (audio [f1 :f2 ])
269
- segments_metadata .append (seg_metadata )
270
- return audio_segments , segments_metadata
381
+ yield {"inputs" : audio [f1 :f2 ], "seg_metadata" : seg_metadata }
271
382
272
383
def load_vad_model (self , vad_onset = 0.500 , vad_offset = 0.363 ):
273
384
vad_model = Model .from_pretrained (self .vad_model_path )
@@ -462,6 +573,7 @@ def transcribe(
462
573
task ,
463
574
all_language_probs ,
464
575
) = self .get_language_and_tokenizer (audio , task , language )
576
+ batch_size = batch_size or self ._batch_size
465
577
466
578
duration_after_vad = sum (
467
579
segment ["end" ] - segment ["start" ] for segment in vad_segments
@@ -511,27 +623,10 @@ def transcribe(
511
623
all_language_probs = all_language_probs ,
512
624
)
513
625
514
- audio_segments , segments_metadata = self .audio_split (
515
- audio , vad_segments , sampling_rate
516
- )
517
- to_cpu = (
518
- self .model .model .device == "cuda" and len (self .model .model .device_index ) > 1
519
- )
520
- audio_segments = torch .nested .nested_tensor (audio_segments ).to_padded_tensor (
521
- padding = 0
522
- )
523
- features = torch .stack (
524
- [
525
- self .model .feature_extractor (audio_segment , to_cpu = to_cpu )[
526
- ..., : self .model .feature_extractor .nb_max_frames
527
- ]
528
- for audio_segment in audio_segments
529
- ]
530
- )
531
-
532
626
segments = self ._batched_segments_generator (
533
- features ,
534
- segments_metadata ,
627
+ audio ,
628
+ vad_segments ,
629
+ sampling_rate ,
535
630
batch_size ,
536
631
batched_options ,
537
632
log_progress ,
@@ -540,40 +635,45 @@ def transcribe(
540
635
return segments , info
541
636
542
637
def _batched_segments_generator (
543
- self , features , segments_metadata , batch_size , options , log_progress
638
+ self , audio , vad_segments , sampling_rate , batch_size , options , log_progress
544
639
):
545
- pbar = tqdm (total = len (features ), disable = not log_progress , position = 0 )
546
640
seg_idx = 0
547
- for i in range (0 , len (features ), batch_size ):
548
- results = self .forward (
549
- features [i : i + batch_size ],
550
- segments_metadata [i : i + batch_size ],
551
- ** options ._asdict (),
641
+ total_segments = len (vad_segments )
642
+ for idx , out in enumerate (
643
+ self .__call__ (
644
+ self .audio_split (audio , vad_segments , sampling_rate ),
645
+ batch_size = batch_size ,
646
+ options = options ,
552
647
)
648
+ ):
649
+ if log_progress :
650
+ percent_complete = ((idx + 1 ) / total_segments ) * 100
651
+ self .model .logger .info (f"Progress: { percent_complete :.2f} %..." )
652
+
653
+ responses = out ["output" ]
654
+ if batch_size == 1 :
655
+ responses = responses [0 ]
656
+
657
+ for response in responses :
658
+ seg_idx += 1
659
+ segments = Segment (
660
+ seek = int (responses [- 1 ]["end" ] * self .model .frames_per_second ),
661
+ id = seg_idx ,
662
+ text = response ["text" ],
663
+ start = round (response ["start" ], 3 ),
664
+ end = round (response ["end" ], 3 ),
665
+ words = (
666
+ None
667
+ if not options .word_timestamps
668
+ else [Word (** word ) for word in response ["words" ]]
669
+ ),
670
+ tokens = response ["tokens" ],
671
+ avg_logprob = response ["avg_logprob" ],
672
+ no_speech_prob = response ["no_speech_prob" ],
673
+ compression_ratio = response ["compression_ratio" ],
674
+ )
675
+ yield segments
553
676
554
- for result in results :
555
- for segment in result :
556
- seg_idx += 1
557
- yield Segment (
558
- seek = int (result [- 1 ]["end" ] * self .model .frames_per_second ),
559
- id = seg_idx ,
560
- text = segment ["text" ],
561
- start = round (segment ["start" ], 3 ),
562
- end = round (segment ["end" ], 3 ),
563
- words = (
564
- None
565
- if not options .word_timestamps
566
- else [Word (** word ) for word in segment ["words" ]]
567
- ),
568
- tokens = segment ["tokens" ],
569
- avg_logprob = segment ["avg_logprob" ],
570
- no_speech_prob = segment ["no_speech_prob" ],
571
- compression_ratio = segment ["compression_ratio" ],
572
- )
573
-
574
- pbar .update (1 )
575
-
576
- pbar .close ()
577
677
# revert the tokenizer if multilingual inference is enabled
578
678
if self .preset_language is None :
579
679
self .tokenizer = None
0 commit comments