diff --git a/src/runloop_api_client/lib/polling.py b/src/runloop_api_client/lib/polling.py index 8031fd92b..899d2a9bf 100644 --- a/src/runloop_api_client/lib/polling.py +++ b/src/runloop_api_client/lib/polling.py @@ -2,21 +2,26 @@ from typing import Any, TypeVar, Callable, Optional from dataclasses import dataclass -T = TypeVar('T') +T = TypeVar("T") + @dataclass class PollingConfig: """Configuration for polling behavior""" + interval_seconds: float = 1.0 max_attempts: int = 120 timeout_seconds: Optional[float] = None + class PollingTimeout(Exception): """Raised when polling exceeds max attempts or timeout""" + def __init__(self, message: str, last_value: Any): self.last_value = last_value super().__init__(f"{message}. Last retrieved value: {last_value}") + def poll_until( retriever: Callable[[], T], is_terminal: Callable[[T], bool], @@ -25,27 +30,27 @@ def poll_until( ) -> T: """ Poll until a condition is met or timeout/max attempts are reached. - + Args: retriever: Callable that returns the object to check is_terminal: Callable that returns True when polling should stop config: Optional polling configuration on_error: Optional error handler that can return a value to continue polling or re-raise the exception to stop polling - + Returns: The final state of the polled object - + Raises: PollingTimeout: When max attempts or timeout is reached """ if config is None: config = PollingConfig() - + attempts = 0 start_time = time.time() last_result = None - + while True: try: last_result = retriever() @@ -54,23 +59,17 @@ def poll_until( last_result = on_error(e) else: raise - + if is_terminal(last_result): return last_result - + attempts += 1 if attempts >= config.max_attempts: - raise PollingTimeout( - f"Exceeded maximum attempts ({config.max_attempts})", - last_result - ) - + raise PollingTimeout(f"Exceeded maximum attempts ({config.max_attempts})", last_result) + if config.timeout_seconds is not None: elapsed = time.time() - start_time if elapsed >= config.timeout_seconds: - raise PollingTimeout( - f"Exceeded timeout of {config.timeout_seconds} seconds", - last_result - ) - + raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) + time.sleep(config.interval_seconds) diff --git a/tests/api_resources/devboxes/test_executions.py b/tests/api_resources/devboxes/test_executions.py index 32da6909f..c712dc20b 100755 --- a/tests/api_resources/devboxes/test_executions.py +++ b/tests/api_resources/devboxes/test_executions.py @@ -4,12 +4,15 @@ import os from typing import Any, cast +from unittest.mock import Mock, patch import pytest from tests.utils import assert_matches_type from runloop_api_client import Runloop, AsyncRunloop from runloop_api_client.types import DevboxExecutionDetailView, DevboxAsyncExecutionDetailView +from runloop_api_client._exceptions import APIStatusError +from runloop_api_client.lib.polling import PollingConfig, PollingTimeout base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -224,6 +227,172 @@ def test_path_params_kill(self, client: Runloop) -> None: devbox_id="devbox_id", ) + # Polling method tests + @parametrize + def test_method_await_completed_success(self, client: Runloop) -> None: + """Test await_completed with successful polling to completed state""" + + # Mock the wait_for_status calls - first returns running, then completed + mock_execution_running = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="running", + stdout="Starting...", + stderr="", + ) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Starting...\nFinished!", + stderr="", + ) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = [mock_execution_running, mock_execution_completed] + + result = client.devboxes.executions.await_completed("execution_id", "devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 + + @parametrize + def test_method_await_completed_immediate_success(self, client: Runloop) -> None: + """Test await_completed when execution is already completed""" + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Already finished!", + stderr="", + ) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_completed + + result = client.devboxes.executions.await_completed("execution_id", "devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 1 + + @parametrize + def test_method_await_completed_timeout_handling(self, client: Runloop) -> None: + """Test await_completed handles 408 timeouts correctly""" + + # Create a mock 408 response + mock_response = Mock() + mock_response.status_code = 408 + mock_408_error = APIStatusError("Request timeout", response=mock_response, body=None) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Finished after timeout!", + stderr="", + ) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + # First call raises 408, second call succeeds + mock_post.side_effect = [mock_408_error, mock_execution_completed] + + result = client.devboxes.executions.await_completed("execution_id", "devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 + + @parametrize + def test_method_await_completed_other_error(self, client: Runloop) -> None: + """Test await_completed re-raises non-408 errors""" + + # Create a mock 500 response + mock_response = Mock() + mock_response.status_code = 500 + mock_500_error = APIStatusError("Internal server error", response=mock_response, body=None) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = mock_500_error + + with pytest.raises(APIStatusError, match="Internal server error"): + client.devboxes.executions.await_completed("execution_id", "devbox_id") + + @parametrize + def test_method_await_completed_with_config(self, client: Runloop) -> None: + """Test await_completed with custom polling configuration""" + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Finished with config!", + stderr="", + ) + + config = PollingConfig(interval_seconds=0.1, max_attempts=10) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_completed + + result = client.devboxes.executions.await_completed("execution_id", "devbox_id", config=config) + + assert result.execution_id == "execution_id" + assert result.status == "completed" + + @parametrize + def test_method_await_completed_polling_timeout(self, client: Runloop) -> None: + """Test await_completed raises PollingTimeout when max attempts exceeded""" + + mock_execution_running = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="running", + stdout="Still running...", + stderr="", + ) + + config = PollingConfig(interval_seconds=0.01, max_attempts=2) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_running + + with pytest.raises(PollingTimeout): + client.devboxes.executions.await_completed("execution_id", "devbox_id", config=config) + + @parametrize + def test_method_await_completed_various_statuses(self, client: Runloop) -> None: + """Test await_completed correctly handles different execution statuses""" + + # Test with queued status first + mock_execution_queued = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="queued", + stdout="", + stderr="", + ) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Done!", + stderr="", + ) + + with patch.object(client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = [mock_execution_queued, mock_execution_completed] + + result = client.devboxes.executions.await_completed("execution_id", "devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 + class TestAsyncExecutions: parametrize = pytest.mark.parametrize( @@ -436,3 +605,173 @@ async def test_path_params_kill(self, async_client: AsyncRunloop) -> None: execution_id="", devbox_id="devbox_id", ) + + # Async polling method tests + @parametrize + async def test_method_await_completed_success(self, async_client: AsyncRunloop) -> None: + """Test await_completed with successful polling to completed state""" + + # Mock the wait_for_status calls - first returns running, then completed + mock_execution_running = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="running", + stdout="Starting...", + stderr="", + ) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Starting...\nFinished!", + stderr="", + ) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = [mock_execution_running, mock_execution_completed] + + result = await async_client.devboxes.executions.await_completed("execution_id", devbox_id="devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 + + @parametrize + async def test_method_await_completed_immediate_success(self, async_client: AsyncRunloop) -> None: + """Test await_completed when execution is already completed""" + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Already finished!", + stderr="", + ) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_completed + + result = await async_client.devboxes.executions.await_completed("execution_id", devbox_id="devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 1 + + @parametrize + async def test_method_await_completed_timeout_handling(self, async_client: AsyncRunloop) -> None: + """Test await_completed handles 408 timeouts correctly""" + + # Create a mock 408 response + mock_response = Mock() + mock_response.status_code = 408 + mock_408_error = APIStatusError("Request timeout", response=mock_response, body=None) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Finished after timeout!", + stderr="", + ) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + # First call raises 408, second call succeeds + mock_post.side_effect = [mock_408_error, mock_execution_completed] + + result = await async_client.devboxes.executions.await_completed("execution_id", devbox_id="devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 + + @parametrize + async def test_method_await_completed_other_error(self, async_client: AsyncRunloop) -> None: + """Test await_completed re-raises non-408 errors""" + + # Create a mock 500 response + mock_response = Mock() + mock_response.status_code = 500 + mock_500_error = APIStatusError("Internal server error", response=mock_response, body=None) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = mock_500_error + + with pytest.raises(APIStatusError, match="Internal server error"): + await async_client.devboxes.executions.await_completed("execution_id", devbox_id="devbox_id") + + @parametrize + async def test_method_await_completed_with_config(self, async_client: AsyncRunloop) -> None: + """Test await_completed with custom polling configuration""" + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Finished with config!", + stderr="", + ) + + config = PollingConfig(interval_seconds=0.1, max_attempts=10) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_completed + + result = await async_client.devboxes.executions.await_completed( + "execution_id", devbox_id="devbox_id", polling_config=config + ) + + assert result.execution_id == "execution_id" + assert result.status == "completed" + + @parametrize + async def test_method_await_completed_polling_timeout(self, async_client: AsyncRunloop) -> None: + """Test await_completed raises PollingTimeout when max attempts exceeded""" + + mock_execution_running = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="running", + stdout="Still running...", + stderr="", + ) + + config = PollingConfig(interval_seconds=0.01, max_attempts=2) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.return_value = mock_execution_running + + with pytest.raises(PollingTimeout): + await async_client.devboxes.executions.await_completed( + "execution_id", devbox_id="devbox_id", polling_config=config + ) + + @parametrize + async def test_method_await_completed_various_statuses(self, async_client: AsyncRunloop) -> None: + """Test await_completed correctly handles different execution statuses""" + + # Test with queued status first + mock_execution_queued = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="queued", + stdout="", + stderr="", + ) + + mock_execution_completed = DevboxAsyncExecutionDetailView( + devbox_id="devbox_id", + execution_id="execution_id", + status="completed", + stdout="Done!", + stderr="", + ) + + with patch.object(async_client.devboxes.executions, "_post") as mock_post: + mock_post.side_effect = [mock_execution_queued, mock_execution_completed] + + result = await async_client.devboxes.executions.await_completed("execution_id", devbox_id="devbox_id") + + assert result.execution_id == "execution_id" + assert result.status == "completed" + assert mock_post.call_count == 2 diff --git a/tests/api_resources/test_devboxes.py b/tests/api_resources/test_devboxes.py index 9c9c78e49..f27641463 100644 --- a/tests/api_resources/test_devboxes.py +++ b/tests/api_resources/test_devboxes.py @@ -4,6 +4,7 @@ import os from typing import Any, cast +from unittest.mock import ANY, Mock, patch import httpx import pytest @@ -31,6 +32,9 @@ SyncDiskSnapshotsCursorIDPage, AsyncDiskSnapshotsCursorIDPage, ) +from runloop_api_client._exceptions import RunloopError, APIStatusError +from runloop_api_client.lib.polling import PollingConfig, PollingTimeout +from runloop_api_client.types.shared.launch_parameters import LaunchParameters base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -965,6 +969,287 @@ def test_path_params_write_file_contents(self, client: Runloop) -> None: file_path="file_path", ) + # Polling method tests + @parametrize + def test_method_await_running_success(self, client: Runloop) -> None: + """Test await_running with successful polling to running state""" + + # Mock the wait_for_status calls - first returns provisioning, then running + mock_devbox_provisioning = DevboxView( + id="test_id", + status="provisioning", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.side_effect = [mock_devbox_provisioning, mock_devbox_running] + + result = client.devboxes.await_running("test_id") + + assert result.id == "test_id" + assert result.status == "running" + assert mock_post.call_count == 2 + + @parametrize + def test_method_await_running_immediate_success(self, client: Runloop) -> None: + """Test await_running when devbox is already running""" + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.return_value = mock_devbox_running + + result = client.devboxes.await_running("test_id") + + assert result.id == "test_id" + assert result.status == "running" + assert mock_post.call_count == 1 + + @parametrize + def test_method_await_running_failure_state(self, client: Runloop) -> None: + """Test await_running when devbox enters failure state""" + + mock_devbox_failed = DevboxView( + id="test_id", + status="failure", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.return_value = mock_devbox_failed + + with pytest.raises(RunloopError, match="Devbox entered non-running terminal state: failure"): + client.devboxes.await_running("test_id") + + @parametrize + def test_method_await_running_timeout_handling(self, client: Runloop) -> None: + """Test await_running handles 408 timeouts correctly""" + + # Create a mock 408 response + mock_response = Mock() + mock_response.status_code = 408 + mock_408_error = APIStatusError("Request timeout", response=mock_response, body=None) + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "_post") as mock_post: + # First call raises 408, second call succeeds + mock_post.side_effect = [mock_408_error, mock_devbox_running] + + result = client.devboxes.await_running("test_id") + + assert result.id == "test_id" + assert result.status == "running" + assert mock_post.call_count == 2 + + @parametrize + def test_method_await_running_other_error(self, client: Runloop) -> None: + """Test await_running re-raises non-408 errors""" + + # Create a mock 500 response + mock_response = Mock() + mock_response.status_code = 500 + mock_500_error = APIStatusError("Internal server error", response=mock_response, body=None) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.side_effect = mock_500_error + + with pytest.raises(APIStatusError, match="Internal server error"): + client.devboxes.await_running("test_id") + + @parametrize + def test_method_await_running_with_config(self, client: Runloop) -> None: + """Test await_running with custom polling configuration""" + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + config = PollingConfig(interval_seconds=0.1, max_attempts=10) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.return_value = mock_devbox_running + + result = client.devboxes.await_running("test_id", polling_config=config) + + assert result.id == "test_id" + assert result.status == "running" + + @parametrize + def test_method_await_running_polling_timeout(self, client: Runloop) -> None: + """Test await_running raises PollingTimeout when max attempts exceeded""" + + mock_devbox_provisioning = DevboxView( + id="test_id", + status="provisioning", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + config = PollingConfig(interval_seconds=0.01, max_attempts=2) + + with patch.object(client.devboxes, "_post") as mock_post: + mock_post.return_value = mock_devbox_provisioning + + with pytest.raises(PollingTimeout): + client.devboxes.await_running("test_id", polling_config=config) + + @parametrize + def test_method_create_and_await_running_success(self, client: Runloop) -> None: + """Test create_and_await_running successful flow""" + + mock_devbox_creating = DevboxView( + id="test_id", + status="provisioning", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "create") as mock_create: + with patch.object(client.devboxes, "await_running") as mock_await: + mock_create.return_value = mock_devbox_creating + mock_await.return_value = mock_devbox_running + + result = client.devboxes.create_and_await_running(name="test") + + assert result.id == "test_id" + assert result.status == "running" + mock_create.assert_called_once() + mock_await.assert_called_once_with( + "test_id", polling_config=None, extra_headers=None, extra_query=None, extra_body=None, timeout=ANY + ) + + @parametrize + def test_method_create_and_await_running_with_config(self, client: Runloop) -> None: + """Test create_and_await_running with custom polling configuration""" + + mock_devbox_creating = DevboxView( + id="test_id", + status="provisioning", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + mock_devbox_running = DevboxView( + id="test_id", + status="running", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + config = PollingConfig(interval_seconds=0.1, max_attempts=10) + + with patch.object(client.devboxes, "create") as mock_create: + with patch.object(client.devboxes, "await_running") as mock_await: + mock_create.return_value = mock_devbox_creating + mock_await.return_value = mock_devbox_running + + result = client.devboxes.create_and_await_running(name="test", polling_config=config) + + assert result.id == "test_id" + assert result.status == "running" + mock_await.assert_called_once_with( + "test_id", polling_config=config, extra_headers=None, extra_query=None, extra_body=None, timeout=ANY + ) + + @parametrize + def test_method_create_and_await_running_create_failure(self, client: Runloop) -> None: + """Test create_and_await_running when create fails""" + + mock_response = Mock() + mock_response.status_code = 400 + mock_error = APIStatusError("Bad request", response=mock_response, body=None) + + with patch.object(client.devboxes, "create") as mock_create: + mock_create.side_effect = mock_error + + with pytest.raises(APIStatusError, match="Bad request"): + client.devboxes.create_and_await_running(name="test") + + @parametrize + def test_method_create_and_await_running_await_failure(self, client: Runloop) -> None: + """Test create_and_await_running when await_running fails""" + + mock_devbox_creating = DevboxView( + id="test_id", + status="provisioning", + capabilities=[], + create_time_ms=1234567890, + launch_parameters=LaunchParameters(resource_size_request="X_SMALL"), + metadata={}, + state_transitions=[], + ) + + with patch.object(client.devboxes, "create") as mock_create: + with patch.object(client.devboxes, "await_running") as mock_await: + mock_create.return_value = mock_devbox_creating + mock_await.side_effect = RunloopError("Devbox entered non-running terminal state: failed") + + with pytest.raises(RunloopError, match="Devbox entered non-running terminal state: failed"): + client.devboxes.create_and_await_running(name="test") + class TestAsyncDevboxes: parametrize = pytest.mark.parametrize( diff --git a/tests/test_polling.py b/tests/test_polling.py new file mode 100644 index 000000000..74819531b --- /dev/null +++ b/tests/test_polling.py @@ -0,0 +1,262 @@ +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from src.runloop_api_client.lib.polling import PollingConfig, PollingTimeout, poll_until + + +class TestPollingConfig: + """Test PollingConfig dataclass""" + + def test_default_config(self): + config = PollingConfig() + assert config.interval_seconds == 1.0 + assert config.max_attempts == 120 + assert config.timeout_seconds is None + + def test_custom_config(self): + config = PollingConfig(interval_seconds=0.5, max_attempts=10, timeout_seconds=30.0) + assert config.interval_seconds == 0.5 + assert config.max_attempts == 10 + assert config.timeout_seconds == 30.0 + + +class TestPollingTimeout: + """Test PollingTimeout exception""" + + def test_polling_timeout_initialization(self): + last_value = {"status": "running"} + exception = PollingTimeout("Test message", last_value) + + assert exception.last_value == last_value + assert "Test message" in str(exception) + assert "Last retrieved value: {'status': 'running'}" in str(exception) + + +class TestPollUntil: + """Test poll_until function""" + + def test_immediate_success(self): + """Test when condition is met on first attempt""" + retriever = Mock(return_value="completed") + is_terminal = Mock(return_value=True) + + result = poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 1 + assert is_terminal.call_count == 1 + is_terminal.assert_called_with("completed") + + def test_success_after_multiple_attempts(self): + """Test when condition is met after several attempts""" + values = ["pending", "running", "completed"] + retriever = Mock(side_effect=values) + is_terminal = Mock(side_effect=[False, False, True]) + + with patch("time.sleep") as mock_sleep: + result = poll_until(retriever, is_terminal) + + assert result == "completed" + assert retriever.call_count == 3 + assert is_terminal.call_count == 3 + assert mock_sleep.call_count == 2 # Should sleep between attempts + + def test_custom_config_interval(self): + """Test with custom polling interval""" + retriever = Mock(side_effect=["pending", "completed"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=0.1) + + with patch("time.sleep") as mock_sleep: + result = poll_until(retriever, is_terminal, config) + + assert result == "completed" + mock_sleep.assert_called_with(0.1) + + def test_max_attempts_exceeded(self): + """Test when max attempts is exceeded""" + retriever = Mock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(max_attempts=3, interval_seconds=0.01) + + with patch("time.sleep"): + with pytest.raises(PollingTimeout) as exc_info: + poll_until(retriever, is_terminal, config) + + assert "Exceeded maximum attempts (3)" in str(exc_info.value) + assert exc_info.value.last_value == "still_running" + assert retriever.call_count == 3 + + def test_timeout_exceeded(self): + """Test when timeout is exceeded""" + retriever = Mock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(timeout_seconds=0.1, interval_seconds=0.01) + + # Mock time.time to simulate timeout + start_time = 1000.0 + with patch("time.time", side_effect=[start_time, start_time + 0.05, start_time + 0.15]): + with patch("time.sleep"): + with pytest.raises(PollingTimeout) as exc_info: + poll_until(retriever, is_terminal, config) + + assert "Exceeded timeout of 0.1 seconds" in str(exc_info.value) + assert exc_info.value.last_value == "still_running" + + def test_error_without_handler(self): + """Test that exceptions are re-raised when no error handler is provided""" + retriever = Mock(side_effect=ValueError("Test error")) + is_terminal = Mock(return_value=False) + + with pytest.raises(ValueError, match="Test error"): + poll_until(retriever, is_terminal) + + def test_error_with_handler_continue(self): + """Test error handler that allows polling to continue""" + retriever = Mock(side_effect=[ValueError("Test error"), "recovered"]) + is_terminal = Mock(side_effect=[False, True]) + + def error_handler(_: Exception) -> str: + return "error_handled" + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal, on_error=error_handler) + + assert result == "recovered" + assert retriever.call_count == 2 + assert is_terminal.call_count == 2 + + def test_error_with_handler_reraise(self): + """Test error handler that re-raises the exception""" + retriever = Mock(side_effect=ValueError("Test error")) + is_terminal = Mock(return_value=False) + + def error_handler(e: Exception) -> None: + raise e + + with pytest.raises(ValueError, match="Test error"): + poll_until(retriever, is_terminal, on_error=error_handler) + + def test_error_handler_return_terminal_value(self): + """Test error handler that returns a terminal value""" + retriever = Mock(side_effect=ValueError("Test error")) + is_terminal = Mock(side_effect=[True]) # Terminal condition met on error handler return + + def error_handler(_: Exception) -> str: + return "error_terminal" + + result = poll_until(retriever, is_terminal, on_error=error_handler) + + assert result == "error_terminal" + assert retriever.call_count == 1 + assert is_terminal.call_count == 1 + + def test_multiple_errors_with_handler(self): + """Test multiple errors with handler""" + retriever = Mock(side_effect=[ValueError("Error 1"), RuntimeError("Error 2"), "success"]) + is_terminal = Mock(side_effect=[False, False, True]) + + error_count = 0 + + def error_handler(_: Exception) -> str: + nonlocal error_count + error_count += 1 + return f"handled_error_{error_count}" + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal, on_error=error_handler) + + assert result == "success" + assert error_count == 2 + assert retriever.call_count == 3 + + def test_none_values_handling(self): + """Test handling of None values""" + retriever = Mock(side_effect=[None, None, "final"]) + is_terminal = Mock(side_effect=[False, False, True]) + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal) + + assert result == "final" + assert retriever.call_count == 3 + + def test_complex_object_polling(self): + """Test polling with complex objects""" + + class Status: + def __init__(self, state: str, progress: int): + self.state = state + self.progress = progress + + statuses = [Status("starting", 0), Status("running", 50), Status("completed", 100)] + + retriever = Mock(side_effect=statuses) + is_terminal = Mock(side_effect=[False, False, True]) + + with patch("time.sleep"): + result = poll_until(retriever, is_terminal) + + assert result.state == "completed" + assert result.progress == 100 + + def test_zero_max_attempts(self): + """Test with zero max attempts""" + retriever = Mock(return_value="value") + is_terminal = Mock(return_value=False) + config = PollingConfig(max_attempts=0) + + with pytest.raises(PollingTimeout) as exc_info: + poll_until(retriever, is_terminal, config) + + assert "Exceeded maximum attempts (0)" in str(exc_info.value) + assert retriever.call_count == 1 # Retriever is called once, then attempts check happens + + def test_negative_interval(self): + """Test with negative interval (should still work)""" + retriever = Mock(side_effect=["first", "second"]) + is_terminal = Mock(side_effect=[False, True]) + config = PollingConfig(interval_seconds=-0.1) + + with patch("time.sleep") as mock_sleep: + result = poll_until(retriever, is_terminal, config) + + assert result == "second" + mock_sleep.assert_called_with(-0.1) + + def test_both_timeout_and_max_attempts(self): + """Test when both timeout and max_attempts are set""" + retriever = Mock(return_value="still_running") + is_terminal = Mock(return_value=False) + config = PollingConfig(max_attempts=5, timeout_seconds=0.1, interval_seconds=0.01) + + # Mock time to hit timeout before max_attempts + start_time = 1000.0 + with patch("time.time", side_effect=[start_time, start_time + 0.05, start_time + 0.15]): + with patch("time.sleep"): + with pytest.raises(PollingTimeout) as exc_info: + poll_until(retriever, is_terminal, config) + + # Should hit timeout first + assert "Exceeded timeout of 0.1 seconds" in str(exc_info.value) + assert retriever.call_count == 2 # Called twice before timeout + + def test_terminal_condition_changes(self): + """Test when terminal condition logic changes during polling""" + retriever = Mock(side_effect=["value1", "value2", "value3"]) + + call_count = 0 + + def dynamic_terminal(_: Any) -> bool: + nonlocal call_count + call_count += 1 + # First two calls return False, third returns True + return call_count >= 3 + + with patch("time.sleep"): + result = poll_until(retriever, dynamic_terminal) + + assert result == "value3" + assert retriever.call_count == 3