15
15
import torch
16
16
17
17
from pyannote .audio import Model
18
- from transformers import Pipeline
19
- from transformers .pipelines .pt_utils import PipelineIterator
18
+ from tqdm import tqdm
20
19
21
20
from faster_whisper .audio import decode_audio , pad_or_trim
22
21
from faster_whisper .feature_extractor import FeatureExtractor
@@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple):
105
104
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
106
105
107
106
108
- class BatchedInferencePipeline ( Pipeline ) :
107
+ class BatchedInferencePipeline :
109
108
"""
110
109
Huggingface Pipeline wrapper for WhisperModel.
111
110
Copyright (c) 2022, Max Bain
@@ -119,55 +118,29 @@ def __init__(
119
118
use_vad_model : bool = True ,
120
119
options : Optional [NamedTuple ] = None ,
121
120
tokenizer = None ,
122
- device : Union [int , str , "torch.device" ] = - 1 ,
123
121
chunk_length : int = 30 ,
124
122
vad_device : Union [int , str , "torch.device" ] = "auto" ,
125
123
vad_onset : float = 0.500 ,
126
124
vad_offset : float = 0.363 ,
127
- framework = "pt" ,
128
125
language : Optional [str ] = None ,
129
- ** kwargs ,
130
126
):
131
127
self .model : WhisperModel = model
132
128
self .tokenizer = tokenizer
133
129
self .options = options
134
130
self .preset_language = language
135
- self ._batch_size = kwargs .pop ("batch_size" , None )
136
- self ._num_workers = 0
137
131
self .use_vad_model = use_vad_model
138
132
self .vad_onset = vad_onset
139
133
self .vad_offset = vad_offset
140
134
self .vad_model_path = os .path .join (get_assets_path (), "pyannote_vad_model.bin" )
141
- self .vad_model = None
142
-
143
- (
144
- self ._preprocess_params ,
145
- self ._forward_params ,
146
- self ._postprocess_params ,
147
- ) = self ._sanitize_parameters (** kwargs )
148
- self .call_count = 0
149
- self .framework = framework
150
- if self .framework == "pt" :
151
- self .device = self .get_device (device )
152
- else :
153
- self .device = device
154
-
155
- if self .use_vad_model and self .vad_model is None :
135
+ if self .use_vad_model :
156
136
self .vad_device = self .get_device (vad_device )
157
-
158
- # load vad model and perform VAD preprocessing if needed
159
137
self .vad_model = self .load_vad_model (
160
138
vad_onset = self .vad_onset , vad_offset = self .vad_offset
161
139
)
140
+ else :
141
+ self .vad_model = None
162
142
self .chunk_length = chunk_length # VAD merging size
163
143
self .last_speech_timestamp = 0.0
164
- super (Pipeline , self ).__init__ ()
165
-
166
- def _sanitize_parameters (self , ** kwargs ):
167
- preprocess_kwargs = {}
168
- if "tokenizer" in kwargs :
169
- preprocess_kwargs ["maybe_arg" ] = kwargs ["maybe_arg" ]
170
- return preprocess_kwargs , {}, {}
171
144
172
145
def get_device (self , device : Union [int , str , "torch.device" ]):
173
146
"""
@@ -193,27 +166,17 @@ def get_device(self, device: Union[int, str, "torch.device"]):
193
166
else :
194
167
return torch .device (f"cuda:{ device } " )
195
168
196
- def preprocess (self , inputs ):
197
- audio = inputs ["inputs" ]
198
- to_cpu = (
199
- self .model .model .device == "cuda" and len (self .model .model .device_index ) > 1
200
- )
201
- features = self .model .feature_extractor (audio , padding = True , to_cpu = to_cpu )[
202
- :, : self .model .feature_extractor .nb_max_frames
203
- ]
204
-
205
- inputs ["features" ] = features
206
- del features
207
- return inputs
208
-
209
- def _forward (self , model_inputs , ** forward_params ):
169
+ def forward (self , features , segments_metadata , ** forward_params ):
210
170
encoder_output , outputs = self .model .generate_segment_batched (
211
- model_inputs [ " features" ] , self .tokenizer , forward_params
171
+ features , self .tokenizer , forward_params
212
172
)
213
173
214
- segment_size = encoder_output .shape [1 ] * 2
215
174
segmented_outputs = []
216
- for segment_metadata , output in zip (model_inputs ["seg_metadata" ], outputs ):
175
+ segment_sizes = []
176
+ for segment_metadata , output in zip (segments_metadata , outputs ):
177
+ duration = segment_metadata ["end_time" ] - segment_metadata ["start_time" ]
178
+ segment_size = int (duration * self .model .frames_per_second )
179
+ segment_sizes .append (segment_size )
217
180
(
218
181
subsegments ,
219
182
seek ,
@@ -223,8 +186,7 @@ def _forward(self, model_inputs, **forward_params):
223
186
tokens = output ["tokens" ],
224
187
time_offset = segment_metadata ["start_time" ],
225
188
segment_size = segment_size ,
226
- segment_duration = segment_metadata ["end_time" ]
227
- - segment_metadata ["start_time" ],
189
+ segment_duration = duration ,
228
190
seek = 0 ,
229
191
)
230
192
segmented_outputs .append (
@@ -248,89 +210,13 @@ def _forward(self, model_inputs, **forward_params):
248
210
segmented_outputs ,
249
211
self .tokenizer ,
250
212
encoder_output ,
251
- segment_size ,
213
+ segment_sizes ,
252
214
forward_params ["prepend_punctuations" ],
253
215
forward_params ["append_punctuations" ],
254
216
self .last_speech_timestamp ,
255
217
)
256
218
257
- return {"output" : segmented_outputs }
258
-
259
- def __call__ (self , inputs , options , batch_size = None , ** kwargs ):
260
- if batch_size is None :
261
- if self ._batch_size is None :
262
- batch_size = 1
263
- else :
264
- batch_size = self ._batch_size
265
-
266
- (
267
- preprocess_params ,
268
- forward_params ,
269
- postprocess_params ,
270
- ) = self ._sanitize_parameters (** kwargs )
271
-
272
- # Fuse __init__ params and __call__ params without modifying the __init__ ones.
273
- preprocess_params = {
274
- ** self ._preprocess_params ,
275
- ** preprocess_params ,
276
- }
277
- options_dict = options ._asdict ()
278
- forward_params = {** self ._forward_params , ** forward_params , ** options_dict }
279
- postprocess_params = {** self ._postprocess_params , ** postprocess_params }
280
-
281
- self .call_count += 1
282
- if (
283
- self .call_count > 10
284
- and self .framework == "pt"
285
- and self .device .type == "cuda"
286
- ):
287
- logging .warning (
288
- "You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
289
- )
290
-
291
- return self .get_iterator (
292
- inputs ,
293
- batch_size ,
294
- preprocess_params ,
295
- forward_params ,
296
- postprocess_params ,
297
- )
298
-
299
- def postprocess (self , model_outputs ):
300
- return model_outputs
301
-
302
- def get_iterator (
303
- self ,
304
- inputs ,
305
- batch_size : int ,
306
- preprocess_params = None ,
307
- forward_params = None ,
308
- postprocess_params = None ,
309
- ):
310
- def stack (items ):
311
- return {
312
- "inputs" : [x ["inputs" ] for x in items ],
313
- "seg_metadata" : [x ["seg_metadata" ] for x in items ],
314
- "features" : torch .stack ([x ["features" ] for x in items ]),
315
- }
316
-
317
- if "TOKENIZERS_PARALLELISM" not in os .environ :
318
- os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
319
-
320
- dataset = PipelineIterator (inputs , self .preprocess , preprocess_params )
321
- dataloader = torch .utils .data .DataLoader (
322
- dataset ,
323
- num_workers = self ._num_workers ,
324
- batch_size = batch_size ,
325
- collate_fn = stack ,
326
- )
327
- model_iterator = PipelineIterator (
328
- dataloader , self .forward , forward_params , loader_batch_size = batch_size
329
- )
330
- final_iterator = PipelineIterator (
331
- model_iterator , self .postprocess , postprocess_params
332
- )
333
- return final_iterator
219
+ return segmented_outputs
334
220
335
221
def get_language_and_tokenizer (
336
222
self , audio , task : Optional [str ] = None , language : Optional [str ] = None
@@ -369,7 +255,8 @@ def get_language_and_tokenizer(
369
255
@staticmethod
370
256
def audio_split (audio , segments , sampling_rate ):
371
257
"""Returns splitted audio chunks as iterator"""
372
-
258
+ audio_segments = []
259
+ segments_metadata = []
373
260
for seg in segments :
374
261
f1 = int (seg ["start" ] * sampling_rate )
375
262
f2 = int (seg ["end" ] * sampling_rate )
@@ -378,7 +265,9 @@ def audio_split(audio, segments, sampling_rate):
378
265
"end_time" : seg ["end" ],
379
266
"stitched_seg" : seg ["segments" ],
380
267
}
381
- yield {"inputs" : audio [f1 :f2 ], "seg_metadata" : seg_metadata }
268
+ audio_segments .append (audio [f1 :f2 ])
269
+ segments_metadata .append (seg_metadata )
270
+ return audio_segments , segments_metadata
382
271
383
272
def load_vad_model (self , vad_onset = 0.500 , vad_offset = 0.363 ):
384
273
vad_model = Model .from_pretrained (self .vad_model_path )
@@ -573,7 +462,6 @@ def transcribe(
573
462
task ,
574
463
all_language_probs ,
575
464
) = self .get_language_and_tokenizer (audio , task , language )
576
- batch_size = batch_size or self ._batch_size
577
465
578
466
duration_after_vad = sum (
579
467
segment ["end" ] - segment ["start" ] for segment in vad_segments
@@ -623,10 +511,27 @@ def transcribe(
623
511
all_language_probs = all_language_probs ,
624
512
)
625
513
514
+ audio_segments , segments_metadata = self .audio_split (
515
+ audio , vad_segments , sampling_rate
516
+ )
517
+ to_cpu = (
518
+ self .model .model .device == "cuda" and len (self .model .model .device_index ) > 1
519
+ )
520
+ audio_segments = torch .nested .nested_tensor (audio_segments ).to_padded_tensor (
521
+ padding = 0
522
+ )
523
+ features = torch .stack (
524
+ [
525
+ self .model .feature_extractor (audio_segment , to_cpu = to_cpu )[
526
+ ..., : self .model .feature_extractor .nb_max_frames
527
+ ]
528
+ for audio_segment in audio_segments
529
+ ]
530
+ )
531
+
626
532
segments = self ._batched_segments_generator (
627
- audio ,
628
- vad_segments ,
629
- sampling_rate ,
533
+ features ,
534
+ segments_metadata ,
630
535
batch_size ,
631
536
batched_options ,
632
537
log_progress ,
@@ -635,45 +540,40 @@ def transcribe(
635
540
return segments , info
636
541
637
542
def _batched_segments_generator (
638
- self , audio , vad_segments , sampling_rate , batch_size , options , log_progress
543
+ self , features , segments_metadata , batch_size , options , log_progress
639
544
):
545
+ pbar = tqdm (total = len (features ), disable = not log_progress , position = 0 )
640
546
seg_idx = 0
641
- total_segments = len (vad_segments )
642
- for idx , out in enumerate (
643
- self .__call__ (
644
- self .audio_split (audio , vad_segments , sampling_rate ),
645
- batch_size = batch_size ,
646
- options = options ,
547
+ for i in range (0 , len (features ), batch_size ):
548
+ results = self .forward (
549
+ features [i : i + batch_size ],
550
+ segments_metadata [i : i + batch_size ],
551
+ ** options ._asdict (),
647
552
)
648
- ):
649
- if log_progress :
650
- percent_complete = ((idx + 1 ) / total_segments ) * 100
651
- self .model .logger .info (f"Progress: { percent_complete :.2f} %..." )
652
-
653
- responses = out ["output" ]
654
- if batch_size == 1 :
655
- responses = responses [0 ]
656
-
657
- for response in responses :
658
- seg_idx += 1
659
- segments = Segment (
660
- seek = int (responses [- 1 ]["end" ] * self .model .frames_per_second ),
661
- id = seg_idx ,
662
- text = response ["text" ],
663
- start = round (response ["start" ], 3 ),
664
- end = round (response ["end" ], 3 ),
665
- words = (
666
- None
667
- if not options .word_timestamps
668
- else [Word (** word ) for word in response ["words" ]]
669
- ),
670
- tokens = response ["tokens" ],
671
- avg_logprob = response ["avg_logprob" ],
672
- no_speech_prob = response ["no_speech_prob" ],
673
- compression_ratio = response ["compression_ratio" ],
674
- )
675
- yield segments
676
553
554
+ for result in results :
555
+ for segment in result :
556
+ seg_idx += 1
557
+ yield Segment (
558
+ seek = int (result [- 1 ]["end" ] * self .model .frames_per_second ),
559
+ id = seg_idx ,
560
+ text = segment ["text" ],
561
+ start = round (segment ["start" ], 3 ),
562
+ end = round (segment ["end" ], 3 ),
563
+ words = (
564
+ None
565
+ if not options .word_timestamps
566
+ else [Word (** word ) for word in segment ["words" ]]
567
+ ),
568
+ tokens = segment ["tokens" ],
569
+ avg_logprob = segment ["avg_logprob" ],
570
+ no_speech_prob = segment ["no_speech_prob" ],
571
+ compression_ratio = segment ["compression_ratio" ],
572
+ )
573
+
574
+ pbar .update (1 )
575
+
576
+ pbar .close ()
677
577
# revert the tokenizer if multilingual inference is enabled
678
578
if self .preset_language is None :
679
579
self .tokenizer = None
0 commit comments