Skip to content

Commit d57c5b4

Browse files
Remove the usage of transformers.pipeline from BatchedInferencePipeline and fix word timestamps for batched inference (SYSTRAN#921)
* fix word timestamps for batched inference * remove hf pipeline
1 parent 83a368e commit d57c5b4

File tree

2 files changed

+72
-172
lines changed

2 files changed

+72
-172
lines changed

faster_whisper/transcribe.py

+70-170
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
import torch
1616

1717
from pyannote.audio import Model
18-
from transformers import Pipeline
19-
from transformers.pipelines.pt_utils import PipelineIterator
18+
from tqdm import tqdm
2019

2120
from faster_whisper.audio import decode_audio, pad_or_trim
2221
from faster_whisper.feature_extractor import FeatureExtractor
@@ -105,7 +104,7 @@ class TranscriptionInfo(NamedTuple):
105104
# (https://github.com/m-bain/whisperX) and adapted for faster_whisper
106105

107106

108-
class BatchedInferencePipeline(Pipeline):
107+
class BatchedInferencePipeline:
109108
"""
110109
Huggingface Pipeline wrapper for WhisperModel.
111110
Copyright (c) 2022, Max Bain
@@ -119,55 +118,29 @@ def __init__(
119118
use_vad_model: bool = True,
120119
options: Optional[NamedTuple] = None,
121120
tokenizer=None,
122-
device: Union[int, str, "torch.device"] = -1,
123121
chunk_length: int = 30,
124122
vad_device: Union[int, str, "torch.device"] = "auto",
125123
vad_onset: float = 0.500,
126124
vad_offset: float = 0.363,
127-
framework="pt",
128125
language: Optional[str] = None,
129-
**kwargs,
130126
):
131127
self.model: WhisperModel = model
132128
self.tokenizer = tokenizer
133129
self.options = options
134130
self.preset_language = language
135-
self._batch_size = kwargs.pop("batch_size", None)
136-
self._num_workers = 0
137131
self.use_vad_model = use_vad_model
138132
self.vad_onset = vad_onset
139133
self.vad_offset = vad_offset
140134
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
141-
self.vad_model = None
142-
143-
(
144-
self._preprocess_params,
145-
self._forward_params,
146-
self._postprocess_params,
147-
) = self._sanitize_parameters(**kwargs)
148-
self.call_count = 0
149-
self.framework = framework
150-
if self.framework == "pt":
151-
self.device = self.get_device(device)
152-
else:
153-
self.device = device
154-
155-
if self.use_vad_model and self.vad_model is None:
135+
if self.use_vad_model:
156136
self.vad_device = self.get_device(vad_device)
157-
158-
# load vad model and perform VAD preprocessing if needed
159137
self.vad_model = self.load_vad_model(
160138
vad_onset=self.vad_onset, vad_offset=self.vad_offset
161139
)
140+
else:
141+
self.vad_model = None
162142
self.chunk_length = chunk_length # VAD merging size
163143
self.last_speech_timestamp = 0.0
164-
super(Pipeline, self).__init__()
165-
166-
def _sanitize_parameters(self, **kwargs):
167-
preprocess_kwargs = {}
168-
if "tokenizer" in kwargs:
169-
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
170-
return preprocess_kwargs, {}, {}
171144

172145
def get_device(self, device: Union[int, str, "torch.device"]):
173146
"""
@@ -193,27 +166,17 @@ def get_device(self, device: Union[int, str, "torch.device"]):
193166
else:
194167
return torch.device(f"cuda:{device}")
195168

196-
def preprocess(self, inputs):
197-
audio = inputs["inputs"]
198-
to_cpu = (
199-
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
200-
)
201-
features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
202-
:, : self.model.feature_extractor.nb_max_frames
203-
]
204-
205-
inputs["features"] = features
206-
del features
207-
return inputs
208-
209-
def _forward(self, model_inputs, **forward_params):
169+
def forward(self, features, segments_metadata, **forward_params):
210170
encoder_output, outputs = self.model.generate_segment_batched(
211-
model_inputs["features"], self.tokenizer, forward_params
171+
features, self.tokenizer, forward_params
212172
)
213173

