|
4 | 4 | import numpy as np
|
5 | 5 | from typing import BinaryIO, Union, List, Optional, Tuple
|
6 | 6 | import warnings
|
| 7 | +import bisect |
7 | 8 | import faster_whisper
|
8 |
| -from modules.whisper.data_classes import * |
9 | 9 | from faster_whisper.transcribe import SpeechTimestampsMap
|
10 | 10 | import gradio as gr
|
11 | 11 |
|
| 12 | +from modules.whisper.data_classes import * |
| 13 | + |
12 | 14 |
|
13 | 15 | class SileroVAD:
|
14 | 16 | def __init__(self):
|
@@ -58,6 +60,7 @@ def run(self,
|
58 | 60 | vad_options=vad_parameters,
|
59 | 61 | progress=progress
|
60 | 62 | )
|
| 63 | + |
61 | 64 | audio = self.collect_chunks(audio, speech_chunks)
|
62 | 65 | duration_after_vad = audio.shape[0] / sampling_rate
|
63 | 66 |
|
@@ -94,35 +97,27 @@ def get_speech_timestamps(
|
94 | 97 | min_silence_duration_ms = vad_options.min_silence_duration_ms
|
95 | 98 | window_size_samples = self.window_size_samples
|
96 | 99 | speech_pad_ms = vad_options.speech_pad_ms
|
97 |
| - sampling_rate = 16000 |
98 |
| - min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 |
99 |
| - speech_pad_samples = sampling_rate * speech_pad_ms / 1000 |
| 100 | + min_speech_samples = self.sampling_rate * min_speech_duration_ms / 1000 |
| 101 | + speech_pad_samples = self.sampling_rate * speech_pad_ms / 1000 |
100 | 102 | max_speech_samples = (
|
101 |
| - sampling_rate * max_speech_duration_s |
| 103 | + self.sampling_rate * max_speech_duration_s |
102 | 104 | - window_size_samples
|
103 | 105 | - 2 * speech_pad_samples
|
104 | 106 | )
|
105 |
| - min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 |
106 |
| - min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 |
| 107 | + min_silence_samples = self.sampling_rate * min_silence_duration_ms / 1000 |
| 108 | + min_silence_samples_at_max_speech = self.sampling_rate * 98 / 1000 |
107 | 109 |
|
108 | 110 | audio_length_samples = len(audio)
|
109 | 111 |
|
110 |
| - state, context = self.model.get_initial_states(batch_size=1) |
111 |
| - |
112 |
| - speech_probs = [] |
113 |
| - for current_start_sample in range(0, audio_length_samples, window_size_samples): |
114 |
| - progress(current_start_sample/audio_length_samples, desc="Detecting speeches only using VAD...") |
115 |
| - |
116 |
| - chunk = audio[current_start_sample: current_start_sample + window_size_samples] |
117 |
| - if len(chunk) < window_size_samples: |
118 |
| - chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk)))) |
119 |
| - speech_prob, state, context = self.model(chunk, state, context, sampling_rate) |
120 |
| - speech_probs.append(speech_prob) |
| 112 | + padded_audio = np.pad( |
| 113 | + audio, (0, window_size_samples - audio.shape[0] % window_size_samples) |
| 114 | + ) |
| 115 | + speech_probs = self.model(padded_audio.reshape(1, -1)).squeeze(0) |
121 | 116 |
|
122 | 117 | triggered = False
|
123 | 118 | speeches = []
|
124 | 119 | current_speech = {}
|
125 |
| - neg_threshold = threshold - 0.15 |
| 120 | + neg_threshold = vad_options.neg_threshold |
126 | 121 |
|
127 | 122 | # to save potential segment end (and tolerate some silence)
|
128 | 123 | temp_end = 0
|
@@ -258,8 +253,23 @@ def restore_speech_timestamps(
|
258 | 253 | ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
259 | 254 |
|
260 | 255 | for segment in segments:
|
261 |
| - segment.start = ts_map.get_original_time(segment.start) |
262 |
| - segment.end = ts_map.get_original_time(segment.end) |
| 256 | + if segment.words: |
| 257 | + words = [] |
| 258 | + for word in segment.words: |
| 259 | + # Ensure the word start and end times are resolved to the same chunk. |
| 260 | + middle = (word.start + word.end) / 2 |
| 261 | + chunk_index = ts_map.get_chunk_index(middle) |
| 262 | + word.start = ts_map.get_original_time(word.start, chunk_index) |
| 263 | + word.end = ts_map.get_original_time(word.end, chunk_index) |
| 264 | + words.append(word) |
| 265 | + |
| 266 | + segment.start = words[0].start |
| 267 | + segment.end = words[-1].end |
| 268 | + segment.words = words |
| 269 | + |
| 270 | + else: |
| 271 | + segment.start = ts_map.get_original_time(segment.start) |
| 272 | + segment.end = ts_map.get_original_time(segment.end) |
263 | 273 |
|
264 | 274 | return segments
|
265 | 275 |
|
0 commit comments