Skip to content

Commit

Permalink
added --task parameter to allow translate to english from others lang…
Browse files Browse the repository at this point in the history
…uages
  • Loading branch information
M4TH1EU committed Jan 28, 2025
1 parent 85bd1e7 commit 1547915
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
105 changes: 105 additions & 0 deletions tests/test_faster_whisper_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Tests for wyoming-faster-whisper"""

import asyncio
import re
import sys
import wave
from asyncio.subprocess import PIPE
from pathlib import Path

import pytest
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioStart, AudioStop, wav_to_chunks
from wyoming.event import async_read_event, async_write_event
from wyoming.info import Describe, Info

_DIR = Path(__file__).parent
_PROGRAM_DIR = _DIR.parent
_LOCAL_DIR = _PROGRAM_DIR / "local"
_SAMPLES_PER_CHUNK = 1024

# Need to give time for the model to download
_START_TIMEOUT = 60
_TRANSCRIBE_TIMEOUT = 60


@pytest.mark.asyncio
async def test_faster_whisper() -> None:
proc = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"wyoming_faster_whisper",
"--uri",
"stdio://",
"--model",
"base-int8",
"--data-dir",
str(_LOCAL_DIR),
"--task",
"translate",
"--language",
"fr",
stdin=PIPE,
stdout=PIPE,
)
assert proc.stdin is not None
assert proc.stdout is not None

# Check info
await async_write_event(Describe().event(), proc.stdin)
while True:
event = await asyncio.wait_for(
async_read_event(proc.stdout), timeout=_START_TIMEOUT
)
assert event is not None

if not Info.is_type(event.type):
continue

info = Info.from_event(event)
assert len(info.asr) == 1, "Expected one asr service"
asr = info.asr[0]
assert len(asr.models) > 0, "Expected at least one model"
assert any(
m.name == "base-int8" for m in asr.models
), "Expected base-int8 model"
break

# We want to use the whisper model
await async_write_event(Transcribe(name="base-int8").event(), proc.stdin)

# Test known WAV
with wave.open(str(_DIR / "whats_your_name_french.wav"), "rb") as example_wav:
await async_write_event(
AudioStart(
rate=example_wav.getframerate(),
width=example_wav.getsampwidth(),
channels=example_wav.getnchannels(),
).event(),
proc.stdin,
)
for chunk in wav_to_chunks(example_wav, _SAMPLES_PER_CHUNK):
await async_write_event(chunk.event(), proc.stdin)

await async_write_event(AudioStop().event(), proc.stdin)

while True:
event = await asyncio.wait_for(
async_read_event(proc.stdout), timeout=_TRANSCRIBE_TIMEOUT
)
assert event is not None

if not Transcript.is_type(event.type):
continue

transcript = Transcript.from_event(event)
text = transcript.text.lower().strip()
text = re.sub(r"[^a-z ]", "", text)
assert text == "how do you call yourself"
break

# Need to close stdin for graceful termination
proc.stdin.close()
_, stderr = await proc.communicate()

assert proc.returncode == 0, stderr.decode()
Binary file added tests/whats_your_name_french.wav
Binary file not shown.
6 changes: 6 additions & 0 deletions wyoming_faster_whisper/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ async def main() -> None:
"--initial-prompt",
help="Optional text to provide as a prompt for the first window",
)
parser.add_argument(
"--task",
default="transcribe",
help="Whether to transcribe or translate (default: transcribe)",
choices=["transcribe", "translate"],
)
#
parser.add_argument("--debug", action="store_true", help="Log DEBUG messages")
parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions wyoming_faster_whisper/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
self.model_lock = model_lock
self.initial_prompt = initial_prompt
self._language = self.cli_args.language
self._task = self.cli_args.task
self._wav_dir = tempfile.TemporaryDirectory()
self._wav_path = os.path.join(self._wav_dir.name, "speech.wav")
self._wav_file: Optional[wave.Wave_write] = None
Expand Down Expand Up @@ -71,6 +72,7 @@ async def handle_event(self, event: Event) -> bool:
beam_size=self.cli_args.beam_size,
language=self._language,
initial_prompt=self.initial_prompt,
task=self._task,
)

text = " ".join(segment.text for segment in segments)
Expand Down

0 comments on commit 1547915

Please sign in to comment.