From 3db876ca9940a67a101c7e3c0bf4a7e4883370d9 Mon Sep 17 00:00:00 2001 From: Hamdan Anwar Sayeed <96612374+s-hamdananwar@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:28:08 -0600 Subject: [PATCH 1/7] fix(pipeline_agent): clear user transcript when before_llm_cb returns false --- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index e2ff635d1..2afcf25b9 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -753,6 +753,7 @@ async def _synthesize_answer_task( llm_stream = await llm_stream if llm_stream is False: + self._transcribed_text = "" handle.cancel() return From 6e13457fbb8881459abdee43bf89934654232e56 Mon Sep 17 00:00:00 2001 From: Hamdan <96612374+s-hamdananwar@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:29:47 -0600 Subject: [PATCH 2/7] Create slimy-carpets-dress.md --- .changeset/slimy-carpets-dress.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/slimy-carpets-dress.md diff --git a/.changeset/slimy-carpets-dress.md b/.changeset/slimy-carpets-dress.md new file mode 100644 index 000000000..76f92e304 --- /dev/null +++ b/.changeset/slimy-carpets-dress.md @@ -0,0 +1,5 @@ +--- +"livekit-agents": patch +--- + +fix(pipeline_agent): clear user transcript when before_llm_cb returns false From e2755d6e6ab959daa3e66b19fee5498a61df8c9b Mon Sep 17 00:00:00 2001 From: Hamdan <96612374+s-hamdananwar@users.noreply.github.com> Date: Thu, 30 Jan 2025 10:01:35 -0600 Subject: [PATCH 3/7] Update livekit-agents/livekit/agents/pipeline/pipeline_agent.py Co-authored-by: David Zhao --- livekit-agents/livekit/agents/pipeline/pipeline_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 2afcf25b9..dc285bc1f 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -753,7 +753,11 @@ async def _synthesize_answer_task( llm_stream = await llm_stream if llm_stream is False: - self._transcribed_text = "" + # user chose not to synthesize an answer, so we do not want to + # leave the same question in chat context. otherwise it would be + # unintentionally committed when the next set of speech comes in. + if len(self._transcribed_text) >= len(handle.user_question): + self._transcribed_text = self._transcribed_text[len(handle.user_question) :] handle.cancel() return From da8fbd4e6eaad7d275843f43eb916b2eee14e668 Mon Sep 17 00:00:00 2001 From: Hamdan Anwar Sayeed <96612374+s-hamdananwar@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:41:04 -0600 Subject: [PATCH 4/7] ruff fix --- .../livekit/agents/ipc/job_thread_executor.py | 6 ++-- .../livekit/agents/ipc/proc_client.py | 6 ++-- .../livekit/agents/ipc/supervised_proc.py | 6 ++-- .../livekit/agents/pipeline/pipeline_agent.py | 16 +++++----- livekit-agents/livekit/agents/stt/stt.py | 6 ++-- .../livekit/plugins/llama_index/llm.py | 6 ++-- .../livekit/plugins/silero/vad.py | 4 +-- .../livekit/plugins/turn_detector/eou.py | 6 ++-- tests/test_create_func.py | 30 +++++++++---------- tests/test_ipc.py | 6 ++-- tests/test_llm.py | 18 +++++------ tests/test_message_change.py | 6 ++-- tests/test_stt.py | 6 ++-- tests/test_tts.py | 6 ++-- 14 files changed, 64 insertions(+), 64 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/job_thread_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py index a847e0de6..6705422ab 100644 --- a/livekit-agents/livekit/agents/ipc/job_thread_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -140,9 +140,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance(init_res, proto.InitializeResponse), ( - "first message must be InitializeResponse" - ) + assert isinstance( + init_res, proto.InitializeResponse + ), "first message must be InitializeResponse" except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("runner initialization timed out") diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 97080e8cc..76b77fb88 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -53,9 +53,9 @@ def initialize(self) -> None: cch = aio.duplex_unix._Duplex.open(self._mp_cch) first_req = recv_message(cch, IPC_MESSAGES) - assert isinstance(first_req, InitializeRequest), ( - "first message must be proto.InitializeRequest" - ) + assert isinstance( + first_req, InitializeRequest + ), "first message must be proto.InitializeRequest" self._init_req = first_req self._initialize_fnc(self._init_req, self) diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index dfd18172d..e56119876 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -165,9 +165,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance(init_res, proto.InitializeResponse), ( - "first message must be InitializeResponse" - ) + assert isinstance( + init_res, proto.InitializeResponse + ), "first message must be InitializeResponse" except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("process initialization timed out") diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index dc285bc1f..3b4c7abd1 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -757,7 +757,9 @@ async def _synthesize_answer_task( # leave the same question in chat context. otherwise it would be # unintentionally committed when the next set of speech comes in. if len(self._transcribed_text) >= len(handle.user_question): - self._transcribed_text = self._transcribed_text[len(handle.user_question) :] + self._transcribed_text = self._transcribed_text[ + len(handle.user_question) : + ] handle.cancel() return @@ -903,9 +905,9 @@ async def _execute_function_calls() -> None: return assert isinstance(speech_handle.source, LLMStream) - assert not user_question or speech_handle.user_committed, ( - "user speech should have been committed before using tools" - ) + assert ( + not user_question or speech_handle.user_committed + ), "user speech should have been committed before using tools" llm_stream = speech_handle.source @@ -1055,9 +1057,9 @@ def _synthesize_agent_speech( speech_id: str, source: str | LLMStream | AsyncIterable[str], ) -> SynthesisHandle: - assert self._agent_output is not None, ( - "agent output should be initialized when ready" - ) + assert ( + self._agent_output is not None + ), "agent output should be initialized when ready" tk = SpeechDataContextVar.set(SpeechData(speech_id)) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index 06fbd7204..a0956e621 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -248,9 +248,9 @@ async def _metrics_monitor_task( async for ev in event_aiter: if ev.type == SpeechEventType.RECOGNITION_USAGE: - assert ev.recognition_usage is not None, ( - "recognition_usage must be provided for RECOGNITION_USAGE event" - ) + assert ( + ev.recognition_usage is not None + ), "recognition_usage must be provided for RECOGNITION_USAGE event" duration = time.perf_counter() - start_time stt_metrics = STTMetrics( diff --git a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py index 6a5ec46db..9f674717d 100644 --- a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py +++ b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py @@ -74,9 +74,9 @@ async def _run(self) -> None: "The last message in the chat context must be from the user" ) - assert isinstance(user_msg.content, str), ( - "user message content must be a string" - ) + assert isinstance( + user_msg.content, str + ), "user message content must be a string" try: if not self._stream: diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index 1604a99a6..ee6d27599 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -86,9 +86,7 @@ async def entrypoint(ctx: JobContext): if __name__ == "__main__": - cli.run_app( - WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm) - ) + cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) ``` Args: diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py index 18944548b..b42075445 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py @@ -193,9 +193,9 @@ async def predict_end_of_turn( timeout=timeout, ) - assert result is not None, ( - "end_of_utterance prediction should always returns a result" - ) + assert ( + result is not None + ), "end_of_utterance prediction should always returns a result" result_json = json.loads(result.decode()) return result_json["eou_probability"] diff --git a/tests/test_create_func.py b/tests/test_create_func.py index 202a4e4f2..a81d31d93 100644 --- a/tests/test_create_func.py +++ b/tests/test_create_func.py @@ -16,9 +16,9 @@ def test_fn( pass fnc_ctx = TestFunctionContext() - assert "test_function" in fnc_ctx.ai_functions, ( - "Function should be registered in ai_functions" - ) + assert ( + "test_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" fnc_info = fnc_ctx.ai_functions["test_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -69,9 +69,9 @@ def test_fn(self): pass fnc_ctx = TestFunctionContext() - assert "test_fn" in fnc_ctx.ai_functions, ( - "Function should be registered in ai_functions" - ) + assert ( + "test_fn" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function" @@ -92,9 +92,9 @@ def optional_fn( pass fnc_ctx = TestFunctionContext() - assert "optional_function" in fnc_ctx.ai_functions, ( - "Function should be registered in ai_functions" - ) + assert ( + "optional_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" fnc_info = fnc_ctx.ai_functions["optional_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -159,9 +159,9 @@ def list_fn( pass fnc_ctx = TestFunctionContext() - assert "list_function" in fnc_ctx.ai_functions, ( - "Function should be registered in ai_functions" - ) + assert ( + "list_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" fnc_info = fnc_ctx.ai_functions["list_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -202,9 +202,9 @@ def enum_fn( pass fnc_ctx = TestFunctionContext() - assert "enum_function" in fnc_ctx.ai_functions, ( - "Function should be registered in ai_functions" - ) + assert ( + "enum_function" in fnc_ctx.ai_functions + ), "Function should be registered in ai_functions" fnc_info = fnc_ctx.ai_functions["enum_function"] build_info = _oai_api.build_oai_function_description(fnc_info) diff --git a/tests/test_ipc.py b/tests/test_ipc.py index 645808827..4e1fd4fe7 100644 --- a/tests/test_ipc.py +++ b/tests/test_ipc.py @@ -354,9 +354,9 @@ async def test_shutdown_no_job(): assert proc.exitcode == 0 assert not proc.killed - assert start_args.shutdown_counter.value == 0, ( - "shutdown_cb isn't called when there is no job" - ) + assert ( + start_args.shutdown_counter.value == 0 + ), "shutdown_cb isn't called when there is no job" async def test_job_slow_shutdown(): diff --git a/tests/test_llm.py b/tests/test_llm.py index 59d456ce0..da4c9520b 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -195,9 +195,9 @@ async def test_cancelled_calls(llm_factory: Callable[[], llm.LLM]): await stream.aclose() assert len(calls) == 1 - assert isinstance(calls[0].exception, asyncio.CancelledError), ( - "toggle_light should have been cancelled" - ) + assert isinstance( + calls[0].exception, asyncio.CancelledError + ), "toggle_light should have been cancelled" @pytest.mark.parametrize("llm_factory", LLMS) @@ -220,9 +220,9 @@ async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]): call = calls[0] currencies = call.call_info.arguments["currencies"] assert len(currencies) == 3, "select_currencies should have 3 currencies" - assert "eur" in currencies and "gbp" in currencies and "sek" in currencies, ( - "select_currencies should have eur, gbp, sek" - ) + assert ( + "eur" in currencies and "gbp" in currencies and "sek" in currencies + ), "select_currencies should have eur, gbp, sek" @pytest.mark.parametrize("llm_factory", LLMS) @@ -342,9 +342,9 @@ async def test_tool_choice_options( if tool_choice == "none" and isinstance(input_llm, anthropic.LLM): assert True else: - assert call_names == expected_calls, ( - f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" - ) + assert ( + call_names == expected_calls + ), f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" async def _request_fnc_call( diff --git a/tests/test_message_change.py b/tests/test_message_change.py index 394dd1abc..a02f395a7 100644 --- a/tests/test_message_change.py +++ b/tests/test_message_change.py @@ -43,9 +43,9 @@ def test_find_longest_increasing_subsequence(indices, expected_seq, desc): assert result[0] == 0, f"First index not included in {desc}" # Verify sequence matches expected - assert result_seq == expected_seq, ( - f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}" - ) + assert ( + result_seq == expected_seq + ), f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}" @pytest.mark.parametrize( diff --git a/tests/test_stt.py b/tests/test_stt.py index 9a80aa4be..d1f340b1e 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -99,9 +99,9 @@ async def _stream_output(): async for event in stream: if event.type == agents.stt.SpeechEventType.START_OF_SPEECH: - assert recv_end, ( - "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" - ) + assert ( + recv_end + ), "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" assert not recv_start recv_end = False recv_start = True diff --git a/tests/test_tts.py b/tests/test_tts.py index 991c84377..95d0d445a 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -38,9 +38,9 @@ async def _assert_valid_synthesized_audio( merged_frame = merge_frames(frames) assert merged_frame.sample_rate == tts.sample_rate, "sample rate should be the same" - assert merged_frame.num_channels == tts.num_channels, ( - "num channels should be the same" - ) + assert ( + merged_frame.num_channels == tts.num_channels + ), "num channels should be the same" SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [ From 3976d3160abfe3158d6d8382a9e9c71e069fe928 Mon Sep 17 00:00:00 2001 From: Hamdan Anwar Sayeed <96612374+s-hamdananwar@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:47:14 -0600 Subject: [PATCH 5/7] Revert "ruff fix" This reverts commit da8fbd4e6eaad7d275843f43eb916b2eee14e668. --- .../livekit/agents/ipc/job_thread_executor.py | 6 ++-- .../livekit/agents/ipc/proc_client.py | 6 ++-- .../livekit/agents/ipc/supervised_proc.py | 6 ++-- .../livekit/agents/pipeline/pipeline_agent.py | 16 +++++----- livekit-agents/livekit/agents/stt/stt.py | 6 ++-- .../livekit/plugins/llama_index/llm.py | 6 ++-- .../livekit/plugins/silero/vad.py | 4 ++- .../livekit/plugins/turn_detector/eou.py | 6 ++-- tests/test_create_func.py | 30 +++++++++---------- tests/test_ipc.py | 6 ++-- tests/test_llm.py | 18 +++++------ tests/test_message_change.py | 6 ++-- tests/test_stt.py | 6 ++-- tests/test_tts.py | 6 ++-- 14 files changed, 64 insertions(+), 64 deletions(-) diff --git a/livekit-agents/livekit/agents/ipc/job_thread_executor.py b/livekit-agents/livekit/agents/ipc/job_thread_executor.py index 6705422ab..a847e0de6 100644 --- a/livekit-agents/livekit/agents/ipc/job_thread_executor.py +++ b/livekit-agents/livekit/agents/ipc/job_thread_executor.py @@ -140,9 +140,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" + assert isinstance(init_res, proto.InitializeResponse), ( + "first message must be InitializeResponse" + ) except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("runner initialization timed out") diff --git a/livekit-agents/livekit/agents/ipc/proc_client.py b/livekit-agents/livekit/agents/ipc/proc_client.py index 76b77fb88..97080e8cc 100644 --- a/livekit-agents/livekit/agents/ipc/proc_client.py +++ b/livekit-agents/livekit/agents/ipc/proc_client.py @@ -53,9 +53,9 @@ def initialize(self) -> None: cch = aio.duplex_unix._Duplex.open(self._mp_cch) first_req = recv_message(cch, IPC_MESSAGES) - assert isinstance( - first_req, InitializeRequest - ), "first message must be proto.InitializeRequest" + assert isinstance(first_req, InitializeRequest), ( + "first message must be proto.InitializeRequest" + ) self._init_req = first_req self._initialize_fnc(self._init_req, self) diff --git a/livekit-agents/livekit/agents/ipc/supervised_proc.py b/livekit-agents/livekit/agents/ipc/supervised_proc.py index e56119876..dfd18172d 100644 --- a/livekit-agents/livekit/agents/ipc/supervised_proc.py +++ b/livekit-agents/livekit/agents/ipc/supervised_proc.py @@ -165,9 +165,9 @@ async def initialize(self) -> None: channel.arecv_message(self._pch, proto.IPC_MESSAGES), timeout=self._opts.initialize_timeout, ) - assert isinstance( - init_res, proto.InitializeResponse - ), "first message must be InitializeResponse" + assert isinstance(init_res, proto.InitializeResponse), ( + "first message must be InitializeResponse" + ) except asyncio.TimeoutError: self._initialize_fut.set_exception( asyncio.TimeoutError("process initialization timed out") diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 3b4c7abd1..dc285bc1f 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -757,9 +757,7 @@ async def _synthesize_answer_task( # leave the same question in chat context. otherwise it would be # unintentionally committed when the next set of speech comes in. if len(self._transcribed_text) >= len(handle.user_question): - self._transcribed_text = self._transcribed_text[ - len(handle.user_question) : - ] + self._transcribed_text = self._transcribed_text[len(handle.user_question) :] handle.cancel() return @@ -905,9 +903,9 @@ async def _execute_function_calls() -> None: return assert isinstance(speech_handle.source, LLMStream) - assert ( - not user_question or speech_handle.user_committed - ), "user speech should have been committed before using tools" + assert not user_question or speech_handle.user_committed, ( + "user speech should have been committed before using tools" + ) llm_stream = speech_handle.source @@ -1057,9 +1055,9 @@ def _synthesize_agent_speech( speech_id: str, source: str | LLMStream | AsyncIterable[str], ) -> SynthesisHandle: - assert ( - self._agent_output is not None - ), "agent output should be initialized when ready" + assert self._agent_output is not None, ( + "agent output should be initialized when ready" + ) tk = SpeechDataContextVar.set(SpeechData(speech_id)) diff --git a/livekit-agents/livekit/agents/stt/stt.py b/livekit-agents/livekit/agents/stt/stt.py index a0956e621..06fbd7204 100644 --- a/livekit-agents/livekit/agents/stt/stt.py +++ b/livekit-agents/livekit/agents/stt/stt.py @@ -248,9 +248,9 @@ async def _metrics_monitor_task( async for ev in event_aiter: if ev.type == SpeechEventType.RECOGNITION_USAGE: - assert ( - ev.recognition_usage is not None - ), "recognition_usage must be provided for RECOGNITION_USAGE event" + assert ev.recognition_usage is not None, ( + "recognition_usage must be provided for RECOGNITION_USAGE event" + ) duration = time.perf_counter() - start_time stt_metrics = STTMetrics( diff --git a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py index 9f674717d..6a5ec46db 100644 --- a/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py +++ b/livekit-plugins/livekit-plugins-llama-index/livekit/plugins/llama_index/llm.py @@ -74,9 +74,9 @@ async def _run(self) -> None: "The last message in the chat context must be from the user" ) - assert isinstance( - user_msg.content, str - ), "user message content must be a string" + assert isinstance(user_msg.content, str), ( + "user message content must be a string" + ) try: if not self._stream: diff --git a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py index ee6d27599..1604a99a6 100644 --- a/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py +++ b/livekit-plugins/livekit-plugins-silero/livekit/plugins/silero/vad.py @@ -86,7 +86,9 @@ async def entrypoint(ctx: JobContext): if __name__ == "__main__": - cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm)) + cli.run_app( + WorkerOptions(entrypoint_fnc=entrypoint, prewarm_fnc=prewarm) + ) ``` Args: diff --git a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py index b42075445..18944548b 100644 --- a/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py +++ b/livekit-plugins/livekit-plugins-turn-detector/livekit/plugins/turn_detector/eou.py @@ -193,9 +193,9 @@ async def predict_end_of_turn( timeout=timeout, ) - assert ( - result is not None - ), "end_of_utterance prediction should always returns a result" + assert result is not None, ( + "end_of_utterance prediction should always returns a result" + ) result_json = json.loads(result.decode()) return result_json["eou_probability"] diff --git a/tests/test_create_func.py b/tests/test_create_func.py index a81d31d93..202a4e4f2 100644 --- a/tests/test_create_func.py +++ b/tests/test_create_func.py @@ -16,9 +16,9 @@ def test_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "test_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "test_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["test_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -69,9 +69,9 @@ def test_fn(self): pass fnc_ctx = TestFunctionContext() - assert ( - "test_fn" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "test_fn" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) assert fnc_ctx.ai_functions["test_fn"].description == "A simple test function" @@ -92,9 +92,9 @@ def optional_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "optional_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "optional_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["optional_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -159,9 +159,9 @@ def list_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "list_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "list_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["list_function"] build_info = _oai_api.build_oai_function_description(fnc_info) @@ -202,9 +202,9 @@ def enum_fn( pass fnc_ctx = TestFunctionContext() - assert ( - "enum_function" in fnc_ctx.ai_functions - ), "Function should be registered in ai_functions" + assert "enum_function" in fnc_ctx.ai_functions, ( + "Function should be registered in ai_functions" + ) fnc_info = fnc_ctx.ai_functions["enum_function"] build_info = _oai_api.build_oai_function_description(fnc_info) diff --git a/tests/test_ipc.py b/tests/test_ipc.py index 4e1fd4fe7..645808827 100644 --- a/tests/test_ipc.py +++ b/tests/test_ipc.py @@ -354,9 +354,9 @@ async def test_shutdown_no_job(): assert proc.exitcode == 0 assert not proc.killed - assert ( - start_args.shutdown_counter.value == 0 - ), "shutdown_cb isn't called when there is no job" + assert start_args.shutdown_counter.value == 0, ( + "shutdown_cb isn't called when there is no job" + ) async def test_job_slow_shutdown(): diff --git a/tests/test_llm.py b/tests/test_llm.py index da4c9520b..59d456ce0 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -195,9 +195,9 @@ async def test_cancelled_calls(llm_factory: Callable[[], llm.LLM]): await stream.aclose() assert len(calls) == 1 - assert isinstance( - calls[0].exception, asyncio.CancelledError - ), "toggle_light should have been cancelled" + assert isinstance(calls[0].exception, asyncio.CancelledError), ( + "toggle_light should have been cancelled" + ) @pytest.mark.parametrize("llm_factory", LLMS) @@ -220,9 +220,9 @@ async def test_calls_arrays(llm_factory: Callable[[], llm.LLM]): call = calls[0] currencies = call.call_info.arguments["currencies"] assert len(currencies) == 3, "select_currencies should have 3 currencies" - assert ( - "eur" in currencies and "gbp" in currencies and "sek" in currencies - ), "select_currencies should have eur, gbp, sek" + assert "eur" in currencies and "gbp" in currencies and "sek" in currencies, ( + "select_currencies should have eur, gbp, sek" + ) @pytest.mark.parametrize("llm_factory", LLMS) @@ -342,9 +342,9 @@ async def test_tool_choice_options( if tool_choice == "none" and isinstance(input_llm, anthropic.LLM): assert True else: - assert ( - call_names == expected_calls - ), f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" + assert call_names == expected_calls, ( + f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}" + ) async def _request_fnc_call( diff --git a/tests/test_message_change.py b/tests/test_message_change.py index a02f395a7..394dd1abc 100644 --- a/tests/test_message_change.py +++ b/tests/test_message_change.py @@ -43,9 +43,9 @@ def test_find_longest_increasing_subsequence(indices, expected_seq, desc): assert result[0] == 0, f"First index not included in {desc}" # Verify sequence matches expected - assert ( - result_seq == expected_seq - ), f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}" + assert result_seq == expected_seq, ( + f"Wrong sequence in {desc}: expected {expected_seq}, got {result_seq}" + ) @pytest.mark.parametrize( diff --git a/tests/test_stt.py b/tests/test_stt.py index d1f340b1e..9a80aa4be 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -99,9 +99,9 @@ async def _stream_output(): async for event in stream: if event.type == agents.stt.SpeechEventType.START_OF_SPEECH: - assert ( - recv_end - ), "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + assert recv_end, ( + "START_OF_SPEECH recv but no END_OF_SPEECH has been sent before" + ) assert not recv_start recv_end = False recv_start = True diff --git a/tests/test_tts.py b/tests/test_tts.py index 95d0d445a..991c84377 100644 --- a/tests/test_tts.py +++ b/tests/test_tts.py @@ -38,9 +38,9 @@ async def _assert_valid_synthesized_audio( merged_frame = merge_frames(frames) assert merged_frame.sample_rate == tts.sample_rate, "sample rate should be the same" - assert ( - merged_frame.num_channels == tts.num_channels - ), "num channels should be the same" + assert merged_frame.num_channels == tts.num_channels, ( + "num channels should be the same" + ) SYNTHESIZE_TTS: list[Callable[[], tts.TTS]] = [ From 4385f3937142d4babae514124b1b5bb745baedc2 Mon Sep 17 00:00:00 2001 From: Hamdan Anwar Sayeed <96612374+s-hamdananwar@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:49:53 -0600 Subject: [PATCH 6/7] ruff fix --- .../livekit/agents/pipeline/pipeline_agent.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index dc285bc1f..3b4c7abd1 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -757,7 +757,9 @@ async def _synthesize_answer_task( # leave the same question in chat context. otherwise it would be # unintentionally committed when the next set of speech comes in. if len(self._transcribed_text) >= len(handle.user_question): - self._transcribed_text = self._transcribed_text[len(handle.user_question) :] + self._transcribed_text = self._transcribed_text[ + len(handle.user_question) : + ] handle.cancel() return @@ -903,9 +905,9 @@ async def _execute_function_calls() -> None: return assert isinstance(speech_handle.source, LLMStream) - assert not user_question or speech_handle.user_committed, ( - "user speech should have been committed before using tools" - ) + assert ( + not user_question or speech_handle.user_committed + ), "user speech should have been committed before using tools" llm_stream = speech_handle.source @@ -1055,9 +1057,9 @@ def _synthesize_agent_speech( speech_id: str, source: str | LLMStream | AsyncIterable[str], ) -> SynthesisHandle: - assert self._agent_output is not None, ( - "agent output should be initialized when ready" - ) + assert ( + self._agent_output is not None + ), "agent output should be initialized when ready" tk = SpeechDataContextVar.set(SpeechData(speech_id)) From 3e4de3afb9d375c0bac518213252f663117e4164 Mon Sep 17 00:00:00 2001 From: Hamdan Anwar Sayeed <96612374+s-hamdananwar@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:13:07 -0600 Subject: [PATCH 7/7] updated ruff fix --- .../livekit/agents/pipeline/pipeline_agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index 3b4c7abd1..da1beb908 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -905,9 +905,9 @@ async def _execute_function_calls() -> None: return assert isinstance(speech_handle.source, LLMStream) - assert ( - not user_question or speech_handle.user_committed - ), "user speech should have been committed before using tools" + assert not user_question or speech_handle.user_committed, ( + "user speech should have been committed before using tools" + ) llm_stream = speech_handle.source @@ -1057,9 +1057,9 @@ def _synthesize_agent_speech( speech_id: str, source: str | LLMStream | AsyncIterable[str], ) -> SynthesisHandle: - assert ( - self._agent_output is not None - ), "agent output should be initialized when ready" + assert self._agent_output is not None, ( + "agent output should be initialized when ready" + ) tk = SpeechDataContextVar.set(SpeechData(speech_id))