214-
segment_size = encoder_output.shape[1] * 2
215174
segmented_outputs = []
216-
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
175+
segment_sizes = []
176+
for segment_metadata, output in zip(segments_metadata, outputs):
177+
duration = segment_metadata["end_time"] - segment_metadata["start_time"]
178+
segment_size = int(duration * self.model.frames_per_second)
179+
segment_sizes.append(segment_size)
217180
(
218181
subsegments,
219182
seek,
@@ -223,8 +186,7 @@ def _forward(self, model_inputs, **forward_params):
223186
tokens=output["tokens"],
224187
time_offset=segment_metadata["start_time"],
225188
segment_size=segment_size,
226-
segment_duration=segment_metadata["end_time"]
227-
- segment_metadata["start_time"],
189+
segment_duration=duration,
228190
seek=0,
229191
)
230192
segmented_outputs.append(
@@ -248,89 +210,13 @@ def _forward(self, model_inputs, **forward_params):
248210
segmented_outputs,
249211
self.tokenizer,
250212
encoder_output,
251-
segment_size,
213+
segment_sizes,
252214
forward_params["prepend_punctuations"],
253215
forward_params["append_punctuations"],
254216
self.last_speech_timestamp,
255217
)
256218

257-
return {"output": segmented_outputs}
258-
259-
def __call__(self, inputs, options, batch_size=None, **kwargs):
260-
if batch_size is None:
261-
if self._batch_size is None:
262-
batch_size = 1
263-
else:
264-
batch_size = self._batch_size
265-
266-
(
267-
preprocess_params,
268-
forward_params,
269-
postprocess_params,
270-
) = self._sanitize_parameters(**kwargs)
271-
272-
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
273-
preprocess_params = {
274-
**self._preprocess_params,
275-
**preprocess_params,
276-
}
277-
options_dict = options._asdict()
278-
forward_params = {**self._forward_params, **forward_params, **options_dict}
279-
postprocess_params = {**self._postprocess_params, **postprocess_params}
280-
281-
self.call_count += 1
282-
if (
283-
self.call_count > 10
284-
and self.framework == "pt"
285-
and self.device.type == "cuda"
286-
):
287-
logging.warning(
288-
"You seem to be using the pipelines sequentially on GPU. Please use a Dataset"
289-
)
290-
291-
return self.get_iterator(
292-
inputs,
293-
batch_size,
294-
preprocess_params,
295-
forward_params,
296-
postprocess_params,
297-
)
298-
299-
def postprocess(self, model_outputs):
300-
return model_outputs
301-
302-
def get_iterator(
303-
self,
304-
inputs,
305-
batch_size: int,
306-
preprocess_params=None,
307-
forward_params=None,
308-
postprocess_params=None,
309-
):
310-
def stack(items):
311-
return {
312-
"inputs": [x["inputs"] for x in items],
313-
"seg_metadata": [x["seg_metadata"] for x in items],
314-
"features": torch.stack([x["features"] for x in items]),
315-
}
316-
317-
if "TOKENIZERS_PARALLELISM" not in os.environ:
318-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
319-
320-
dataset = PipelineIterator(inputs, self.preprocess, preprocess_params)
321-
dataloader = torch.utils.data.DataLoader(
322-
dataset,
323-
num_workers=self._num_workers,
324-
batch_size=batch_size,
325-
collate_fn=stack,
326-
)
327-
model_iterator = PipelineIterator(
328-
dataloader, self.forward, forward_params, loader_batch_size=batch_size
329-
)
330-
final_iterator = PipelineIterator(
331-
model_iterator, self.postprocess, postprocess_params
332-
)
333-
return final_iterator
219+
return segmented_outputs
334220

335221
def get_language_and_tokenizer(
336222
self, audio, task: Optional[str] = None, language: Optional[str] = None
@@ -369,7 +255,8 @@ def get_language_and_tokenizer(
369255
@staticmethod
370256
def audio_split(audio, segments, sampling_rate):
371257
"""Returns splitted audio chunks as iterator"""
372-
258+
audio_segments = []
259+
segments_metadata = []
373260
for seg in segments:
374261
f1 = int(seg["start"] * sampling_rate)
375262
f2 = int(seg["end"] * sampling_rate)
@@ -378,7 +265,9 @@ def audio_split(audio, segments, sampling_rate):
378265
"end_time": seg["end"],
379266
"stitched_seg": seg["segments"],
380267
}
381-
yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata}
268+
audio_segments.append(audio[f1:f2])
269+
segments_metadata.append(seg_metadata)
270+
return audio_segments, segments_metadata
382271

