diff --git a/changelog/3175.added.md b/changelog/3175.added.md new file mode 100644 index 0000000000..16c9460681 --- /dev/null +++ b/changelog/3175.added.md @@ -0,0 +1,41 @@ +- Added additional functionality related to "thinking", for Google and Anthropic + LLMs. + + 1. New typed parameters for Google and Anthropic LLMs that control the + models' thinking behavior (like how much thinking to do, and whether to + output thoughts or thought summaries): + - `AnthropicLLMService.ThinkingConfig` + - `GoogleLLMService.ThinkingConfig` + 2. New frames for representing thoughts output by LLMs: + - `LLMThoughtStartFrame` + - `LLMThoughtTextFrame` + - `LLMThoughtEndFrame` + 3. A mechanism for appending arbitrary context messages after a function call + message, used specifically to support Google's function-call-related + "thought signatures", which are necessary to ensure thinking continuity + between function calls in a chain (where the model thinks, makes a function + call, thinks some more, etc.). See: + - `append_extra_context_messages` field in `FunctionInProgressFrame` and + helper types + - `GoogleLLMService` leveraging the new mechanism to add a Google-specific + `"fn_thought_signature"` message + - `LLMAssistantAggregator` handling of `append_extra_context_messages` + - `GeminiLLMAdapter` handling of `"fn_thought_signature"` messages + 4. A generic mechanism for recording LLM thoughts to context, used + specifically to support Anthropic, whose thought signatures are expected to + appear alongside the text of the thoughts within assistant context + messages. See: + - `LLMThoughtEndFrame.signature` + - `LLMAssistantAggregator` handling of the above field + - `AnthropicLLMAdapter` handling of `"thought"` context messages + 5. Google-specific logic for inserting non-function-call-related thought + signatures into the context, to help maintain thinking continuity in a + chain of LLM calls. See: + - `GoogleLLMService` sending `LLMMessagesAppendFrame`s to add LLM-specific + `"non_fn_thought_signature"` messages to context + - `GeminiLLMAdapter` handling of `"non_fn_thought_signature"` messages + 6. An expansion of `TranscriptProcessor` to process LLM thoughts in addition + to user and assistant utterances. See: + - `TranscriptProcessor(process_thoughts=True)` (defaults to `False`) + - `ThoughtTranscriptionMessage`, which is now also emitted with the + `"on_transcript_update"` event diff --git a/examples/foundational/07n-interruptible-google-http.py b/examples/foundational/07n-interruptible-google-http.py index 2ef65e4741..4a03829904 100644 --- a/examples/foundational/07n-interruptible-google-http.py +++ b/examples/foundational/07n-interruptible-google-http.py @@ -75,8 +75,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm = GoogleLLMService( api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.5-flash", - # turn on thinking if you want it - # params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}),) + # force a certain amount of thinking if you want it + # params=GoogleLLMService.InputParams( + # thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096) + # ), ) messages = [ diff --git a/examples/foundational/07n-interruptible-google.py b/examples/foundational/07n-interruptible-google.py index 73dd49e781..28b61c151d 100644 --- a/examples/foundational/07n-interruptible-google.py +++ b/examples/foundational/07n-interruptible-google.py @@ -75,8 +75,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm = GoogleLLMService( api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.5-flash", - # turn on thinking if you want it - # params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}),) + # force a certain amount of thinking if you want it + # params=GoogleLLMService.InputParams( + # thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096) + # ), ) messages = [ diff --git a/examples/foundational/07s-interruptible-google-audio-in.py b/examples/foundational/07s-interruptible-google-audio-in.py index 67772e40d2..90bff60621 100644 --- a/examples/foundational/07s-interruptible-google-audio-in.py +++ b/examples/foundational/07s-interruptible-google-audio-in.py @@ -224,8 +224,10 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): llm = GoogleLLMService( api_key=os.getenv("GOOGLE_API_KEY"), model="gemini-2.5-flash", - # turn on thinking if you want it - # params=GoogleLLMService.InputParams(extra={"thinking_config": {"thinking_budget": 4096}}), + # force a certain amount of thinking if you want it + # params=GoogleLLMService.InputParams( + # thinking=GoogleLLMService.ThinkingConfig(thinking_budget=4096) + # ), ) tts = GoogleTTSService( diff --git a/examples/foundational/49a-thinking-anthropic.py b/examples/foundational/49a-thinking-anthropic.py new file mode 100644 index 0000000000..4066a15c04 --- /dev/null +++ b/examples/foundational/49a-thinking-anthropic.py @@ -0,0 +1,161 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.anthropic.llm import AnthropicLLMService +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = AnthropicLLMService( + api_key=os.getenv("ANTHROPIC_API_KEY"), + params=AnthropicLLMService.InputParams( + thinking=AnthropicLLMService.ThinkingConfig(type="enabled", budget_tokens=2048) + ), + ) + + transcript = TranscriptProcessor(process_thoughts=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = LLMContext(messages) + context_aggregator = LLMContextAggregatorPair(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + transcript.user(), # User transcripts + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + transcript.assistant(), # Assistant transcripts (including thoughts) + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append( + { + "role": "user", + "content": "Say hello briefly.", + } + ) + # Here are some example prompts conducive to demonstrating + # thinking (picked from Google and Anthropic docs). + # messages.append( + # { + # "role": "user", + # "content": "Analogize photosynthesis and growing up. Keep your answer concise.", + # # "content": "Compare and contrast electric cars and hybrid cars." + # # "content": "Are there an infinite number of prime numbers such that n mod 4 == 3?" + # } + # ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + # Register event handler for transcript updates + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + for msg in frame.messages: + if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)): + timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" + role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role + logger.info(f"Transcript: {timestamp}{role}: {msg.content}") + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/foundational/49b-thinking-google.py b/examples/foundational/49b-thinking-google.py new file mode 100644 index 0000000000..947ab39c9f --- /dev/null +++ b/examples/foundational/49b-thinking-google.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.google.llm import GoogleLLMService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = GoogleLLMService( + api_key=os.getenv("GOOGLE_API_KEY"), + # model="gemini-3-pro-preview", # A more powerful reasoning model, but slower + params=GoogleLLMService.InputParams( + thinking=GoogleLLMService.ThinkingConfig( + # thinking_level="low", # Use this field instead of thinking_budget for Gemini 3 Pro. Defaults to "high". + thinking_budget=-1, # Dynamic thinking + include_thoughts=True, + ) + ), + ) + + transcript = TranscriptProcessor(process_thoughts=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = LLMContext(messages) + context_aggregator = LLMContextAggregatorPair(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + transcript.user(), # User transcripts + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + transcript.assistant(), # Assistant transcripts (including thoughts) + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append( + { + "role": "user", + "content": "Say hello briefly.", + } + ) + # Here are some example prompts conducive to demonstrating + # thinking (picked from Google and Anthropic docs). + # messages.append( + # { + # "role": "user", + # "content": "Analogize photosynthesis and growing up. Keep your answer concise.", + # # "content": "Compare and contrast electric cars and hybrid cars." + # # "content": "Are there an infinite number of prime numbers such that n mod 4 == 3?" + # } + # ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + # Register event handler for transcript updates + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + for msg in frame.messages: + if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)): + timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" + role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role + logger.info(f"Transcript: {timestamp}{role}: {msg.content}") + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/foundational/49c-thinking-functions-anthropic.py b/examples/foundational/49c-thinking-functions-anthropic.py new file mode 100644 index 0000000000..e821b9d095 --- /dev/null +++ b/examples/foundational/49c-thinking-functions-anthropic.py @@ -0,0 +1,185 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.anthropic.llm import AnthropicLLMService +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + + +async def check_flight_status(params: FunctionCallParams, flight_number: str): + """Check the status of a flight. Returns status (e.g., "on time", "delayed") and departure time. + + Args: + flight_number (str): The flight number, e.g. "AA100". + """ + await params.result_callback({"status": "delayed", "departure_time": "14:30"}) + + +async def book_taxi(params: FunctionCallParams, time: str): + """Book a taxi for a given time. Returns status (e.g., "done"). + + Args: + time (str): The time to book the taxi for, e.g. "15:00". + """ + await params.result_callback({"status": "done"}) + + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = AnthropicLLMService( + api_key=os.getenv("ANTHROPIC_API_KEY"), + params=AnthropicLLMService.InputParams( + thinking=AnthropicLLMService.ThinkingConfig(type="enabled", budget_tokens=2048) + ), + ) + + llm.register_direct_function(check_flight_status) + llm.register_direct_function(book_taxi) + + tools = ToolsSchema(standard_tools=[check_flight_status, book_taxi]) + + transcript = TranscriptProcessor(process_thoughts=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = LLMContext(messages, tools) + context_aggregator = LLMContextAggregatorPair(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + transcript.user(), # User transcripts + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + transcript.assistant(), # Assistant transcripts (including thoughts) + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append( + { + "role": "user", + "content": "Say hello briefly.", + } + ) + # Here is an example prompt conducive to demonstrating thinking and + # function calling. + # This example comes from Gemini docs. + # messages.append( + # { + # "role": "user", + # "content": "Check the status of flight AA100 and, if it's delayed, book me a taxi 2 hours before its departure time.", + # } + # ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + for msg in frame.messages: + if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)): + timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" + role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role + logger.info(f"Transcript: {timestamp}{role}: {msg.content}") + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/foundational/49d-thinking-functions-google.py b/examples/foundational/49d-thinking-functions-google.py new file mode 100644 index 0000000000..cdf4621b19 --- /dev/null +++ b/examples/foundational/49d-thinking-functions-google.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.adapters.schemas.tools_schema import ToolsSchema +from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams +from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3 +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.audio.vad.vad_analyzer import VADParams +from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair +from pipecat.processors.transcript_processor import TranscriptProcessor +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.google.llm import GoogleLLMService +from pipecat.services.llm_service import FunctionCallParams +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + + +async def check_flight_status(params: FunctionCallParams, flight_number: str): + """Check the status of a flight. Returns status (e.g., "on time", "delayed") and departure time. + + Args: + flight_number (str): The flight number, e.g. "AA100". + """ + await params.result_callback({"status": "delayed", "departure_time": "14:30"}) + + +async def book_taxi(params: FunctionCallParams, time: str): + """Book a taxi for a given time. Returns status (e.g., "done"). + + Args: + time (str): The time to book the taxi for, e.g. "15:00". + """ + await params.result_callback({"status": "done"}) + + +# We store functions so objects (e.g. SileroVADAnalyzer) don't get +# instantiated. The function will be called when the desired transport gets +# selected. +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)), + turn_analyzer=LocalSmartTurnAnalyzerV3(params=SmartTurnParams()), + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + + llm = GoogleLLMService( + api_key=os.getenv("GOOGLE_API_KEY"), + # model="gemini-3-pro-preview", # A more powerful reasoning model, but slower + params=GoogleLLMService.InputParams( + thinking=GoogleLLMService.ThinkingConfig( + # thinking_level="low", # Use this field instead of thinking_budget for Gemini 3 Pro. Defaults to "high". + thinking_budget=-1, # Dynamic thinking + include_thoughts=True, + ) + ), + ) + + llm.register_direct_function(check_flight_status) + llm.register_direct_function(book_taxi) + + tools = ToolsSchema(standard_tools=[check_flight_status, book_taxi]) + + transcript = TranscriptProcessor(process_thoughts=True) + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = LLMContext(messages, tools) + context_aggregator = LLMContextAggregatorPair(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + transcript.user(), # User transcripts + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + transcript.assistant(), # Assistant transcripts (including thoughts) + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_connected") + async def on_client_connected(transport, client): + logger.info(f"Client connected") + # Kick off the conversation. + messages.append( + { + "role": "user", + "content": "Say hello briefly.", + } + ) + # Here is an example prompt conducive to demonstrating thinking and + # function calling. + # This example comes from Gemini docs. + # messages.append( + # { + # "role": "user", + # "content": "Check the status of flight AA100 and, if it's delayed, book me a taxi 2 hours before its departure time.", + # } + # ) + await task.queue_frames([LLMRunFrame()]) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + @transcript.event_handler("on_transcript_update") + async def on_transcript_update(processor, frame): + for msg in frame.messages: + if isinstance(msg, (ThoughtTranscriptionMessage, TranscriptionMessage)): + timestamp = f"[{msg.timestamp}] " if msg.timestamp else "" + role = "THOUGHT" if isinstance(msg, ThoughtTranscriptionMessage) else msg.role + logger.info(f"Transcript: {timestamp}{role}: {msg.content}") + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/pyproject.toml b/pyproject.toml index bb72dba61d..5cff2fa142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ fal = [ "fal-client~=0.5.9" ] fireworks = [] fish = [ "ormsgpack~=1.7.0", "pipecat-ai[websockets-base]" ] gladia = [ "pipecat-ai[websockets-base]" ] -google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.41.0,<2", "pipecat-ai[websockets-base]" ] +google = [ "google-cloud-speech>=2.33.0,<3", "google-cloud-texttospeech>=2.31.0,<3", "google-genai>=1.51.0,<2", "pipecat-ai[websockets-base]" ] gradium = [ "pipecat-ai[websockets-base]" ] grok = [] groq = [ "groq~=0.23.0" ] diff --git a/scripts/evals/run-release-evals.py b/scripts/evals/run-release-evals.py index f45128133e..863514c643 100644 --- a/scripts/evals/run-release-evals.py +++ b/scripts/evals/run-release-evals.py @@ -74,6 +74,11 @@ def EVAL_VISION_IMAGE(*, eval_speaks_first: bool = False): eval_speaks_first=True, ) +EVAL_FLIGHT_STATUS = EvalConfig( + prompt="Check the status of flight AA100.", + eval="The user says something about the status of flight AA100, such as whether it's on time or delayed.", +) + TESTS_07 = [ # 07 series @@ -204,6 +209,13 @@ def EVAL_VISION_IMAGE(*, eval_speaks_first: bool = False): ("44-voicemail-detection.py", EVAL_CONVERSATION), ] +TESTS_49 = [ + ("49a-thinking-anthropic.py", EVAL_SIMPLE_MATH), + ("49b-thinking-google.py", EVAL_SIMPLE_MATH), + ("49c-thinking-functions-anthropic.py", EVAL_FLIGHT_STATUS), + ("49d-thinking-functions-google.py", EVAL_FLIGHT_STATUS), +] + TESTS = [ *TESTS_07, *TESTS_12, @@ -216,6 +228,7 @@ def EVAL_VISION_IMAGE(*, eval_speaks_first: bool = False): *TESTS_40, *TESTS_43, *TESTS_44, + *TESTS_49, ] diff --git a/src/pipecat/adapters/services/anthropic_adapter.py b/src/pipecat/adapters/services/anthropic_adapter.py index 75fa5899d1..4cecaf8108 100644 --- a/src/pipecat/adapters/services/anthropic_adapter.py +++ b/src/pipecat/adapters/services/anthropic_adapter.py @@ -165,9 +165,44 @@ def _from_universal_context_messages( def _from_universal_context_message(self, message: LLMContextMessage) -> MessageParam: if isinstance(message, LLMSpecificMessage): - return copy.deepcopy(message.message) + return self._from_anthropic_specific_message(message) return self._from_standard_message(message) + def _from_anthropic_specific_message(self, message: LLMSpecificMessage) -> MessageParam: + """Convert LLMSpecificMessage to Anthropic format. + + Anthropic-specific messages may either be special thought messages that + need to be handled in a special way, or messages already in Anthropic + format. + + Args: + message: Anthropic-specific message. + """ + # Handle special case of thought messages. + # These can be converted to standalone "assistant" messages; later + # these thinking messages will be properly merged into the assistant + # response messages before the context is sent to Anthropic for the + # next turn. + if ( + isinstance(message.message, dict) + and message.message.get("type") == "thought" + and (text := message.message.get("text")) + and (signature := message.message.get("signature")) + ): + return { + "role": "assistant", + "content": [ + { + "type": "thinking", + "thinking": text, + "signature": signature, + } + ], + } + + # Fall back to assuming that the message is already in Anthropic format + return copy.deepcopy(message.message) + def _from_standard_message(self, message: LLMStandardMessage) -> MessageParam: """Convert standard universal context message to Anthropic format. diff --git a/src/pipecat/adapters/services/gemini_adapter.py b/src/pipecat/adapters/services/gemini_adapter.py index a4f70b1fa4..5a7387aca4 100644 --- a/src/pipecat/adapters/services/gemini_adapter.py +++ b/src/pipecat/adapters/services/gemini_adapter.py @@ -209,16 +209,55 @@ def _from_universal_context_messages( system_instruction = None messages = [] tool_call_id_to_name_mapping = {} + non_fn_thought_signatures = [] - # Process each message, preserving Google-formatted messages and converting others + # Process each message, converting to Google format as needed for message in universal_context_messages: - result = self._from_universal_context_message( + # We have a Google-specific message; this may either be a + # thought-signature-containing message that we need to handle in a + # special way, or a message already in Google format that we can + # use directly + if isinstance(message, LLMSpecificMessage): + # Special handling for function-call-related thought signature + # messages + if ( + isinstance(message.message, dict) + and message.message.get("type") == "fn_thought_signature" + and (thought_signature := message.message.get("signature")) + ): + self._apply_function_thought_signature_to_messages( + thought_signature, message.message.get("tool_call_id"), messages + ) + continue + + # Special handling for non-function-call-related thought- + # signature-containing messages + if ( + isinstance(message.message, dict) + and message.message.get("type") == "non_fn_thought_signature" + and (thought_signature := message.message.get("signature")) + and (bookmark := message.message.get("bookmark")) + ): + non_fn_thought_signatures.append( + {"signature": thought_signature, "bookmark": bookmark} + ) + continue + + # Fall back to assuming that the message is already in Google + # format + messages.append(message.message) + continue + + # We have a standard universal context message; convert it to + # Google format + result = self._from_standard_message( message, params=self.MessageConversionParams( already_have_system_instruction=bool(system_instruction), tool_call_id_to_name_mapping=tool_call_id_to_name_mapping, ), ) + # Each result is either a Content or a system instruction if result.content: messages.append(result.content) @@ -229,6 +268,10 @@ def _from_universal_context_messages( if result.tool_call_id_to_name_mapping: tool_call_id_to_name_mapping.update(result.tool_call_id_to_name_mapping) + # Apply non-function-call-related thought signatures to the appropriate + # messages + self._apply_non_function_thought_signatures_to_messages(non_fn_thought_signatures, messages) + # Check if we only have function-related messages (no regular text) has_regular_messages = any( len(msg.parts) == 1 @@ -247,13 +290,6 @@ def _from_universal_context_messages( return self.ConvertedMessages(messages=messages, system_instruction=system_instruction) - def _from_universal_context_message( - self, message: LLMContextMessage, *, params: MessageConversionParams - ) -> MessageConversionResult: - if isinstance(message, LLMSpecificMessage): - return self.MessageConversionResult(content=message.message) - return self._from_standard_message(message, params=params) - def _from_standard_message( self, message: LLMStandardMessage, *, params: MessageConversionParams ) -> MessageConversionResult: @@ -410,3 +446,137 @@ def _from_standard_message( content=Content(role=role, parts=parts), tool_call_id_to_name_mapping=tool_call_id_to_name_mapping, ) + + def _apply_function_thought_signature_to_messages( + self, thought_signature: bytes, tool_call_id: str, messages: List[Content] + ) -> None: + """Apply a function-related thought signature to the corresponding function call message. + + Args: + thought_signature: The thought signature bytes to apply. + tool_call_id: ID of the tool call message to find and modify. + messages: List of messages to search through. + """ + # Search backwards through messages to find the matching function call + for message in reversed(messages): + if not isinstance(message, Content) or not message.parts: + continue + # Find the specific part with the matching function call + for part in message.parts: + if ( + hasattr(part, "function_call") + and part.function_call + and part.function_call.id == tool_call_id + ): + part.thought_signature = thought_signature + break + else: + # Continue outer loop if inner loop didn't break + continue + # Break outer loop if inner loop broke (found match) + break + + def _apply_non_function_thought_signatures_to_messages( + self, thought_signatures: List[dict], messages: List[Content] + ) -> None: + """Apply (optional, but recommended) non-function-call-related thought signatures to the last part of corresponding non-function-call assistant messages. + + Gemini 3 Pro (and, somewhat surprisingly, other models, too, when + functions are involved in the conversation) outputs thought signatures + at the end of assistant responses. + + Args: + thought_signatures: A list of dicts containing: + - "signature": a thought signature + - "bookmark": a bookmark to identify the message part to apply the signature to. + The bookmark may contain either: + - "text" + - "inline_data" + messages: List of messages to search through. + """ + if not thought_signatures: + return + + # For debugging, print out thought signatures and their bookmarks + logger.trace(f"Thought signatures to apply: {len(thought_signatures)}") + for ts in thought_signatures: + bookmark = ts.get("bookmark") + if bookmark.get("text"): + text = bookmark["text"] + log_display_text = f"{text[:50]}..." if len(text) > 50 else text + logger.trace(f" - At text: {log_display_text}") + elif bookmark.get("inline_data"): + logger.trace(f" - At inline data") + + # Find all assistant (model) messages that aren't function calls + non_fn_assistant_messages = [] + for message in messages: + if not isinstance(message, Content) or not message.parts: + continue + # Check if this is a model message without function calls + if message.role == "model": + has_function_call = any( + hasattr(part, "function_call") and part.function_call for part in message.parts + ) + if not has_function_call: + non_fn_assistant_messages.append(message) + + # Apply thought signatures to the corresponding assistant messages + # Match them using content heuristics, maintaining order (messages without signatures are skipped) + message_start_index = 0 # Track where to start searching for the next match + for thought_signature_dict in thought_signatures: + signature = thought_signature_dict.get("signature") + bookmark = thought_signature_dict.get("bookmark") + if not signature: + continue + + # Search through remaining non-function assistant messages for a match + for i in range(message_start_index, len(non_fn_assistant_messages)): + message = non_fn_assistant_messages[i] + if not message.parts: + continue + + last_part = message.parts[-1] + matched = False + + # If it's a text bookmark, check that the last message part text has the same text or + # - is a prefix of that text (in case spoken text was truncated due to interruption) + # - is prefixed by that text (in case bookmark represents just first chunk of multi-chunk text) + if bookmark_text := bookmark.get("text"): + if hasattr(last_part, "text") and last_part.text: + # Normalize whitespace for comparison + signed_text = " ".join(bookmark_text.split()) + last_text = " ".join(last_part.text.split()) + if ( + last_text == signed_text + or signed_text.startswith(last_text) + or last_text.startswith(signed_text) + ): + log_display_text = ( + f"{last_part.text[:50]}..." + if len(last_part.text) > 50 + else last_part.text + ) + logger.trace( + f"Applying thought signature to part with matching text: {log_display_text}" + ) + last_part.thought_signature = signature + matched = True + + # Check if signed part has inline_data and last message part has matching inline_data + elif inline_data := bookmark.get("inline_data"): + if ( + hasattr(last_part, "inline_data") + and last_part.inline_data + and last_part.inline_data.data == inline_data.data + ): + logger.trace( + f"Applying thought signature to part with matching inline_data" + ) + last_part.thought_signature = signature + matched = True + + # If we found a match, update start index and stop searching for this signed part + if matched: + message_start_index = i + 1 + break diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 9cb969f283..542f21fc01 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -38,7 +38,7 @@ from pipecat.utils.utils import obj_count, obj_id if TYPE_CHECKING: - from pipecat.processors.aggregators.llm_context import LLMContext, NotGiven + from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage, NotGiven from pipecat.processors.frame_processor import FrameProcessor @@ -512,6 +512,15 @@ class TranscriptionMessage: timestamp: Optional[str] = None +@dataclass +class ThoughtTranscriptionMessage: + """An LLM thought message in a conversation transcript.""" + + role: Literal["assistant"] = field(default="assistant", init=False) + content: str + timestamp: Optional[str] = None + + @dataclass class TranscriptionUpdateFrame(DataFrame): """Frame containing new messages added to conversation transcript. @@ -556,7 +565,7 @@ class TranscriptionUpdateFrame(DataFrame): messages: List of new transcript messages that were added. """ - messages: List[TranscriptionMessage] + messages: List[TranscriptionMessage | ThoughtTranscriptionMessage] def __str__(self): pts = format_pts(self.pts) @@ -577,6 +586,75 @@ class LLMContextFrame(Frame): context: "LLMContext" +@dataclass +class LLMThoughtStartFrame(ControlFrame): + """Frame indicating the start of an LLM thought. + + Parameters: + append_to_context: Whether the thought should be appended to the LLM context. + If it is appended, the `llm` field is required, since it will be + appended as an `LLMSpecificMessage`. + llm: Optional identifier of the LLM provider for LLM-specific handling. + Only required if `append_to_context` is True, as the thought is + appended to context as an `LLMSpecificMessage`. + """ + + append_to_context: bool = False + llm: Optional[str] = None + + def __post_init__(self): + super().__post_init__() + if self.append_to_context and self.llm is None: + raise ValueError("When append_to_context is True, llm must be set") + + def __str__(self): + pts = format_pts(self.pts) + return ( + f"{self.name}(pts: {pts}, append_to_context: {self.append_to_context}, llm: {self.llm})" + ) + + +@dataclass +class LLMThoughtTextFrame(DataFrame): + """Frame containing the text (or text chunk) of an LLM thought. + + Note that despite this containing text, it is a DataFrame and not a + TextFrame, to avoid most typical text processing, such as TTS. + + Parameters: + text: The text (or text chunk) of the thought. + """ + + text: str + includes_inter_frame_spaces: bool = field(init=False) + + def __post_init__(self): + super().__post_init__() + # Assume that thought text chunks include all necessary spaces + self.includes_inter_frame_spaces = True + + def __str__(self): + pts = format_pts(self.pts) + return f"{self.name}(pts: {pts}, thought text: {self.text})" + + +@dataclass +class LLMThoughtEndFrame(ControlFrame): + """Frame indicating the end of an LLM thought. + + Parameters: + signature: Optional signature associated with the thought. + This is used by Anthropic, which includes a signature at the end of + each thought. + """ + + signature: Any = None + + def __str__(self): + pts = format_pts(self.pts) + return f"{self.name}(pts: {pts}, signature: {self.signature})" + + @dataclass class LLMMessagesFrame(DataFrame): """Frame containing LLM messages for chat completion. @@ -1119,12 +1197,16 @@ class FunctionCallFromLLM: tool_call_id: A unique identifier for the function call. arguments: The arguments to pass to the function. context: The LLM context when the function call was made. + append_extra_context_messages: Optional extra messages to append to the + context after the function call message. Used to add Google + function-call-related thought signatures to the context. """ function_name: str tool_call_id: str arguments: Mapping[str, Any] context: Any + append_extra_context_messages: Optional[List["LLMContextMessage"]] = None @dataclass @@ -1663,13 +1745,16 @@ class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame): tool_call_id: Unique identifier for this function call. arguments: Arguments passed to the function. cancel_on_interruption: Whether to cancel this call if interrupted. - + append_extra_context_messages: Optional extra messages to append to the + context after the function call message. Used to add Google + function-call-related thought signatures to the context. """ function_name: str tool_call_id: str arguments: Any cancel_on_interruption: bool = False + append_extra_context_messages: Optional[List["LLMContextMessage"]] = None @dataclass diff --git a/src/pipecat/processors/aggregators/llm_response_universal.py b/src/pipecat/processors/aggregators/llm_response_universal.py index 69fc649ceb..debfefb076 100644 --- a/src/pipecat/processors/aggregators/llm_response_universal.py +++ b/src/pipecat/processors/aggregators/llm_response_universal.py @@ -47,6 +47,9 @@ LLMRunFrame, LLMSetToolChoiceFrame, LLMSetToolsFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, SpeechControlParamsFrame, StartFrame, TextFrame, @@ -592,6 +595,10 @@ def __init__( self._function_calls_in_progress: Dict[str, Optional[FunctionCallInProgressFrame]] = {} self._context_updated_tasks: Set[asyncio.Task] = set() + self._thought_aggregation_enabled = False + self._thought_llm: str = "" + self._thought_aggregation: List[TextPartForConcatenation] = [] + @property def has_function_calls_in_progress(self) -> bool: """Check if there are any function calls currently in progress. @@ -601,6 +608,17 @@ def has_function_calls_in_progress(self) -> bool: """ return bool(self._function_calls_in_progress) + async def reset(self): + """Reset the aggregation state.""" + await super().reset() + await self._reset_thought_aggregation() # Just to be safe + + async def _reset_thought_aggregation(self): + """Reset the thought aggregation state.""" + self._thought_aggregation_enabled = False + self._thought_llm = "" + self._thought_aggregation = [] + async def process_frame(self, frame: Frame, direction: FrameDirection): """Process frames for assistant response aggregation and function call management. @@ -619,6 +637,12 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self._handle_llm_end(frame) elif isinstance(frame, TextFrame): await self._handle_text(frame) + elif isinstance(frame, LLMThoughtStartFrame): + await self._handle_thought_start(frame) + elif isinstance(frame, LLMThoughtTextFrame): + await self._handle_thought_text(frame) + elif isinstance(frame, LLMThoughtEndFrame): + await self._handle_thought_end(frame) elif isinstance(frame, LLMRunFrame): await self._handle_llm_run(frame) elif isinstance(frame, LLMMessagesAppendFrame): @@ -716,6 +740,10 @@ async def _handle_function_call_in_progress(self, frame: FunctionCallInProgressF } ) + # Append to context any specified extra context messages + if frame.append_extra_context_messages: + self._context.add_messages(frame.append_extra_context_messages) + self._function_calls_in_progress[frame.tool_call_id] = frame async def _handle_function_call_result(self, frame: FunctionCallResultFrame): @@ -824,6 +852,47 @@ async def _handle_text(self, frame: TextFrame): ) ) + async def _handle_thought_start(self, frame: LLMThoughtStartFrame): + if not self._started: + return + + await self._reset_thought_aggregation() + self._thought_aggregation_enabled = frame.append_to_context + self._thought_llm = frame.llm + + async def _handle_thought_text(self, frame: LLMThoughtTextFrame): + if not self._started or not self._thought_aggregation_enabled: + return + + # Make sure we really have text (spaces count, too!) + if len(frame.text) == 0: + return + + self._thought_aggregation.append( + TextPartForConcatenation( + frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces + ) + ) + + async def _handle_thought_end(self, frame: LLMThoughtEndFrame): + if not self._started or not self._thought_aggregation_enabled: + return + + thought = concatenate_aggregated_text(self._thought_aggregation) + llm = self._thought_llm + await self._reset_thought_aggregation() + + self._context.add_message( + LLMSpecificMessage( + llm=llm, + message={ + "type": "thought", + "text": thought, + "signature": frame.signature, + }, + ) + ) + def _context_updated_task_finished(self, task: asyncio.Task): self._context_updated_tasks.discard(task) diff --git a/src/pipecat/processors/transcript_processor.py b/src/pipecat/processors/transcript_processor.py index 93e0c37b46..0dd59f1b37 100644 --- a/src/pipecat/processors/transcript_processor.py +++ b/src/pipecat/processors/transcript_processor.py @@ -20,6 +20,10 @@ EndFrame, Frame, InterruptionFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, + ThoughtTranscriptionMessage, TranscriptionFrame, TranscriptionMessage, TranscriptionUpdateFrame, @@ -81,92 +85,98 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): class AssistantTranscriptProcessor(BaseTranscriptProcessor): - """Processes assistant TTS text frames into timestamped conversation messages. + """Processes assistant TTS text frames and LLM thought frames into timestamped messages. - This processor aggregates TTS text frames into complete utterances and emits them as - transcript messages. Utterances are completed when: + This processor aggregates both TTS text frames and LLM thought frames into + complete utterances and thoughts, emitting them as transcript messages. + An assistant utterance is completed when: - The bot stops speaking (BotStoppedSpeakingFrame) - The bot is interrupted (InterruptionFrame) - - The pipeline ends (EndFrame) + - The pipeline ends (EndFrame, CancelFrame) + + A thought is completed when: + - The thought ends (LLMThoughtEndFrame) + - The bot is interrupted (InterruptionFrame) + - The pipeline ends (EndFrame, CancelFrame) """ - def __init__(self, **kwargs): + def __init__(self, *, process_thoughts: bool = False, **kwargs): """Initialize processor with aggregation state. Args: + process_thoughts: Whether to process LLM thought frames. Defaults to False. **kwargs: Additional arguments passed to parent class. """ super().__init__(**kwargs) - self._current_text_parts: List[TextPartForConcatenation] = [] - self._aggregation_start_time: Optional[str] = None - - async def _emit_aggregated_text(self): - """Aggregates and emits text fragments as a transcript message. - - This method uses a heuristic to automatically detect whether text fragments - contain embedded spacing (spaces at the beginning or end of fragments) or not, - and applies the appropriate joining strategy. It handles fragments from different - TTS services with different formatting patterns. - - Examples: - Fragments with embedded spacing (concatenated):: - - TTSTextFrame: ["Hello"] - TTSTextFrame: [" there"] # Leading space - TTSTextFrame: ["!"] - TTSTextFrame: [" How"] # Leading space - TTSTextFrame: ["'s"] - TTSTextFrame: [" it"] # Leading space - - Result: "Hello there! How's it" - Fragments with trailing spaces (concatenated):: + self._process_thoughts = process_thoughts + self._current_assistant_text_parts: List[TextPartForConcatenation] = [] + self._assistant_text_start_time: Optional[str] = None - TTSTextFrame: ["Hel"] - TTSTextFrame: ["lo "] # Trailing space - TTSTextFrame: ["to "] # Trailing space - TTSTextFrame: ["you"] + self._current_thought_parts: List[TextPartForConcatenation] = [] + self._thought_start_time: Optional[str] = None + self._thought_active = False - Result: "Hello to you" - - Word-by-word fragments without spacing (joined with spaces):: - - TTSTextFrame: ["Hello"] - TTSTextFrame: ["there"] - TTSTextFrame: ["how"] - TTSTextFrame: ["are"] - TTSTextFrame: ["you"] + async def _emit_aggregated_assistant_text(self): + """Aggregates and emits text fragments as a transcript message. - Result: "Hello there how are you" + This method aggregates text fragments that may arrive in multiple + TTSTextFrame instances and emits them as a single TranscriptionMessage. """ - if self._current_text_parts and self._aggregation_start_time: - content = concatenate_aggregated_text(self._current_text_parts) + if self._current_assistant_text_parts and self._assistant_text_start_time: + content = concatenate_aggregated_text(self._current_assistant_text_parts) if content: logger.trace(f"Emitting aggregated assistant message: {content}") message = TranscriptionMessage( role="assistant", content=content, - timestamp=self._aggregation_start_time, + timestamp=self._assistant_text_start_time, ) await self._emit_update([message]) else: logger.trace("No content to emit after stripping whitespace") # Reset aggregation state - self._current_text_parts = [] - self._aggregation_start_time = None + self._current_assistant_text_parts = [] + self._assistant_text_start_time = None + + async def _emit_aggregated_thought(self): + """Aggregates and emits thought text fragments as a thought transcript message. + + This method aggregates thought fragments that may arrive in multiple + LLMThoughtTextFrame instances and emits them as a single ThoughtTranscriptionMessage. + """ + if self._current_thought_parts and self._thought_start_time: + content = concatenate_aggregated_text(self._current_thought_parts) + if content: + logger.trace(f"Emitting aggregated thought message: {content}") + message = ThoughtTranscriptionMessage( + content=content, + timestamp=self._thought_start_time, + ) + await self._emit_update([message]) + else: + logger.trace("No thought content to emit after stripping whitespace") + + # Reset aggregation state + self._current_thought_parts = [] + self._thought_start_time = None + self._thought_active = False async def process_frame(self, frame: Frame, direction: FrameDirection): - """Process frames into assistant conversation messages. + """Process frames into assistant conversation messages and thought messages. Handles different frame types: - TTSTextFrame: Aggregates text for current utterance + - LLMThoughtStartFrame: Begins aggregating a new thought + - LLMThoughtTextFrame: Aggregates text for current thought + - LLMThoughtEndFrame: Completes current thought - BotStoppedSpeakingFrame: Completes current utterance - - InterruptionFrame: Completes current utterance due to interruption - - EndFrame: Completes current utterance at pipeline end - - CancelFrame: Completes current utterance due to cancellation + - InterruptionFrame: Completes current utterance and thought due to interruption + - EndFrame: Completes current utterance and thought at pipeline end + - CancelFrame: Completes current utterance and thought due to cancellation Args: frame: Input frame to process. @@ -178,14 +188,40 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): # Push frame first otherwise our emitted transcription update frame # might get cleaned up. await self.push_frame(frame, direction) - # Emit accumulated text with interruptions - await self._emit_aggregated_text() + # Emit accumulated text and thought with interruptions + await self._emit_aggregated_assistant_text() + if self._process_thoughts and self._thought_active: + await self._emit_aggregated_thought() + elif isinstance(frame, LLMThoughtStartFrame): + # Start a new thought + if self._process_thoughts: + self._thought_active = True + self._thought_start_time = time_now_iso8601() + self._current_thought_parts = [] + # Push frame. + await self.push_frame(frame, direction) + elif isinstance(frame, LLMThoughtTextFrame): + # Aggregate thought text if we have an active thought + if self._process_thoughts and self._thought_active: + self._current_thought_parts.append( + TextPartForConcatenation( + frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces + ) + ) + # Push frame. + await self.push_frame(frame, direction) + elif isinstance(frame, LLMThoughtEndFrame): + # Emit accumulated thought when thought ends + if self._process_thoughts and self._thought_active: + await self._emit_aggregated_thought() + # Push frame. + await self.push_frame(frame, direction) elif isinstance(frame, TTSTextFrame): # Start timestamp on first text part - if not self._aggregation_start_time: - self._aggregation_start_time = time_now_iso8601() + if not self._assistant_text_start_time: + self._assistant_text_start_time = time_now_iso8601() - self._current_text_parts.append( + self._current_assistant_text_parts.append( TextPartForConcatenation( frame.text, includes_inter_part_spaces=frame.includes_inter_frame_spaces ) @@ -195,7 +231,10 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): await self.push_frame(frame, direction) elif isinstance(frame, (BotStoppedSpeakingFrame, EndFrame)): # Emit accumulated text when bot finishes speaking or pipeline ends. - await self._emit_aggregated_text() + await self._emit_aggregated_assistant_text() + # Emit accumulated thought at pipeline end if still active + if isinstance(frame, EndFrame) and self._process_thoughts and self._thought_active: + await self._emit_aggregated_thought() # Push frame. await self.push_frame(frame, direction) else: @@ -206,7 +245,8 @@ class TranscriptProcessor: """Factory for creating and managing transcript processors. Provides unified access to user and assistant transcript processors - with shared event handling. + with shared event handling. The assistant processor handles both TTS text + and LLM thought frames. Example:: @@ -221,7 +261,7 @@ class TranscriptProcessor: llm, tts, transport.output(), - transcript.assistant_tts(), # Assistant transcripts + transcript.assistant(), # Assistant transcripts (including thoughts) context_aggregator.assistant(), ] ) @@ -231,8 +271,14 @@ async def handle_update(processor, frame): print(f"New messages: {frame.messages}") """ - def __init__(self): - """Initialize factory.""" + def __init__(self, *, process_thoughts: bool = False): + """Initialize factory. + + Args: + process_thoughts: Whether the assistant processor should handle LLM thought + frames. Defaults to False. + """ + self._process_thoughts = process_thoughts self._user_processor = None self._assistant_processor = None self._event_handlers = {} @@ -267,7 +313,9 @@ def assistant(self, **kwargs) -> AssistantTranscriptProcessor: The assistant transcript processor instance. """ if self._assistant_processor is None: - self._assistant_processor = AssistantTranscriptProcessor(**kwargs) + self._assistant_processor = AssistantTranscriptProcessor( + process_thoughts=self._process_thoughts, **kwargs + ) # Apply any registered event handlers for event_name, handler in self._event_handlers.items(): diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index a5c67e90ea..348745e84b 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -17,7 +17,7 @@ import json import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import httpx from loguru import logger @@ -40,6 +40,9 @@ LLMFullResponseStartFrame, LLMMessagesFrame, LLMTextFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, LLMUpdateSettingsFrame, UserImageRawFrame, ) @@ -110,6 +113,24 @@ class AnthropicLLMService(LLMService): # Overriding the default adapter to use the Anthropic one. adapter_class = AnthropicLLMAdapter + class ThinkingConfig(BaseModel): + """Configuration for extended thinking. + + Parameters: + type: Type of thinking mode (currently only "enabled" or "disabled"). + budget_tokens: Maximum number of tokens for thinking. + With today's models, the minimum is 1024. + Only allowed if type is "enabled". + """ + + # Why `| str` here? To not break compatibility in case Anthropic adds + # more types in the future. + type: Literal["enabled", "disabled"] | str + + # Why not enforce minimnum of 1024 here? To not break compatibility in + # case Anthropic changes this requirement in the future. + budget_tokens: int + class InputParams(BaseModel): """Input parameters for Anthropic model inference. @@ -124,6 +145,10 @@ class InputParams(BaseModel): temperature: Sampling temperature between 0.0 and 1.0. top_k: Top-k sampling parameter. top_p: Top-p sampling parameter between 0.0 and 1.0. + thinking: Extended thinking configuration. + Enabling extended thinking causes the model to spend more time "thinking" before responding. + It also causes this service to emit LLMThinking*Frames during response generation. + Extended thinking is disabled by default. extra: Additional parameters to pass to the API. """ @@ -133,6 +158,9 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) top_k: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) + thinking: Optional["AnthropicLLMService.ThinkingConfig"] = Field( + default_factory=lambda: NOT_GIVEN + ) extra: Optional[Dict[str, Any]] = Field(default_factory=dict) def model_post_init(self, __context): @@ -191,6 +219,7 @@ def __init__( "temperature": params.temperature, "top_k": params.top_k, "top_p": params.top_p, + "thinking": params.thinking, "extra": params.extra if isinstance(params.extra, dict) else {}, } @@ -354,12 +383,21 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): "top_p": self._settings["top_p"], } + # Add thinking parameter if set + if self._settings["thinking"]: + params["thinking"] = self._settings["thinking"].model_dump(exclude_unset=True) + # Messages, system, tools params.update(params_from_context) params.update(self._settings["extra"]) - response = await self._create_message_stream(self._client.messages.create, params) + # "Interleaved thinking" needed to allow thinking between sequences + # of function calls, when extended thinking is enabled. + # Note that this requires us to use `client.beta`, below. + params.update({"betas": ["interleaved-thinking-2025-05-14"]}) + + response = await self._create_message_stream(self._client.beta.messages.create, params) await self.stop_ttfb_metrics() @@ -380,10 +418,21 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): completion_tokens_estimate += self._estimate_tokens( event.delta.partial_json ) + elif hasattr(event.delta, "thinking"): + await self.push_frame(LLMThoughtTextFrame(text=event.delta.thinking)) + elif hasattr(event.delta, "signature"): + await self.push_frame(LLMThoughtEndFrame(signature=event.delta.signature)) elif event.type == "content_block_start": if event.content_block.type == "tool_use": tool_use_block = event.content_block json_accumulator = "" + elif event.content_block.type == "thinking": + await self.push_frame( + LLMThoughtStartFrame( + append_to_context=True, + llm=self.get_llm_adapter().id_for_llm_specific_messages, + ) + ) elif ( event.type == "message_delta" and hasattr(event.delta, "stop_reason") diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index 840b473b2d..6ccea5eff6 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -16,7 +16,7 @@ import os import uuid from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, List, Optional +from typing import Any, AsyncIterator, Dict, List, Literal, Optional from loguru import logger from PIL import Image @@ -32,14 +32,18 @@ LLMContextFrame, LLMFullResponseEndFrame, LLMFullResponseStartFrame, + LLMMessagesAppendFrame, LLMMessagesFrame, LLMTextFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, LLMUpdateSettingsFrame, OutputImageRawFrame, UserImageRawFrame, ) from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.llm_context import LLMContext +from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage from pipecat.processors.aggregators.llm_response import ( LLMAssistantAggregatorParams, LLMUserAggregatorParams, @@ -665,6 +669,34 @@ class GoogleLLMService(LLMService): # Overriding the default adapter to use the Gemini one. adapter_class = GeminiLLMAdapter + class ThinkingConfig(BaseModel): + """Configuration for controlling the model's internal "thinking" process used before generating a response. + + Gemini 2.5 and 3 series models have this thinking process. + + Parameters: + thinking_level: Thinking level for Gemini 3 Pro. Can be "low" or "high". + If not provided, Gemini 3 Pro defaults to "high". + Note: Gemini 2.5 series should use thinking_budget instead. + thinking_budget: Token budget for thinking, for Gemini 2.5 series. + -1 for dynamic thinking (model decides), 0 to disable thinking, + or a specific token count (e.g., 128-32768 for 2.5 Pro). + If not provided, most models today default to dynamic thinking. + See https://ai.google.dev/gemini-api/docs/thinking#set-budget + for default values and allowed ranges. + Note: Gemini 3 Pro should use thinking_level instead. + include_thoughts: Whether to include thought summaries in the response. + Today's models default to not including thoughts (False). + """ + + thinking_budget: Optional[int] = Field(default=None) + + # Why `| str` here? To not break compatibility in case Google adds more + # levels in the future. + thinking_level: Optional[Literal["low", "high"] | str] = Field(default=None) + + include_thoughts: Optional[bool] = Field(default=None) + class InputParams(BaseModel): """Input parameters for Google AI models. @@ -673,6 +705,12 @@ class InputParams(BaseModel): temperature: Sampling temperature between 0.0 and 2.0. top_k: Top-k sampling parameter. top_p: Top-p sampling parameter between 0.0 and 1.0. + thinking: Thinking configuration with thinking_budget, thinking_level, and include_thoughts. + Used to control the model's internal "thinking" process used before generating a response. + Gemini 2.5 series models use thinking_budget; Gemini 3 models use thinking_level. + If this is not provided, Pipecat disables thinking for all + models where that's possible (the 2.5 series, except 2.5 Pro), + to reduce latency. extra: Additional parameters as a dictionary. """ @@ -680,6 +718,7 @@ class InputParams(BaseModel): temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_k: Optional[int] = Field(default=None, ge=0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + thinking: Optional["GoogleLLMService.ThinkingConfig"] = Field(default=None) extra: Optional[Dict[str, Any]] = Field(default_factory=dict) def __init__( @@ -720,6 +759,7 @@ def __init__( "temperature": params.temperature, "top_k": params.top_k, "top_p": params.top_p, + "thinking": params.thinking, "extra": params.extra if isinstance(params.extra, dict) else {}, } self._tools = tools @@ -830,6 +870,12 @@ async def _stream_content( if v is not None } + # Add thinking parameters if configured + if self._settings["thinking"]: + generation_params["thinking_config"] = self._settings["thinking"].model_dump( + exclude_unset=True + ) + if self._settings["extra"]: generation_params.update(self._settings["extra"]) @@ -896,6 +942,7 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): ) function_calls = [] + previous_part = None async for chunk in response: # Stop TTFB metrics after the first chunk await self.stop_ttfb_metrics() @@ -918,9 +965,17 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): for candidate in chunk.candidates: if candidate.content and candidate.content.parts: for part in candidate.content.parts: - if not part.thought and part.text: - search_result += part.text - await self.push_frame(LLMTextFrame(part.text)) + if part.text: + if part.thought: + # Gemini emits fully-formed thoughts rather + # than chunks so bracket each thought in + # start/end + await self.push_frame(LLMThoughtStartFrame()) + await self.push_frame(LLMThoughtTextFrame(part.text)) + await self.push_frame(LLMThoughtEndFrame()) + else: + search_result += part.text + await self.push_frame(LLMTextFrame(part.text)) elif part.function_call: function_call = part.function_call id = function_call.id or str(uuid.uuid4()) @@ -931,6 +986,17 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): tool_call_id=id, function_name=function_call.name, arguments=function_call.args or {}, + append_extra_context_messages=[ + self.get_llm_adapter().create_llm_specific_message( + { + "type": "fn_thought_signature", + "signature": part.thought_signature, + "tool_call_id": id, + } + ) + ] + if part.thought_signature + else None, ) ) elif part.inline_data and part.inline_data.data: @@ -940,6 +1006,50 @@ async def _process_context(self, context: OpenAILLMContext | LLMContext): ) await self.push_frame(frame) + # With Gemini 3 Pro (and, contrary to Google's + # docs, other models models, too, especially when + # functions are involved in the conversation), + # thought signatures can be associated with any + # kind of Part, not just function calls. + # + # They should always be included in the last + # response Part. (*) + # + # (*) Since we're using the streaming API, though, + # where text Parts may be split across multiple + # chunks (each represented by a Part, confusingly), + # signatures may actually appear with the first + # chunk (Gemini 2.5) or in a trailing empty-text + # chunk (Gemini 3 Pro). + if part.thought_signature and not part.function_call: + # Save a "bookmark" for the signature, so we + # can later stick it in the right place in + # context when sending it back to the LLM to + # continue the conversation. + bookmark = {} + if part.inline_data and part.inline_data.data: + bookmark["inline_data"] = {"inline_data": part.inline_data} + elif part.text is not None: + # Account for Gemini 3 Pro trailing + # empty-text chunk by using search_result, + # which accumulates all text so far. + bookmark["text"] = search_result + await self.push_frame( + LLMMessagesAppendFrame( + [ + self.get_llm_adapter().create_llm_specific_message( + { + "type": "non_fn_thought_signature", + "signature": part.thought_signature, + "bookmark": bookmark, + } + ) + ] + ) + ) + + previous_part = part + if ( candidate.grounding_metadata and candidate.grounding_metadata.grounding_chunks diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 6e53552638..91358dcf3c 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -14,6 +14,7 @@ Awaitable, Callable, Dict, + List, Mapping, Optional, Protocol, @@ -44,7 +45,11 @@ StartFrame, UserImageRequestFrame, ) -from pipecat.processors.aggregators.llm_context import LLMContext, LLMSpecificMessage +from pipecat.processors.aggregators.llm_context import ( + LLMContext, + LLMContextMessage, + LLMSpecificMessage, +) from pipecat.processors.aggregators.llm_response import ( LLMAssistantAggregatorParams, LLMUserAggregatorParams, @@ -127,6 +132,9 @@ class FunctionCallRunnerItem: tool_call_id: A unique identifier for the function call. arguments: The arguments for the function. context: The LLM context. + append_extra_context_messages: Optional extra messages to append to the + context after the function call message. Used to add Google + function-call-related thought signatures to the context. run_llm: Optional flag to control LLM execution after function call. """ @@ -135,6 +143,7 @@ class FunctionCallRunnerItem: tool_call_id: str arguments: Mapping[str, Any] context: OpenAILLMContext | LLMContext + append_extra_context_messages: Optional[List[LLMContextMessage]] = None run_llm: Optional[bool] = None @@ -456,6 +465,7 @@ async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM] tool_call_id=function_call.tool_call_id, arguments=function_call.arguments, context=function_call.context, + append_extra_context_messages=function_call.append_extra_context_messages, ) ) @@ -580,6 +590,7 @@ async def _run_function_call(self, runner_item: FunctionCallRunnerItem): function_name=runner_item.function_name, tool_call_id=runner_item.tool_call_id, arguments=runner_item.arguments, + append_extra_context_messages=runner_item.append_extra_context_messages, cancel_on_interruption=item.cancel_on_interruption, ) diff --git a/tests/test_transcript_processor.py b/tests/test_transcript_processor.py index d86e42101e..c8e15eb244 100644 --- a/tests/test_transcript_processor.py +++ b/tests/test_transcript_processor.py @@ -16,6 +16,10 @@ BotStoppedSpeakingFrame, CancelFrame, InterruptionFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, + ThoughtTranscriptionMessage, TranscriptionFrame, TranscriptionMessage, TranscriptionUpdateFrame, @@ -485,3 +489,309 @@ def make_tts_text_frame(text: str) -> TTSTextFrame: self.assertEqual(message.role, "assistant") # Should be properly joined without extra spaces self.assertEqual(message.content, "Hello there! How's it going?") + + +class TestThoughtTranscription(unittest.IsolatedAsyncioTestCase): + """Tests for thought transcription in AssistantTranscriptProcessor""" + + async def test_basic_thought_transcription(self): + """Test basic thought frame processing""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Create frames for a simple thought + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="Let me think about this..."), + LLMThoughtEndFrame(), + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + TranscriptionUpdateFrame, + LLMThoughtEndFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify update was received + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertIsInstance(message, ThoughtTranscriptionMessage) + self.assertEqual(message.content, "Let me think about this...") + self.assertIsNotNone(message.timestamp) + + async def test_thought_aggregation(self): + """Test that thought text frames are properly aggregated""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Create frames simulating chunked thought text + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="The user "), + LLMThoughtTextFrame(text="is asking "), + LLMThoughtTextFrame(text="about electric "), + LLMThoughtTextFrame(text="cars."), + LLMThoughtEndFrame(), + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + LLMThoughtTextFrame, + LLMThoughtTextFrame, + LLMThoughtTextFrame, + TranscriptionUpdateFrame, + LLMThoughtEndFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify aggregation + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertIsInstance(message, ThoughtTranscriptionMessage) + self.assertEqual(message.content, "The user is asking about electric cars.") + + async def test_thought_with_interruption(self): + """Test that thoughts are properly captured when interrupted""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="I need to consider "), + LLMThoughtTextFrame(text="multiple factors"), + SleepFrame(), + InterruptionFrame(), # User interrupts + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + LLMThoughtTextFrame, + InterruptionFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify thought was captured on interruption + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertIsInstance(message, ThoughtTranscriptionMessage) + self.assertEqual(message.content, "I need to consider multiple factors") + + async def test_thought_with_cancel(self): + """Test that thoughts are properly captured when cancelled""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="Starting analysis"), + SleepFrame(), + CancelFrame(), + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + CancelFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + send_end_frame=False, + ) + + # Verify thought was captured on cancellation + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertIsInstance(message, ThoughtTranscriptionMessage) + self.assertEqual(message.content, "Starting analysis") + + async def test_thought_with_end_frame(self): + """Test that thoughts are captured when pipeline ends normally""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="Final thought"), + # Pipeline ends here; run_test will automatically send EndFrame + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + TranscriptionUpdateFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify thought was captured on EndFrame + self.assertEqual(len(received_updates), 1) + message = received_updates[0].messages[0] + self.assertIsInstance(message, ThoughtTranscriptionMessage) + self.assertEqual(message.content, "Final thought") + + async def test_multiple_thoughts(self): + """Test multiple separate thoughts in sequence""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + # First thought + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="First consideration"), + LLMThoughtEndFrame(), + # Second thought + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text="Second consideration"), + LLMThoughtEndFrame(), + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + TranscriptionUpdateFrame, + LLMThoughtEndFrame, + LLMThoughtStartFrame, + LLMThoughtTextFrame, + TranscriptionUpdateFrame, + LLMThoughtEndFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify both thoughts were captured + self.assertEqual(len(received_updates), 2) + + first_message = received_updates[0].messages[0] + self.assertIsInstance(first_message, ThoughtTranscriptionMessage) + self.assertEqual(first_message.content, "First consideration") + + second_message = received_updates[1].messages[0] + self.assertIsInstance(second_message, ThoughtTranscriptionMessage) + self.assertEqual(second_message.content, "Second consideration") + + # Verify timestamps are different + self.assertNotEqual(first_message.timestamp, second_message.timestamp) + + async def test_empty_thought_handling(self): + """Test that empty thoughts are not emitted""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + frames_to_send = [ + LLMThoughtStartFrame(), + LLMThoughtTextFrame(text=""), # Empty + LLMThoughtTextFrame(text=" "), # Just whitespace + LLMThoughtEndFrame(), + ] + + expected_down_frames = [ + LLMThoughtStartFrame, + LLMThoughtTextFrame, + LLMThoughtTextFrame, + LLMThoughtEndFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify no updates emitted for empty content + self.assertEqual(len(received_updates), 0) + + async def test_thought_without_start_frame(self): + """Test that thought text without start frame is ignored""" + processor = AssistantTranscriptProcessor(process_thoughts=True) + + received_updates: List[TranscriptionUpdateFrame] = [] + + @processor.event_handler("on_transcript_update") + async def handle_update(proc, frame: TranscriptionUpdateFrame): + received_updates.append(frame) + + # Send thought text without start frame + frames_to_send = [ + LLMThoughtTextFrame(text="This should be ignored"), + LLMThoughtEndFrame(), + ] + + expected_down_frames = [ + LLMThoughtTextFrame, + LLMThoughtEndFrame, + ] + + await run_test( + processor, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + # Verify no updates since thought wasn't properly started + self.assertEqual(len(received_updates), 0) diff --git a/uv.lock b/uv.lock index cdb0842847..cebf3ab96e 100644 --- a/uv.lock +++ b/uv.lock @@ -1853,6 +1853,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/a4/7319a2a8add4cc352be9e3efeff5e2aacee917c85ca2fa1647e29089983c/google_auth-2.41.1-py2.py3-none-any.whl", hash = "sha256:754843be95575b9a19c604a848a41be03f7f2afd8c019f716dc1f51ee41c639d", size = 221302, upload-time = "2025-09-30T22:51:24.212Z" }, ] +[package.optional-dependencies] +requests = [ + { name = "requests" }, +] + [[package]] name = "google-cloud-speech" version = "2.33.0" @@ -1920,11 +1925,11 @@ wheels = [ [[package]] name = "google-genai" -version = "1.41.0" +version = "1.53.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, - { name = "google-auth" }, + { name = "google-auth", extra = ["requests"] }, { name = "httpx" }, { name = "pydantic" }, { name = "requests" }, @@ -1932,9 +1937,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/72/8b/ee20bcf707769b3b0e1106c3b5c811507736af7e8a60f29a70af1750ba19/google_genai-1.41.0.tar.gz", hash = "sha256:134f861bb0ace4e34af0501ecb75ceee15f7662fd8120698cd185e8cb39f2800", size = 245812, upload-time = "2025-10-02T22:30:29.699Z" } +sdist = { url = "https://files.pythonhosted.org/packages/de/b3/36fbfde2e21e6d3bc67780b61da33632f495ab1be08076cf0a16af74098f/google_genai-1.53.0.tar.gz", hash = "sha256:938a26d22f3fd32c6eeeb4276ef204ef82884e63af9842ce3eac05ceb39cbd8d", size = 260102, upload-time = "2025-12-03T17:21:23.233Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/14/e5e8fbca8863fee718208566c4e927b8e9f45fd46ec5cf89e24759da545b/google_genai-1.41.0-py3-none-any.whl", hash = "sha256:111a3ee64c1a0927d3879faddb368234594432479a40c311e5fe4db338ca8778", size = 245931, upload-time = "2025-10-02T22:30:27.885Z" }, + { url = "https://files.pythonhosted.org/packages/40/f2/97fefdd1ad1f3428321bac819ae7a83ccc59f6439616054736b7819fa56c/google_genai-1.53.0-py3-none-any.whl", hash = "sha256:65a3f99e5c03c372d872cda7419f5940e723374bb12a2f3ffd5e3e56e8eb2094", size = 262015, upload-time = "2025-12-03T17:21:21.934Z" }, ] [[package]] @@ -4695,7 +4700,7 @@ requires-dist = [ { name = "faster-whisper", marker = "extra == 'whisper'", specifier = "~=1.1.1" }, { name = "google-cloud-speech", marker = "extra == 'google'", specifier = ">=2.33.0,<3" }, { name = "google-cloud-texttospeech", marker = "extra == 'google'", specifier = ">=2.31.0,<3" }, - { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.41.0,<2" }, + { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.51.0,<2" }, { name = "groq", marker = "extra == 'groq'", specifier = "~=0.23.0" }, { name = "hume", marker = "extra == 'hume'", specifier = ">=0.11.2" }, { name = "langchain", marker = "extra == 'langchain'", specifier = "~=0.3.20" },