Skip to content

Commit ffc2e42

Browse files
Revert "Remove the usage of transformers.pipeline from BatchedInferencePipeline and fix word timestamps for batched inference (SYSTRAN#921)"
This reverts commit d57c5b4.
1 parent d57c5b4 commit ffc2e42

File tree

2 files changed

+172
-72
lines changed

2 files changed

+172
-72
lines changed

faster_whisper/transcribe.py

+170-70
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import torch
1616

1717
from pyannote.audio import Model
18-
from tqdm import tqdm
18+
from transformers import Pipeline
19+
from transformers.pipelines.pt_utils import PipelineIterator
1920

2021
from faster_whisper.audio import decode_audio, pad_or_trim
2122
from faster_whisper.feature_extractor import FeatureExtractor
@@ -104,7 +105,7 @@ class TranscriptionInfo(NamedTuple):
104105
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
105106

106107

107-
class BatchedInferencePipeline:
108+
class BatchedInferencePipeline(Pipeline):
108109
"""
109110
Huggingface Pipeline wrapper for WhisperModel.
110111
Copyright (c) 2022, Max Bain
@@ -118,29 +119,55 @@ def __init__(
118119
use_vad_model: bool = True,
119120
options: Optional[NamedTuple] = None,
120121
tokenizer=None,
122+
device: Union[int, str, "torch.device"] = -1,
121123
chunk_length: int = 30,
122124
vad_device: Union[int, str, "torch.device"] = "auto",
123125
vad_onset: float = 0.500,
124126
vad_offset: float = 0.363,
127+
framework="pt",
125128
language: Optional[str] = None,
129+
**kwargs,
126130
):
127131
self.model: WhisperModel = model
128132
self.tokenizer = tokenizer
129133
self.options = options
130134
self.preset_language = language
135+
self._batch_size = kwargs.pop("batch_size", None)
136+
self._num_workers = 0
131137
self.use_vad_model = use_vad_model
132138
self.vad_onset = vad_onset
133139
self.vad_offset = vad_offset
134140
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
135-
if self.use_vad_model:
141+
self.vad_model = None
142+
143+
(
144+
self._preprocess_params,
145+
self._forward_params,
146+
self._postprocess_params,
147+
) = self._sanitize_parameters(**kwargs)
148+
self.call_count = 0
149+
self.framework = framework
150+
if self.framework == "pt":
151+
self.device = self.get_device(device)
152+
else:
153+
self.device = device
154+
155+
if self.use_vad_model and self.vad_model is None:
136156
self.vad_device = self.get_device(vad_device)
157+
158+
# load vad model and perform VAD preprocessing if needed
137159
self.vad_model = self.load_vad_model(
138160
vad_onset=self.vad_onset, vad_offset=self.vad_offset
139161
)
140-
else:
141-
self.vad_model = None
142162
self.chunk_length = chunk_length # VAD merging size
143163
self.last_speech_timestamp = 0.0
164+
super(Pipeline, self).__init__()
165+
166+
def _sanitize_parameters(self, **kwargs):
167+
preprocess_kwargs = {}
168+
if "tokenizer" in kwargs:
169+
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
170+
return preprocess_kwargs, {}, {}
144171

145172
def get_device(self, device: Union[int, str, "torch.device"]):
146173
"""
@@ -166,17 +193,27 @@ def get_device(self, device: Union[int, str, "torch.device"]):
166193
else:
167194
return torch.device(f"cuda:{device}")
168195

169-
def forward(self, features, segments_metadata, **forward_params):
196+
def preprocess(self, inputs):
197+
audio = inputs["inputs"]
198+
to_cpu = (
199+
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
200+
)
201+
features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
202+
:, : self.model.feature_extractor.nb_max_frames
203+
]
204+
205+
inputs["features"] = features
206+
del features
207+
return inputs
208+
209+
def _forward(self, model_inputs, **forward_params):
170210
encoder_output, outputs = self.model.generate_segment_batched(
171-
features, self.tokenizer, forward_params
211+
model_inputs["features"], self.tokenizer, forward_params
172212
)
173213