383272
def load_vad_model(self, vad_onset=0.500, vad_offset=0.363):
384273
vad_model = Model.from_pretrained(self.vad_model_path)
@@ -573,7 +462,6 @@ def transcribe(
573462
task,
574463
all_language_probs,
575464
) = self.get_language_and_tokenizer(audio, task, language)
576-
batch_size = batch_size or self._batch_size
577465

578466
duration_after_vad = sum(
579467
segment["end"] - segment["start"] for segment in vad_segments
@@ -623,10 +511,27 @@ def transcribe(
623511
all_language_probs=all_language_probs,
624512
)
625513

514+
audio_segments, segments_metadata = self.audio_split(
515+
audio, vad_segments, sampling_rate
516+
)
517+
to_cpu = (
518+
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
519+
)
520+
audio_segments = torch.nested.nested_tensor(audio_segments).to_padded_tensor(
521+
padding=0
522+
)
523+
features = torch.stack(
524+
[
525+
self.model.feature_extractor(audio_segment, to_cpu=to_cpu)[
526+
..., : self.model.feature_extractor.nb_max_frames
527+
]
528+
for audio_segment in audio_segments
529+
]
530+
)
531+
626532
segments = self._batched_segments_generator(
627-
audio,
628-
vad_segments,
629-
sampling_rate,
533+
features,
534+
segments_metadata,
630535
batch_size,
631536
batched_options,
632537
log_progress,
@@ -635,45 +540,40 @@ def transcribe(
635540
return segments, info
636541

637542
def _batched_segments_generator(
638-
self, audio, vad_segments, sampling_rate, batch_size, options, log_progress
543+
self, features, segments_metadata, batch_size, options, log_progress
639544
):
545+
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
640546
seg_idx = 0
641-
total_segments = len(vad_segments)
642-
for idx, out in enumerate(
643-
self.__call__(
644-
self.audio_split(audio, vad_segments, sampling_rate),
645-
batch_size=batch_size,
646-
options=options,
547+
for i in range(0, len(features), batch_size):
548+
results = self.forward(
549+
features[i : i + batch_size],
550+
segments_metadata[i : i + batch_size],
551+
**options._asdict(),
647552
)
648-
):
649-
if log_progress:
650-
percent_complete = ((idx + 1) / total_segments) * 100
651-
self.model.logger.info(f"Progress: {percent_complete:.2f}%...")
652-
653-
responses = out["output"]
654-
if batch_size == 1:
655-
responses = responses[0]
656-
657-
for response in responses:
658-
seg_idx += 1
659-
segments = Segment(
660-
seek=int(responses[-1]["end"] * self.model.frames_per_second),
661-
id=seg_idx,
662-
text=response["text"],
663-
start=round(response["start"], 3),
664-
end=round(response["end"], 3),
665-
words=(
666-
None
667-
if not options.word_timestamps
668-
else [Word(**word) for word in response["words"]]
669-
),
670-
tokens=response["tokens"],
671-
avg_logprob=response["avg_logprob"],
672-
no_speech_prob=response["no_speech_prob"],
673-
compression_ratio=response["compression_ratio"],
674-
)
675-
yield segments
676553

554+
for result in results:
555+
for segment in result:
556+
seg_idx += 1
557+
yield Segment(
558+
seek=int(result[-1]["end"] * self.model.frames_per_second),
559+
id=seg_idx,
560+
text=segment["text"],
561+
start=round(segment["start"], 3),
562+
end=round(segment["end"], 3),
563+
words=(
564+
None
565+
if not options.word_timestamps
566+
else [Word(**word) for word in segment["words"]]
567+
),
568+
tokens=segment["tokens"],
569+
avg_logprob=segment["avg_logprob"],
570+
no_speech_prob=segment["no_speech_prob"],
571+
compression_ratio=segment["compression_ratio"],
572+
)
573+
574+
pbar.update(1)
575+
576+
pbar.close()
677577
# revert the tokenizer if multilingual inference is enabled
678578
if self.preset_language is None:
679579
self.tokenizer = None

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ ctranslate2>=4.0,<5
22
huggingface_hub>=0.13
33
tokenizers>=0.13,<1
44
onnxruntime>=1.14,<2
5-
transformers
65
pyannote-audio>=3.1.1
76
torch>=2.1.1
8-
torchaudio>=2.1.2
7+
torchaudio>=2.1.2
8+
tqdm

0 commit comments

Comments
 (0)