1616from agent_cli .server .whisper .backends .base import InvalidAudioError
1717
1818if TYPE_CHECKING :
19+ from agent_cli .core .vad import VoiceActivityDetector
1920 from agent_cli .server .whisper .model_registry import WhisperModelRegistry
2021
2122logger = logging .getLogger (__name__ )
2223
24+ # VAD availability check - the vad extra may not be installed
25+ _VAD_AVAILABLE = False
26+ try :
27+ from agent_cli .core .vad import VoiceActivityDetector as _VoiceActivityDetector
28+
29+ _VAD_AVAILABLE = True
30+ except ImportError :
31+ _VoiceActivityDetector = None # type: ignore[misc, assignment]
32+
33+
34+ def _create_vad (
35+ threshold : float ,
36+ silence_threshold_ms : int ,
37+ min_speech_duration_ms : int ,
38+ ) -> VoiceActivityDetector :
39+ """Create a VoiceActivityDetector instance.
40+
41+ Raises ImportError if VAD is not available.
42+ """
43+ if not _VAD_AVAILABLE :
44+ msg = (
45+ "VAD is not available. Install it with: "
46+ "`pip install agent-cli[vad]` or `uv sync --extra vad`"
47+ )
48+ raise ImportError (msg )
49+ return _VoiceActivityDetector (
50+ threshold = threshold ,
51+ silence_threshold_ms = silence_threshold_ms ,
52+ min_speech_duration_ms = min_speech_duration_ms ,
53+ )
54+
2355
2456def _split_seconds (seconds : float ) -> tuple [int , int , int , int ]:
2557 """Split seconds into (hours, minutes, seconds, milliseconds)."""
@@ -316,17 +348,41 @@ async def stream_transcription(
316348 websocket : WebSocket ,
317349 model : Annotated [str | None , Query (description = "Model to use" )] = None ,
318350 language : Annotated [str | None , Query (description = "Language code" )] = None ,
351+ use_vad : Annotated [
352+ bool ,
353+ Query (description = "Enable VAD for streaming partial results" ),
354+ ] = True ,
355+ vad_threshold : Annotated [
356+ float ,
357+ Query (description = "Speech detection threshold (0.0-1.0)" , ge = 0.0 , le = 1.0 ),
358+ ] = 0.3 ,
359+ vad_silence_ms : Annotated [
360+ int ,
361+ Query (description = "Silence duration (ms) to end speech segment" , ge = 100 , le = 5000 ),
362+ ] = 1000 ,
363+ vad_min_speech_ms : Annotated [
364+ int ,
365+ Query (
366+ description = "Minimum speech duration (ms) to trigger transcription" ,
367+ ge = 50 ,
368+ le = 2000 ,
369+ ),
370+ ] = 250 ,
319371 ) -> None :
320- """WebSocket endpoint for streaming transcription.
372+ """WebSocket endpoint for streaming transcription with optional VAD .
321373
322374 Protocol:
323375 - Client sends binary audio chunks (16kHz, 16-bit, mono PCM)
324376 - Client sends b"EOS" to signal end of audio
325377 - Server sends JSON messages with transcription results
326378
379+ When use_vad=True (default):
380+ - Partial transcriptions are sent as speech segments complete
381+ - Final message contains combined text from all segments
382+
327383 Message format from server:
328- {"type": "partial", "text": "...", "is_final": false}
329- {"type": "final", "text": "...", "is_final": true, "segments ": [ ...] }
384+ {"type": "partial", "text": "...", "is_final": false, "language": "..." }
385+ {"type": "final", "text": "...", "is_final": true, "language ": " ...", ... }
330386 {"type": "error", "message": "..."}
331387 """
332388 await websocket .accept ()
@@ -340,74 +396,188 @@ async def stream_transcription(
340396 await websocket .close ()
341397 return
342398
343- # Collect audio data
344- audio_buffer = io .BytesIO ()
345- wav_file : wave .Wave_write | None = None
346-
347- try :
348- while True :
349- data = await websocket .receive_bytes ()
350-
351- # Initialize WAV file on first chunk (before EOS check)
352- if wav_file is None :
353- wav_file = wave .open (audio_buffer , "wb" ) # noqa: SIM115
354- setup_wav_file (wav_file )
355-
356- # Check for end of stream (EOS marker)
357- eos_marker = b"EOS"
358- eos_len = len (eos_marker )
359- if data == eos_marker :
360- break
361- if data [- eos_len :] == eos_marker :
362- # Write remaining data before EOS marker
363- if len (data ) > eos_len :
364- wav_file .writeframes (data [:- eos_len ])
365- break
366-
367- wav_file .writeframes (data )
368-
369- # Close WAV file
370- if wav_file is not None :
371- wav_file .close ()
372-
373- # Get audio data
374- audio_buffer .seek (0 )
375- audio_data = audio_buffer .read ()
376-
377- if not audio_data :
378- await websocket .send_json ({"type" : "error" , "message" : "No audio received" })
399+ # Initialize VAD if requested
400+ vad = None
401+ if use_vad :
402+ try :
403+ vad = _create_vad (
404+ threshold = vad_threshold ,
405+ silence_threshold_ms = vad_silence_ms ,
406+ min_speech_duration_ms = vad_min_speech_ms ,
407+ )
408+ except ImportError as e :
409+ await websocket .send_json ({"type" : "error" , "message" : str (e )})
379410 await websocket .close ()
380411 return
381412
382- # Transcribe
383- try :
384- result = await manager .transcribe (
385- audio_data ,
386- language = language ,
387- task = "transcribe" ,
388- )
413+ try :
414+ if vad is not None :
415+ # VAD-enabled streaming mode
416+ await _stream_with_vad (websocket , manager , vad , language )
417+ else :
418+ # Legacy buffered mode (no VAD)
419+ await _stream_buffered (websocket , manager , language )
420+ except Exception as e :
421+ logger .exception ("WebSocket error" )
422+ with contextlib .suppress (Exception ):
423+ await websocket .send_json ({"type" : "error" , "message" : str (e )})
424+ finally :
425+ with contextlib .suppress (Exception ):
426+ await websocket .close ()
389427
428+ async def _stream_with_vad (
429+ websocket : WebSocket ,
430+ manager : Any ,
431+ vad : VoiceActivityDetector ,
432+ language : str | None ,
433+ ) -> None :
434+ """Handle streaming transcription with VAD-based segmentation."""
435+ all_segments_text : list [str ] = []
436+ total_duration : float = 0.0
437+ final_language : str | None = None
438+ eos_marker = b"EOS"
439+ eos_len = len (eos_marker )
440+
441+ async def process_segment (segment : bytes ) -> None :
442+ """Transcribe segment and send partial result."""
443+ nonlocal final_language , total_duration
444+ result = await _transcribe_segment (manager , segment , language )
445+ if result and result .text .strip ():
446+ all_segments_text .append (result .text .strip ())
447+ final_language = result .language
448+ total_duration += result .duration
390449 await websocket .send_json (
391450 {
392- "type" : "final " ,
393- "text" : result .text ,
394- "is_final" : True ,
451+ "type" : "partial " ,
452+ "text" : result .text . strip () ,
453+ "is_final" : False ,
395454 "language" : result .language ,
396- "duration" : result .duration ,
397- "segments" : result .segments ,
398455 },
399456 )
400457
401- except Exception as e :
402- await websocket .send_json ({"type" : "error" , "message" : str (e )})
458+ while True :
459+ data = await websocket .receive_bytes ()
460+
461+ # Check for end of stream
462+ is_eos = data == eos_marker
463+ audio_chunk = b""
464+
465+ if is_eos :
466+ pass # No audio to process
467+ elif data [- eos_len :] == eos_marker :
468+ # Audio followed by EOS marker
469+ audio_chunk = data [:- eos_len ]
470+ is_eos = True
471+ else :
472+ audio_chunk = data
473+
474+ # Process audio chunk through VAD
475+ if audio_chunk :
476+ _is_speaking , segment = vad .process_chunk (audio_chunk )
477+ if segment :
478+ await process_segment (segment )
479+
480+ if is_eos :
481+ # Flush any remaining audio in VAD buffer
482+ if remaining := vad .flush ():
483+ await process_segment (remaining )
484+ break
485+
486+ # Send final combined result
487+ final_text = " " .join (all_segments_text )
488+ await websocket .send_json (
489+ {
490+ "type" : "final" ,
491+ "text" : final_text ,
492+ "is_final" : True ,
493+ "language" : final_language ,
494+ "duration" : total_duration ,
495+ },
496+ )
497+
498+ async def _stream_buffered (
499+ websocket : WebSocket ,
500+ manager : Any ,
501+ language : str | None ,
502+ ) -> None :
503+ """Handle streaming transcription with buffered mode (no VAD)."""
504+ audio_buffer = io .BytesIO ()
505+ wav_file : wave .Wave_write | None = None
506+ eos_marker = b"EOS"
507+ eos_len = len (eos_marker )
508+
509+ while True :
510+ data = await websocket .receive_bytes ()
511+
512+ # Initialize WAV file on first chunk (before EOS check)
513+ if wav_file is None :
514+ wav_file = wave .open (audio_buffer , "wb" ) # noqa: SIM115
515+ setup_wav_file (wav_file )
516+
517+ # Check for end of stream
518+ if data == eos_marker :
519+ break
520+ if data [- eos_len :] == eos_marker :
521+ # Write remaining data before EOS marker
522+ if len (data ) > eos_len :
523+ wav_file .writeframes (data [:- eos_len ])
524+ break
525+
526+ wav_file .writeframes (data )
527+
528+ # Close WAV file
529+ if wav_file is not None :
530+ wav_file .close ()
531+
532+ # Get audio data
533+ audio_buffer .seek (0 )
534+ audio_data = audio_buffer .read ()
535+
536+ if not audio_data :
537+ await websocket .send_json ({"type" : "error" , "message" : "No audio received" })
538+ return
539+
540+ # Transcribe
541+ try :
542+ result = await manager .transcribe (
543+ audio_data ,
544+ language = language ,
545+ task = "transcribe" ,
546+ )
403547
548+ await websocket .send_json (
549+ {
550+ "type" : "final" ,
551+ "text" : result .text ,
552+ "is_final" : True ,
553+ "language" : result .language ,
554+ "duration" : result .duration ,
555+ "segments" : result .segments ,
556+ },
557+ )
404558 except Exception as e :
405- logger .exception ("WebSocket error" )
406- with contextlib .suppress (Exception ):
407- await websocket .send_json ({"type" : "error" , "message" : str (e )})
559+ await websocket .send_json ({"type" : "error" , "message" : str (e )})
408560
409- finally :
410- with contextlib .suppress (Exception ):
411- await websocket .close ()
561+ async def _transcribe_segment (
562+ manager : Any ,
563+ segment : bytes ,
564+ language : str | None ,
565+ ) -> Any | None :
566+ """Transcribe a raw PCM audio segment by wrapping it in WAV format."""
567+ try :
568+ # Wrap raw PCM in WAV format for transcription
569+ wav_buffer = io .BytesIO ()
570+ with wave .open (wav_buffer , "wb" ) as wav_file :
571+ setup_wav_file (wav_file )
572+ wav_file .writeframes (segment )
573+ wav_buffer .seek (0 )
574+ return await manager .transcribe (
575+ wav_buffer .read (),
576+ language = language ,
577+ task = "transcribe" ,
578+ )
579+ except Exception :
580+ logger .exception ("Failed to transcribe segment" )
581+ return None
412582
413583 return app
0 commit comments