214+
segment_size = encoder_output.shape[1] * 2
174215
segmented_outputs = []
175-
segment_sizes = []
176-
for segment_metadata, output in zip(segments_metadata, outputs):
177-
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
178-
segment_size = int(duration * self.model.frames_per_second)
179-
segment_sizes.append(segment_size)
216+
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
180217
(
181218
subsegments,
182219
seek,
@@ -186,7 +223,8 @@ def forward(self, features, segments_metadata, **forward_params):
186223
tokens=output["tokens"],
187224
time_offset=segment_metadata["start_time"],
188225
segment_size=segment_size,
189-
segment_duration=duration,
226+
segment_duration=segment_metadata["end_time"]
227+
- segment_metadata["start_time"],
190228
seek=0,
191229
)
192230
segmented_outputs.append(
@@ -210,13 +248,89 @@ def forward(self, features, segments_metadata, **forward_params):
210248
segmented_outputs,
211249
self.tokenizer,
212250
encoder_output,
213-
segment_sizes,
251+
segment_size,
214252
forward_params["prepend_punctuations"],
215253
forward_params["append_punctuations"],
216254
self.last_speech_timestamp,
217255
)
218256

219-
return segmented_outputs
257+
return {"output": segmented_outputs}
258+
259+
def __call__(self, inputs, options, batch_size=None, **kwargs):
260+
if batch_size is None:
261+
if self._batch_size is None:
262+
batch_size = 1
263+
else:
264+
batch_size = self._batch_size
265+
266+
(
267+
preprocess_params,
268+
forward_params,
269+
postprocess_params,
270+
) = self._sanitize_parameters(**kwargs)
271+
272+
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
273+
preprocess_params = {
274+
**self._preprocess_params,
275+
**preprocess_params,
276+
}
277+
options_dict = options._asdict()
278+
forward_params = {**self._forward_params, **forward_params, **options_dict}
279+
postprocess_params = {**self._postprocess_params, **postprocess_params}
280+
281+
self.call_count += 1
282+
if (
283+
self.call_count > 10
284+
and self.framework == "pt"
285+
and self.device.type == "cuda"
286+
):
287+
logging.warning(
288+
"You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
289+
)
290+
291+
return self.get_iterator(
292+
inputs,
293+
batch_size,
294+
preprocess_params,
295+
forward_params,
296+
postprocess_params,
297+
)
298+
299+
def postprocess(self, model_outputs):
300+
return model_outputs
301+
302+
def get_iterator(
303+
self,
304+
inputs,
305+
batch_size: int,
306+
preprocess_params=None,
307+
forward_params=None,
308+
postprocess_params=None,
309+
):
310+
def stack(items):
311+
return {
312+
"inputs": [x["inputs"] for x in items],
313+
"seg_metadata": [x["seg_metadata"] for x in items],
314+
"features": torch.stack([x["features"] for x in items]),
315+
}
316+
317+
if "TOKENIZERS_PARALLELISM" not in os.environ:
318+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
319+
320+
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
321+
dataloader = torch.utils.data.DataLoader(
322+
dataset,
323+
num_workers=self._num_workers,
324+
batch_size=batch_size,
325+
collate_fn=stack,
326+
)
327+
model_iterator = PipelineIterator(
328+
dataloader, self.forward, forward_params, loader_batch_size=batch_size
329+
)
330+
final_iterator = PipelineIterator(
331+
model_iterator, self.postprocess, postprocess_params
332+
)
333+
return final_iterator
220334

221335
def get_language_and_tokenizer(
222336
self, audio, task: Optional[str] = None, language: Optional[str] = None
@@ -255,8 +369,7 @@ def get_language_and_tokenizer(
255369
@staticmethod
256370
def audio_split(audio, segments, sampling_rate):
257371
"""Returns splitted audio chunks as iterator"""
258-
audio_segments = []
259-
segments_metadata = []
372+
260373
for seg in segments:
261374
f1 = int(seg["start"] * sampling_rate)
262375
f2 = int(seg["end"] * sampling_rate)
@@ -265,9 +378,7 @@ def audio_split(audio, segments, sampling_rate):
265378
"end_time": seg["end"],
266379
"stitched_seg": seg["segments"],
267380
}
268-
audio_segments.append(audio[f1:f2])
269-
segments_metadata.append(seg_metadata)
270-
return audio_segments, segments_metadata
381+
yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata}
271382

