Skip to content

Commit 1d9c197

Browse files
committed
refactor: extract duplicate code, add VAD tests, add duration to response
- Extract duplicate segment processing into inner process_segment() function - Add duration tracking to VAD mode for consistent response format - Add 3 new tests: single segment, multiple segments, VAD not available - Remove unused .claude/REPORT.md from commit
1 parent c52eb77 commit 1d9c197

File tree

2 files changed

+377
-61
lines changed

2 files changed

+377
-61
lines changed

agent_cli/server/whisper/api.py

Lines changed: 229 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,42 @@
1616
from agent_cli.server.whisper.backends.base import InvalidAudioError
1717

1818
if TYPE_CHECKING:
19+
from agent_cli.core.vad import VoiceActivityDetector
1920
from agent_cli.server.whisper.model_registry import WhisperModelRegistry
2021

2122
logger = logging.getLogger(__name__)
2223

24+
# VAD availability check - the vad extra may not be installed
25+
_VAD_AVAILABLE = False
26+
try:
27+
from agent_cli.core.vad import VoiceActivityDetector as _VoiceActivityDetector
28+
29+
_VAD_AVAILABLE = True
30+
except ImportError:
31+
_VoiceActivityDetector = None # type: ignore[misc, assignment]
32+
33+
34+
def _create_vad(
35+
threshold: float,
36+
silence_threshold_ms: int,
37+
min_speech_duration_ms: int,
38+
) -> VoiceActivityDetector:
39+
"""Create a VoiceActivityDetector instance.
40+
41+
Raises ImportError if VAD is not available.
42+
"""
43+
if not _VAD_AVAILABLE:
44+
msg = (
45+
"VAD is not available. Install it with: "
46+
"`pip install agent-cli[vad]` or `uv sync --extra vad`"
47+
)
48+
raise ImportError(msg)
49+
return _VoiceActivityDetector(
50+
threshold=threshold,
51+
silence_threshold_ms=silence_threshold_ms,
52+
min_speech_duration_ms=min_speech_duration_ms,
53+
)
54+
2355

