Skip to content

Commit

Permalink
Merge pull request #475 from niwa2/fix_total_time
Browse files Browse the repository at this point in the history
Fix total time elapsed does not account for all steps
  • Loading branch information
jhj0517 authored Jan 24, 2025
2 parents e44f57f + 009faec commit 0721cc0
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions modules/whisper/base_transcription_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from faster_whisper.vad import VadOptions
import gc
from copy import deepcopy
import time

from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
Expand Down Expand Up @@ -109,6 +110,8 @@ def run(self,
elapsed_time: float
elapsed time for running
"""
start_time = time.time()

if not validate_audio(audio):
logger.info(f"The audio file {audio} is not able to open or corrupted. Please check the file.")
return [Segment()], 0
Expand Down Expand Up @@ -137,10 +140,12 @@ def run(self,

if bgm_params.enable_offload:
self.music_separator.offload()
elapsed_time_bgm_sep = time.time() - start_time

origin_audio = deepcopy(audio)

if vad_params.vad_filter:
progress(0, desc="Filtering silent parts from audio..")
vad_options = VadOptions(
threshold=vad_params.threshold,
min_speech_duration_ms=vad_params.min_speech_duration_ms,
Expand All @@ -160,7 +165,7 @@ def run(self,
else:
vad_params.vad_filter = False

result, elapsed_time = self.transcribe(
result, elapsed_time_transcription = self.transcribe(
audio,
progress,
*whisper_params.to_list()
Expand All @@ -177,20 +182,23 @@ def run(self,
logger.info("VAD detected no speech segments in the audio.")

if diarization_params.is_diarize:
progress(0.99, desc="Diarizing speakers..")
result, elapsed_time_diarization = self.diarizer.run(
audio=origin_audio,
use_auth_token=diarization_params.hf_token if diarization_params.hf_token else os.environ.get("HF_TOKEN"),
transcribed_result=result,
device=diarization_params.diarization_device
)
elapsed_time += elapsed_time_diarization

self.cache_parameters(
params=params,
file_format=file_format,
add_timestamp=add_timestamp
)
return result, elapsed_time

progress(1.0, desc="Finished.")
total_elapsed_time = time.time() - start_time
return result, total_elapsed_time

def transcribe_file(self,
files: Optional[List] = None,
Expand Down

0 comments on commit 0721cc0

Please sign in to comment.