272383
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
273384
vad_model = Model.from_pretrained(self.vad_model_path)
@@ -462,6 +573,7 @@ def transcribe(
462573
task,
463574
all_language_probs,
464575
) = self.get_language_and_tokenizer(audio, task, language)
576+
batch_size = batch_size or self._batch_size
465577

466578
duration_after_vad = sum(
467579
segment["end"] - segment["start"] for segment in vad_segments
@@ -511,27 +623,10 @@ def transcribe(
511623
all_language_probs=all_language_probs,
512624
)
513625

514-
audio_segments, segments_metadata = self.audio_split(
515-
audio, vad_segments, sampling_rate
516-
)
517-
to_cpu = (
518-
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
519-
)
520-
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
521-
padding=0
522-
)
523-
features = torch.stack(
524-
[
525-
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
526-
..., : self.model.feature_extractor.nb_max_frames
527-
]
528-
for audio_segment in audio_segments
529-
]
530-
)
531-
532626
segments = self._batched_segments_generator(
533-
features,
534-
segments_metadata,
627+
audio,
628+
vad_segments,
629+
sampling_rate,
535630
batch_size,
536631
batched_options,
537632
log_progress,
@@ -540,40 +635,45 @@ def transcribe(
540635
return segments, info
541636

542637
def _batched_segments_generator(
543-
self, features, segments_metadata, batch_size, options, log_progress
638+
self, audio, vad_segments, sampling_rate, batch_size, options, log_progress
544639
):
545-
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
546640
seg_idx = 0
547-
for i in range(0, len(features), batch_size):
548-
results = self.forward(
549-
features[i : i + batch_size],
550-
segments_metadata[i : i + batch_size],
551-
**options._asdict(),
641+
total_segments = len(vad_segments)
642+
for idx, out in enumerate(
643+
self.__call__(
644+
self.audio_split(audio, vad_segments, sampling_rate),
645+
batch_size=batch_size,
646+
options=options,
552647
)
648+
):
649+
if log_progress:
650+
percent_complete = ((idx + 1) / total_segments) * 100
651+
self.model.logger.info(f"Progress: {percent_complete:.2f}%...")
652+
653+
responses = out["output"]
654+
if batch_size == 1:
655+
responses = responses[0]
656+
657+
for response in responses:
658+
seg_idx += 1
659+
segments = Segment(
660+
seek=int(responses[-1]["end"] * self.model.frames_per_second),
661+
id=seg_idx,
662+
text=response["text"],
663+
start=round(response["start"], 3),
664+
end=round(response["end"], 3),
665+
words=(
666+
None
667+
if not options.word_timestamps
668+
else [Word(**word) for word in response["words"]]
669+
),
670+
tokens=response["tokens"],
671+
avg_logprob=response["avg_logprob"],
672+
no_speech_prob=response["no_speech_prob"],
673+
compression_ratio=response["compression_ratio"],
674+
)
675+
yield segments
553676

554-
for result in results:
555-
for segment in result:
556-
seg_idx += 1
557-
yield Segment(
558-
seek=int(result[-1]["end"] * self.model.frames_per_second),
559-
id=seg_idx,
560-
text=segment["text"],
561-
start=round(segment["start"], 3),
562-
end=round(segment["end"], 3),
563-
words=(
564-
None
565-
if not options.word_timestamps
566-
else [Word(**word) for word in segment["words"]]
567-
),
568-
tokens=segment["tokens"],
569-
avg_logprob=segment["avg_logprob"],
570-
no_speech_prob=segment["no_speech_prob"],
571-
compression_ratio=segment["compression_ratio"],
572-
)
573-
574-
pbar.update(1)
575-
576-
pbar.close()
577677
# revert the tokenizer if multilingual inference is enabled
578678
if self.preset_language is None:
579679
self.tokenizer = None

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ctranslate2>=4.0,<5
22
huggingface_hub>=0.13
33
tokenizers>=0.13,<1
44
onnxruntime>=1.14,<2
5+
transformers
56
pyannote-audio>=3.1.1
67
torch>=2.1.1
7-
torchaudio>=2.1.2
8-
tqdm
8+
torchaudio>=2.1.2

0 commit comments

Comments
 (0)