diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 9fad65dbc2..7665e3b3a5 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -33,7 +33,7 @@ jobs: - name: Install dependencies run: | - uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain + uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra daily - name: Run tests with coverage run: | diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 857ebb4893..f5c9ceaf0d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies run: | - uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain + uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra daily - name: Test with pytest run: | diff --git a/src/pipecat/transports/daily/transport.py b/src/pipecat/transports/daily/transport.py index 51ef637b8f..036778e9f9 100644 --- a/src/pipecat/transports/daily/transport.py +++ b/src/pipecat/transports/daily/transport.py @@ -567,8 +567,17 @@ async def send_message( Returns: error: An error description or None. """ + # Wait for join to complete with timeout + # If already joined, this returns immediately if not self._joined: - return "Unable to send messages before joining." + try: + await asyncio.wait_for(self._joined_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + return "Join operation timed out, unable to send message." + + # Double-check we're still joined (could have been cleared if left during wait) + if not self._joined: + return "Transport disconnected while waiting to send message." participant_id = None if isinstance( @@ -770,6 +779,8 @@ async def join(self): logger.error(error_msg) await self._callbacks.on_error(error_msg) self._joining = False + # Ensure any waiting callers are notified of failure + self._joined_event.set() # Allows send attempts to fail immediately instead of hanging async def _join(self): """Execute the actual room join operation.""" @@ -1050,8 +1061,16 @@ async def send_prebuilt_chat_message( Returns: error: An error description or None. """ + # Wait for join to complete with timeout + if not self._joined: + try: + await asyncio.wait_for(self._joined_event.wait(), timeout=10.0) + except asyncio.TimeoutError: + return "Join operation timed out, unable to send message." + + # Double-check we're still joined if not self._joined: - return "Can't send message if not joined" + return "Transport disconnected while waiting to send message." future = self._get_event_loop().create_future() self._client.send_prebuilt_chat_message( @@ -1960,7 +1979,7 @@ async def send_message( """ error = await self._client.send_message(frame) if error: - logger.error(f"Unable to send message: {error}") + logger.error(f"Unable to send message: {error}", extra={"frame": frame}) async def register_video_destination(self, destination: str): """Register a video output destination. diff --git a/tests/test_daily_transport_service.py b/tests/test_daily_transport_service.py index aabbd733da..8f8a8a6d3c 100644 --- a/tests/test_daily_transport_service.py +++ b/tests/test_daily_transport_service.py @@ -4,7 +4,141 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import asyncio import unittest +from unittest.mock import MagicMock + +from pipecat.frames.frames import OutputTransportMessageFrame +from pipecat.transports.daily.transport import DailyTransportClient + + +class TestDailyTransportRaceCondition(unittest.IsolatedAsyncioTestCase): + """Tests for the race condition fix in DailyTransport.send_message()""" + + async def test_send_message_waits_for_join(self): + """Test that send_message() waits for join to complete instead of rejecting immediately.""" + + # Create a mock transport object with just the attributes we need + transport = MagicMock(spec=DailyTransportClient) + transport._joined = False + transport._joined_event = asyncio.Event() + transport._client = MagicMock() + + # Mock the send_app_message to succeed via completion callback + def mock_send(msg, pid, completion): + completion(None) + + transport._client.send_app_message = mock_send + transport._get_event_loop = MagicMock(return_value=asyncio.get_event_loop()) + + # Set up the joined event to fire after a short delay + async def set_joined_after_delay(): + await asyncio.sleep(0.05) + transport._joined = True + transport._joined_event.set() + + send_message = DailyTransportClient.send_message + + # Schedule the event setter + task = asyncio.create_task(set_joined_after_delay()) + + # Call the real send_message with our mock object + frame = OutputTransportMessageFrame(message="test message") + result = await send_message(transport, frame) + + await task + + # Should succeed (no error) + self.assertIsNone(result) + + async def test_send_message_timeout_if_join_slow(self): + """Test that send_message() times out if join takes longer than 10 seconds.""" + + # Create a mock transport that never joins + transport = MagicMock(spec=DailyTransportClient) + transport._joined = False + transport._joined_event = asyncio.Event() # Event that never gets set + transport._client = MagicMock() + transport._get_event_loop = MagicMock(return_value=asyncio.get_event_loop()) + + # Bind the real send_message method + + send_message = DailyTransportClient.send_message + + frame = OutputTransportMessageFrame(message="test message") + + # Call send_message - it should timeout after ~10 seconds + # For testing, we'll wrap it with a shorter timeout to fail fast + start = asyncio.get_event_loop().time() + result = await asyncio.wait_for(send_message(transport, frame), timeout=11.0) + elapsed = asyncio.get_event_loop().time() - start + + # Should fail with timeout error (took at least 10 seconds) + self.assertGreaterEqual(elapsed, 9.5) + self.assertIn("timed out", result.lower() if result else "") + + async def test_send_message_already_joined(self): + """Test that send_message() returns immediately if already joined.""" + + # Create a mock transport that's already joined + transport = MagicMock(spec=DailyTransportClient) + transport._joined = True + transport._joined_event = asyncio.Event() + transport._joined_event.set() + transport._client = MagicMock() + transport._get_event_loop = MagicMock(return_value=asyncio.get_event_loop()) + + # Mock the send_app_message to succeed + def mock_send(msg, pid, completion): + completion(None) + + transport._client.send_app_message = mock_send + + # Bind the real send_message method + + send_message = DailyTransportClient.send_message + + frame = OutputTransportMessageFrame(message="test message") + + start_time = asyncio.get_event_loop().time() + result = await send_message(transport, frame) + elapsed = asyncio.get_event_loop().time() - start_time + + # Should succeed immediately + self.assertIsNone(result) + # Should not take significant time + self.assertLess(elapsed, 0.1) + + async def test_send_message_disconnects_during_wait(self): + """Test that send_message() handles disconnect during wait.""" + + transport = MagicMock(spec=DailyTransportClient) + transport._joined = False + transport._joined_event = asyncio.Event() + transport._client = MagicMock() + transport._get_event_loop = MagicMock(return_value=asyncio.get_event_loop()) + + # Simulate transport being left while waiting + async def clear_joined_during_wait(): + await asyncio.sleep(0.05) + transport._joined = False + transport._joined_event.set() + + # Bind the real method + + send_message = DailyTransportClient.send_message + + frame = OutputTransportMessageFrame(message="test message") + + # Schedule disconnect + task = asyncio.create_task(clear_joined_during_wait()) + + result = await send_message(transport, frame) + + await task + + # Should fail because transport disconnected + self.assertIn("disconnected", result.lower() if result else "") class TestDailyTransport(unittest.IsolatedAsyncioTestCase):