diff --git a/whisper/demo_seek_fix.py b/whisper/demo_seek_fix.py new file mode 100644 index 000000000..d870f7de5 --- /dev/null +++ b/whisper/demo_seek_fix.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +"""Demo: transcribe short audio with whisper-base.en to test seek behavior.""" + +import os +import string +import urllib.request +import mlx_whisper + +# Download test audio +audio = "LJ037-0171.wav" +if not os.path.exists(audio): + urllib.request.urlretrieve(f"https://keithito.com/LJ-Speech-Dataset/{audio}", audio) + +# Expected transcription +expected = "The examination and testimony of the experts enabled the commission to conclude that five shots may have been fired" + +# Transcribe +result = mlx_whisper.transcribe(audio, path_or_hf_repo="mlx-community/whisper-base.en-mlx") + +# Compute accuracy +strip = str.maketrans("", "", string.punctuation) +expected_words = set(expected.lower().translate(strip).split()) +actual_words = set(result["text"].lower().translate(strip).split()) +accuracy = len(expected_words & actual_words) / len(expected_words) * 100 + +# Output +print(f"Expected: {expected}") +print(f"Actual: {result['text'].strip()}") +print(f"Accuracy: {accuracy:.0f}%") +print(f"Segments: {len(result['segments'])}") +for s in result["segments"]: + print(f" [{s['start']:.2f}s - {s['end']:.2f}s]{s['text']}") diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index bced16a58..ae4af6dff 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -380,8 +380,22 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: last_slice = current_slice if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size + # When single_timestamp_ending and there's remaining audio, + # advance to the timestamp position instead of full segment to avoid + # skipping content in short audio clips. + last_timestamp_token = tokens[-1].item() + if last_timestamp_token != tokenizer.timestamp_begin: + last_timestamp_pos = ( + last_timestamp_token - tokenizer.timestamp_begin + ) + timestamp_seek = last_timestamp_pos * input_stride + # Only use timestamp-based seek if there's remaining audio + if seek + timestamp_seek < content_frames: + seek += timestamp_seek + else: + seek += segment_size + else: + seek += segment_size else: # otherwise, ignore the unfinished segment and seek to the last timestamp last_timestamp_pos = ( @@ -409,7 +423,24 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: result=result, ) ) - seek += segment_size + # When single_timestamp_ending and there's remaining audio, + # advance to the timestamp position instead of full segment to avoid + # skipping content in short audio clips. + if ( + single_timestamp_ending + and len(timestamps) > 0 + and timestamps[-1].item() != tokenizer.timestamp_begin + ): + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin + ) + timestamp_seek = last_timestamp_pos * input_stride + if seek + timestamp_seek < content_frames: + seek += timestamp_seek + else: + seek += segment_size + else: + seek += segment_size if word_timestamps: add_word_timestamps(