diff --git a/changelog/3189.added.md b/changelog/3189.added.md new file mode 100644 index 0000000000..5dbcb60583 --- /dev/null +++ b/changelog/3189.added.md @@ -0,0 +1,6 @@ +- Data and control frames can now be marked as non-interruptible by using the + `UninterruptibleFrame` mixin. Frames marked as `UninterruptibleFrame` will not + be interrupted during processing, and any queued frames of this type will be + retained in the internal queues. This is useful when you need ordered frames + (data or control) that should not be discarded or cancelled due to + interruptions. diff --git a/changelog/3189.changed.md b/changelog/3189.changed.md new file mode 100644 index 0000000000..f8f24a856a --- /dev/null +++ b/changelog/3189.changed.md @@ -0,0 +1,3 @@ +- `FunctionCallInProgressFrame` and `FunctionCallResultFrame` have changed from + system frames to a control frame and a data frame, respectively, and are now + both marked as `UninterruptibleFrame`. diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 488685fc3f..9cb969f283 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -186,6 +186,20 @@ class ControlFrame(Frame): # +@dataclass +class UninterruptibleFrame: + """A marker for data or control frames that must not be interrupted. + + Frames with this mixin are still ordered normally, but unlike other frames, + they are preserved during interruptions: they remain in internal queues and + any task processing them will not be cancelled. This ensures the frame is + always delivered and processed to completion. + + """ + + pass + + @dataclass class AudioRawFrame: """A frame containing a chunk of raw audio. @@ -696,6 +710,44 @@ class LLMConfigureOutputFrame(DataFrame): skip_tts: bool +@dataclass +class FunctionCallResultProperties: + """Properties for configuring function call result behavior. + + Parameters: + run_llm: Whether to run the LLM after receiving this result. + on_context_updated: Callback to execute when context is updated. + """ + + run_llm: Optional[bool] = None + on_context_updated: Optional[Callable[[], Awaitable[None]]] = None + + +@dataclass +class FunctionCallResultFrame(DataFrame, UninterruptibleFrame): + """Frame containing the result of an LLM function call. + + This is an uninterruptible frame because once a result is generated we + always want to update the context. + + Parameters: + function_name: Name of the function that was executed. + tool_call_id: Unique identifier for the function call. + arguments: Arguments that were passed to the function. + result: The result returned by the function. + run_llm: Whether to run the LLM after this result. + properties: Additional properties for result handling. + + """ + + function_name: str + tool_call_id: str + arguments: Any + result: Any + run_llm: Optional[bool] = None + properties: Optional[FunctionCallResultProperties] = None + + @dataclass class TTSSpeakFrame(DataFrame): """Frame containing text that should be spoken by TTS. @@ -1089,23 +1141,6 @@ class FunctionCallsStartedFrame(SystemFrame): function_calls: Sequence[FunctionCallFromLLM] -@dataclass -class FunctionCallInProgressFrame(SystemFrame): - """Frame signaling that a function call is currently executing. - - Parameters: - function_name: Name of the function being executed. - tool_call_id: Unique identifier for this function call. - arguments: Arguments passed to the function. - cancel_on_interruption: Whether to cancel this call if interrupted. - """ - - function_name: str - tool_call_id: str - arguments: Any - cancel_on_interruption: bool = False - - @dataclass class FunctionCallCancelFrame(SystemFrame): """Frame signaling that a function call has been cancelled. @@ -1119,40 +1154,6 @@ class FunctionCallCancelFrame(SystemFrame): tool_call_id: str -@dataclass -class FunctionCallResultProperties: - """Properties for configuring function call result behavior. - - Parameters: - run_llm: Whether to run the LLM after receiving this result. - on_context_updated: Callback to execute when context is updated. - """ - - run_llm: Optional[bool] = None - on_context_updated: Optional[Callable[[], Awaitable[None]]] = None - - -@dataclass -class FunctionCallResultFrame(SystemFrame): - """Frame containing the result of an LLM function call. - - Parameters: - function_name: Name of the function that was executed. - tool_call_id: Unique identifier for the function call. - arguments: Arguments that were passed to the function. - result: The result returned by the function. - run_llm: Whether to run the LLM after this result. - properties: Additional properties for result handling. - """ - - function_name: str - tool_call_id: str - arguments: Any - result: Any - run_llm: Optional[bool] = None - properties: Optional[FunctionCallResultProperties] = None - - @dataclass class STTMuteFrame(SystemFrame): """Frame to mute/unmute the Speech-to-Text service. @@ -1650,6 +1651,27 @@ def __post_init__(self): self.skip_tts = None +@dataclass +class FunctionCallInProgressFrame(ControlFrame, UninterruptibleFrame): + """Frame signaling that a function call is currently executing. + + This is an uninterruptible frame because we always want to update the + context. + + Parameters: + function_name: Name of the function being executed. + tool_call_id: Unique identifier for this function call. + arguments: Arguments passed to the function. + cancel_on_interruption: Whether to cancel this call if interrupted. + + """ + + function_name: str + tool_call_id: str + arguments: Any + cancel_on_interruption: bool = False + + @dataclass class TTSStartedFrame(ControlFrame): """Frame indicating the beginning of a TTS response. diff --git a/src/pipecat/processors/frame_processor.py b/src/pipecat/processors/frame_processor.py index ec9ebab98c..ed77506ddb 100644 --- a/src/pipecat/processors/frame_processor.py +++ b/src/pipecat/processors/frame_processor.py @@ -33,6 +33,7 @@ InterruptionTaskFrame, StartFrame, SystemFrame, + UninterruptibleFrame, ) from pipecat.metrics.metrics import LLMTokenUsage, MetricsData from pipecat.observers.base_observer import BaseObserver, FrameProcessed, FramePushed @@ -211,6 +212,7 @@ def __init__( # The input task that handles all types of frames. It processes system # frames right away and queues non-system frames for later processing. self.__should_block_system_frames = False + self.__input_queue = FrameProcessorQueue() self.__input_event: Optional[asyncio.Event] = None self.__input_frame_task: Optional[asyncio.Task] = None @@ -220,8 +222,10 @@ def __init__( # called. To resume processing frames we need to call # `resume_processing_frames()` which will wake up the event. self.__should_block_frames = False + self.__process_queue = asyncio.Queue() self.__process_event: Optional[asyncio.Event] = None self.__process_frame_task: Optional[asyncio.Task] = None + self.__process_current_frame: Optional[Frame] = None # To interrupt a pipeline, we push an `InterruptionTaskFrame` upstream. # Then we wait for the corresponding `InterruptionFrame` to travel from @@ -805,8 +809,12 @@ async def _start_interruption(self): # interruption). Instead we just drain the queue because this is # an interruption. self.__reset_process_task() + elif isinstance(self.__process_current_frame, UninterruptibleFrame): + # We don't want to cancel UninterruptibleFrame, so we simply + # cleanup the queue. + self.__reset_process_queue() else: - # Cancel and re-create the process task including the queue. + # Cancel and re-create the process task. await self.__cancel_process_task() self.__create_process_task() except Exception as e: @@ -872,7 +880,6 @@ def __create_input_task(self): if not self.__input_frame_task: self.__input_event = asyncio.Event() - self.__input_queue = FrameProcessorQueue() self.__input_frame_task = self.create_task(self.__input_frame_task_handler()) async def __cancel_input_task(self): @@ -890,9 +897,7 @@ def __create_process_task(self): return if not self.__process_frame_task: - self.__should_block_frames = False - self.__process_event = asyncio.Event() - self.__process_queue = asyncio.Queue() + self.__reset_process_task() self.__process_frame_task = self.create_task(self.__process_frame_task_handler()) def __reset_process_task(self): @@ -902,10 +907,26 @@ def __reset_process_task(self): self.__should_block_frames = False self.__process_event = asyncio.Event() + self.__reset_process_queue() + + def __reset_process_queue(self): + """Reset non-system frame processing queue.""" + # Create a new queue to insert UninterruptibleFrame frames. + new_queue = asyncio.Queue() + + # Process current queue and keep UninterruptibleFrame frames. while not self.__process_queue.empty(): - self.__process_queue.get_nowait() + item = self.__process_queue.get_nowait() + if isinstance(item, UninterruptibleFrame): + new_queue.put_nowait(item) self.__process_queue.task_done() + # Put back UninterruptibleFrame frames into our process queue. + while not new_queue.empty(): + item = new_queue.get_nowait() + self.__process_queue.put_nowait(item) + new_queue.task_done() + async def __cancel_process_task(self): """Cancel the non-system frame processing task.""" if self.__process_frame_task: @@ -959,8 +980,12 @@ async def __input_frame_task_handler(self): async def __process_frame_task_handler(self): """Handle non-system frames from the process queue.""" while True: + self.__process_current_frame = None + (frame, direction, callback) = await self.__process_queue.get() + self.__process_current_frame = frame + if self.__should_block_frames and self.__process_event: logger.trace(f"{self}: frame processing paused") await self.__process_event.wait() diff --git a/tests/test_frame_processor.py b/tests/test_frame_processor.py index d0072e5fb3..2a3e1e66c0 100644 --- a/tests/test_frame_processor.py +++ b/tests/test_frame_processor.py @@ -6,13 +6,17 @@ import asyncio import unittest +from dataclasses import dataclass from pipecat.frames.frames import ( + DataFrame, EndFrame, Frame, InterruptionFrame, OutputTransportMessageUrgentFrame, + SystemFrame, TextFrame, + UninterruptibleFrame, ) from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.filters.identity_filter import IdentityFilter @@ -110,3 +114,75 @@ async def process_frame(self, frame: Frame, direction: FrameDirection): expected_down_frames=expected_down_frames, send_end_frame=False, ) + + async def test_interruptible_frames(self): + @dataclass + class TestInterruptibleFrame(DataFrame): + text: str + + class DelayTestFrameProcessor(FrameProcessor): + """This processor just delays processing frames so we have time to + try to interrupt them. + """ + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if not isinstance(frame, SystemFrame): + # Sleep more than SleepFrame default. + await asyncio.sleep(0.4) + await self.push_frame(frame, direction) + + pipeline = Pipeline([DelayTestFrameProcessor()]) + + frames_to_send = [ + TestInterruptibleFrame(text="Hello from Pipecat!"), + # Make sure we hit the DelayTestFrameProcessor first. + SleepFrame(), + # Just a random interruption. This should cause the interruption of + # TestInterruptibleFrame. + InterruptionFrame(), + ] + expected_down_frames = [InterruptionFrame] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + ) + + async def test_uninterruptible_frames(self): + @dataclass + class TestUninterruptibleFrame(DataFrame, UninterruptibleFrame): + text: str + + class DelayTestFrameProcessor(FrameProcessor): + """This processor just delays processing non-InterruptionFrame so we + have time to try to interrupt them. + + """ + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if not isinstance(frame, SystemFrame): + # Sleep more than SleepFrame default. + await asyncio.sleep(0.4) + await self.push_frame(frame, direction) + + pipeline = Pipeline([DelayTestFrameProcessor()]) + + frames_to_send = [ + TestUninterruptibleFrame(text="Hello from Pipecat!"), + # Make sure we hit the DelayTestFrameProcessor first. + SleepFrame(), + # Just a random interruption. This should not cause the interruption + # of TestUninterruptibleFrame. + InterruptionFrame(), + ] + expected_down_frames = [ + InterruptionFrame, + TestUninterruptibleFrame, + ] + await run_test( + pipeline, + frames_to_send=frames_to_send, + expected_down_frames=expected_down_frames, + )