Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions whisper/demo_seek_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env python3
"""Demo: transcribe short audio with whisper-base.en to test seek behavior."""

import os
import string
import urllib.request
import mlx_whisper

# Download test audio
audio = "LJ037-0171.wav"
if not os.path.exists(audio):
urllib.request.urlretrieve(f"https://keithito.com/LJ-Speech-Dataset/{audio}", audio)

# Expected transcription
expected = "The examination and testimony of the experts enabled the commission to conclude that five shots may have been fired"

# Transcribe
result = mlx_whisper.transcribe(audio, path_or_hf_repo="mlx-community/whisper-base.en-mlx")

# Compute accuracy
strip = str.maketrans("", "", string.punctuation)
expected_words = set(expected.lower().translate(strip).split())
actual_words = set(result["text"].lower().translate(strip).split())
accuracy = len(expected_words & actual_words) / len(expected_words) * 100

# Output
print(f"Expected: {expected}")
print(f"Actual: {result['text'].strip()}")
print(f"Accuracy: {accuracy:.0f}%")
print(f"Segments: {len(result['segments'])}")
for s in result["segments"]:
print(f" [{s['start']:.2f}s - {s['end']:.2f}s]{s['text']}")
37 changes: 34 additions & 3 deletions whisper/mlx_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,22 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
last_slice = current_slice

if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
# When single_timestamp_ending and there's remaining audio,
# advance to the timestamp position instead of full segment to avoid
# skipping content in short audio clips.
last_timestamp_token = tokens[-1].item()
if last_timestamp_token != tokenizer.timestamp_begin:
last_timestamp_pos = (
last_timestamp_token - tokenizer.timestamp_begin
)
timestamp_seek = last_timestamp_pos * input_stride
# Only use timestamp-based seek if there's remaining audio
if seek + timestamp_seek < content_frames:
seek += timestamp_seek
else:
seek += segment_size
else:
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
Expand Down Expand Up @@ -409,7 +423,24 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
result=result,
)
)
seek += segment_size
# When single_timestamp_ending and there's remaining audio,
# advance to the timestamp position instead of full segment to avoid
# skipping content in short audio clips.
if (
single_timestamp_ending
and len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
timestamp_seek = last_timestamp_pos * input_stride
if seek + timestamp_seek < content_frames:
seek += timestamp_seek
else:
seek += segment_size
else:
seek += segment_size

if word_timestamps:
add_word_timestamps(
Expand Down