2456
def _split_seconds(seconds: float) -> tuple[int, int, int, int]:
2557
"""Split seconds into (hours, minutes, seconds, milliseconds)."""
@@ -316,17 +348,41 @@ async def stream_transcription(
316348
websocket: WebSocket,
317349
model: Annotated[str | None, Query(description="Model to use")] = None,
318350
language: Annotated[str | None, Query(description="Language code")] = None,
351+
use_vad: Annotated[
352+
bool,
353+
Query(description="Enable VAD for streaming partial results"),
354+
] = True,
355+
vad_threshold: Annotated[
356+
float,
357+
Query(description="Speech detection threshold (0.0-1.0)", ge=0.0, le=1.0),
358+
] = 0.3,
359+
vad_silence_ms: Annotated[
360+
int,
361+
Query(description="Silence duration (ms) to end speech segment", ge=100, le=5000),
362+
] = 1000,
363+
vad_min_speech_ms: Annotated[
364+
int,
365+
Query(
366+
description="Minimum speech duration (ms) to trigger transcription",
367+
ge=50,
368+
le=2000,
369+
),
370+
] = 250,
319371
) -> None:
320-
"""WebSocket endpoint for streaming transcription.
372+
"""WebSocket endpoint for streaming transcription with optional VAD.
321373
322374
Protocol:
323375
- Client sends binary audio chunks (16kHz, 16-bit, mono PCM)
324376
- Client sends b"EOS" to signal end of audio
325377
- Server sends JSON messages with transcription results
326378
379+
When use_vad=True (default):
380+
- Partial transcriptions are sent as speech segments complete
381+
- Final message contains combined text from all segments
382+
327383
Message format from server:
328-
{"type": "partial", "text": "...", "is_final": false}
329-
{"type": "final", "text": "...", "is_final": true, "segments": [...]}
384+
{"type": "partial", "text": "...", "is_final": false, "language": "..."}
385+
{"type": "final", "text": "...", "is_final": true, "language": "...", ...}
330386
{"type": "error", "message": "..."}
331387
"""
332388
await websocket.accept()
@@ -340,74 +396,188 @@ async def stream_transcription(
340396
await websocket.close()
341397
return
342398

343-
# Collect audio data
344-
audio_buffer = io.BytesIO()
345-
wav_file: wave.Wave_write | None = None
346-
347-
try:
348-
while True:
349-
data = await websocket.receive_bytes()
350-
351-
# Initialize WAV file on first chunk (before EOS check)
352-
if wav_file is None:
353-
wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115
354-
setup_wav_file(wav_file)
355-
356-
# Check for end of stream (EOS marker)
357-
eos_marker = b"EOS"
358-
eos_len = len(eos_marker)
359-
if data == eos_marker:
360-
break
361-
if data[-eos_len:] == eos_marker:
362-
# Write remaining data before EOS marker
363-
if len(data) > eos_len:
364-
wav_file.writeframes(data[:-eos_len])
365-
break
366-
367-
wav_file.writeframes(data)
368-
369-
# Close WAV file
370-
if wav_file is not None:
371-
wav_file.close()
372-
373-
# Get audio data
374-
audio_buffer.seek(0)
375-
audio_data = audio_buffer.read()
376-
377-
if not audio_data:
378-
await websocket.send_json({"type": "error", "message": "No audio received"})
399+
# Initialize VAD if requested
400+
vad = None
401+
if use_vad:
402+
try:
403+
vad = _create_vad(
404+
threshold=vad_threshold,
405+
silence_threshold_ms=vad_silence_ms,
406+
min_speech_duration_ms=vad_min_speech_ms,
407+
)
408+
except ImportError as e:
409+
await websocket.send_json({"type": "error", "message": str(e)})
379410
await websocket.close()
380411
return
381412

382-
# Transcribe
383-
try:
384-
result = await manager.transcribe(
385-
audio_data,
386-
language=language,
387-
task="transcribe",
388-
)
413+
try:
414+
if vad is not None:
415+
# VAD-enabled streaming mode
416+
await _stream_with_vad(websocket, manager, vad, language)
417+
else:
418+
# Legacy buffered mode (no VAD)
419+
await _stream_buffered(websocket, manager, language)
420+
except Exception as e:
421+
logger.exception("WebSocket error")
422+
with contextlib.suppress(Exception):
423+
await websocket.send_json({"type": "error", "message": str(e)})
424+
finally:
425+
with contextlib.suppress(Exception):
426+
await websocket.close()
389427

428+
async def _stream_with_vad(
429+
websocket: WebSocket,
430+
manager: Any,
431+
vad: VoiceActivityDetector,
432+
language: str | None,
433+
) -> None:
434+
"""Handle streaming transcription with VAD-based segmentation."""
435+
all_segments_text: list[str] = []
436+
total_duration: float = 0.0
437+
final_language: str | None = None
438+
eos_marker = b"EOS"
439+
eos_len = len(eos_marker)
440+
441+
async def process_segment(segment: bytes) -> None:
442+
"""Transcribe segment and send partial result."""
443+
nonlocal final_language, total_duration
444+
result = await _transcribe_segment(manager, segment, language)
445+
if result and result.text.strip():
446+
all_segments_text.append(result.text.strip())
447+
final_language = result.language
448+
total_duration += result.duration
390449
await websocket.send_json(
391450
{
392-
"type": "final",
393-
"text": result.text,
394-
"is_final": True,
451+
"type": "partial",
452+
"text": result.text.strip(),
453+
"is_final": False,
395454
"language": result.language,
396-
"duration": result.duration,
397-
"segments": result.segments,
398455
},
399456
)
400457

401-
except Exception as e:
402-
await websocket.send_json({"type": "error", "message": str(e)})
458+
while True:
459+
data = await websocket.receive_bytes()
460+
461+
# Check for end of stream
462+
is_eos = data == eos_marker
463+
audio_chunk = b""
464+
465+
if is_eos:
466+
pass # No audio to process
467+
elif data[-eos_len:] == eos_marker:
468+
# Audio followed by EOS marker
469+
audio_chunk = data[:-eos_len]
470+
is_eos = True
471+
else:
472+
audio_chunk = data
473+
474+
# Process audio chunk through VAD
475+
if audio_chunk:
476+
_is_speaking, segment = vad.process_chunk(audio_chunk)
477+
if segment:
478+
await process_segment(segment)
479+
480+
if is_eos:
481+
# Flush any remaining audio in VAD buffer
482+
if remaining := vad.flush():
483+
await process_segment(remaining)
484+
break
485+
486+
# Send final combined result
487+
final_text = " ".join(all_segments_text)
488+
await websocket.send_json(
489+
{
490+
"type": "final",
491+
"text": final_text,
492+
"is_final": True,
493+
"language": final_language,
494+
"duration": total_duration,
495+
},
496+
)
497+
498+
async def _stream_buffered(
499+
websocket: WebSocket,
500+
manager: Any,
501+
language: str | None,
502+
) -> None:
503+
"""Handle streaming transcription with buffered mode (no VAD)."""
504+
audio_buffer = io.BytesIO()
505+
wav_file: wave.Wave_write | None = None
506+
eos_marker = b"EOS"
507+
eos_len = len(eos_marker)
508+
509+
while True:
510+
data = await websocket.receive_bytes()
511+
512+
# Initialize WAV file on first chunk (before EOS check)
513+
if wav_file is None:
514+
wav_file = wave.open(audio_buffer, "wb") # noqa: SIM115
515+
setup_wav_file(wav_file)
516+
517+
# Check for end of stream
518+
if data == eos_marker:
519+
break
520+
if data[-eos_len:] == eos_marker:
521+
# Write remaining data before EOS marker
522+
if len(data) > eos_len:
523+
wav_file.writeframes(data[:-eos_len])
524+
break
525+
526+
wav_file.writeframes(data)
527+
528+
# Close WAV file
529+
if wav_file is not None:
530+
wav_file.close()
531+
532+
# Get audio data
533+
audio_buffer.seek(0)
534+
audio_data = audio_buffer.read()
535+
536+
if not audio_data:
537+
await websocket.send_json({"type": "error", "message": "No audio received"})
538+
return
539+
540+
# Transcribe
541+
try:
542+
result = await manager.transcribe(
543+
audio_data,
544+
language=language,
545+
task="transcribe",
546+
)
403547

548+
await websocket.send_json(
549+
{
550+
"type": "final",
551+
"text": result.text,
552+
"is_final": True,
553+
"language": result.language,
554+
"duration": result.duration,
555+
"segments": result.segments,
556+
},
557+
)
404558
except Exception as e:
405-
logger.exception("WebSocket error")
406-
with contextlib.suppress(Exception):
407-
await websocket.send_json({"type": "error", "message": str(e)})
559+
await websocket.send_json({"type": "error", "message": str(e)})
408560

409-
finally:
410-
with contextlib.suppress(Exception):
411-
await websocket.close()
561+
async def _transcribe_segment(
562+
manager: Any,
563+
segment: bytes,
564+
language: str | None,
565+
) -> Any | None:
566+
"""Transcribe a raw PCM audio segment by wrapping it in WAV format."""
567+
try:
568+
# Wrap raw PCM in WAV format for transcription
569+
wav_buffer = io.BytesIO()
570+
with wave.open(wav_buffer, "wb") as wav_file:
571+
setup_wav_file(wav_file)
572+
wav_file.writeframes(segment)
573+
wav_buffer.seek(0)
574+
return await manager.transcribe(
575+
wav_buffer.read(),
576+
language=language,
577+
task="transcribe",
578+
)
579+
except Exception:
580+
logger.exception("Failed to transcribe segment")
581+
return None
412582

413583
return app

0 commit comments

Comments
 (0)