diff --git a/moshi/moshi/models/lm.py b/moshi/moshi/models/lm.py index a64d736b..c3e26584 100644 --- a/moshi/moshi/models/lm.py +++ b/moshi/moshi/models/lm.py @@ -657,6 +657,8 @@ def __init__( report_loss: bool = False, return_logits: bool = False, audio_silence_frame_cnt: int = 1, + on_silence_start: Optional[Callable] = None, + on_speech_start: Optional[Callable] = None, text_prompt_tokens: Optional[list[int]] = None, save_voice_prompt_embeddings: bool = False, sample_rate: int = 32000, @@ -673,6 +675,8 @@ def __init__( self.top_k_text = top_k_text self.text_prompt_tokens = text_prompt_tokens self.audio_silence_frame_cnt = audio_silence_frame_cnt + self.on_silence_start = on_silence_start + self.on_speech_start = on_speech_start self.voice_prompt = None self.zero_text_code = 3 self._frame_rate = frame_rate @@ -1089,9 +1093,15 @@ def _step_audio_silence(self): pass async def _step_audio_silence_async(self, is_alive: Optional[Callable]=None): + # Notify silence started if callback provided + if self.on_silence_start is not None: + await self.on_silence_start() for _ in self._step_audio_silence_core(): if is_alive is not None and not await is_alive(): break + # Notify silence ended (speech starting) if callback provided + if self.on_speech_start is not None: + await self.on_speech_start() def _step_text_prompt_core(self) -> Iterator[None]: for text_prompt_token in self.text_prompt_tokens: