diff --git a/src/anycastd/core/_run.py b/src/anycastd/core/_run.py index 555aa3c..c507753 100644 --- a/src/anycastd/core/_run.py +++ b/src/anycastd/core/_run.py @@ -1,9 +1,17 @@ import asyncio +import signal +import sys from collections.abc import Iterable +from functools import partial +from typing import NoReturn + +import structlog from anycastd._configuration import MainConfiguration, config_to_service from anycastd.core._service import Service +logger = structlog.get_logger() + async def run_from_configuration(configuration: MainConfiguration) -> None: """Run anycastd using an instance of the main configuration.""" @@ -12,11 +20,29 @@ async def run_from_configuration(configuration: MainConfiguration) -> None: async def run_services(services: Iterable[Service]) -> None: - """Run services. + """Run services until termination. + + A signal handler is installed to manage termination. When a SIGTERM or SIGINT + signal is received, graceful termination is managed by the handler. Args: services: The services to run. """ + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, partial(signal_handler, sig)) + async with asyncio.TaskGroup() as tg: for service in services: tg.create_task(service.run()) + + +def signal_handler(signal: signal.Signals) -> NoReturn: + """Logs the received signal and terminates all tasks.""" + msg = f"Received {signal.name}, terminating." + + logger.info(msg) + for task in asyncio.all_tasks(): + task.cancel(msg) + + sys.exit(0) diff --git a/src/anycastd/core/_service.py b/src/anycastd/core/_service.py index 2a2991b..35be441 100644 --- a/src/anycastd/core/_service.py +++ b/src/anycastd/core/_service.py @@ -51,6 +51,7 @@ class Service: health_checks: tuple[Healthcheck, ...] _healthy: bool = field(default=False, init=False, repr=False, compare=False) + _terminate: bool = field(default=False, init=False, repr=False, compare=False) _log: structlog.typing.FilteringBoundLogger = field( default=logger, init=False, repr=False, compare=False ) @@ -83,27 +84,33 @@ def healthy(self, new_value: bool) -> None: service_healthy=self.healthy, ) - # The _only_once parameter is only used for testing. - # TODO: Look into a better way to do this. - async def run(self, *, _only_once: bool = False) -> None: + async def run(self) -> None: """Run the service. This will announce the prefixes when all health checks are - passing, and denounce them otherwise. + passing, and denounce them otherwise. If the returned coroutine is cancelled, + the service will be terminated, denouncing all prefixes in the process. """ self._log.info(f"Starting service {self.name}.", service_healthy=self.healthy) - while True: - checks_currently_healthy: bool = await self.all_checks_healthy() + try: + while not self._terminate: + checks_currently_healthy: bool = await self.all_checks_healthy() - if checks_currently_healthy and not self.healthy: - self.healthy = True - await self.announce_all_prefixes() - elif not checks_currently_healthy and self.healthy: - self.healthy = False - await self.denounce_all_prefixes() + if checks_currently_healthy and not self.healthy: + self.healthy = True + await self.announce_all_prefixes() + elif not checks_currently_healthy and self.healthy: + self.healthy = False + await self.denounce_all_prefixes() - if _only_once: - break + await asyncio.sleep(0.05) + + except asyncio.CancelledError: + self._log.debug( + f"Coroutine for service {self.name} was cancelled.", + service_healthy=self.healthy, + ) + await self.terminate() async def all_checks_healthy(self) -> bool: """Runs all checks and returns their cumulative result. @@ -144,3 +151,9 @@ async def denounce_all_prefixes(self) -> None: async with asyncio.TaskGroup() as tg: for prefix in self.prefixes: tg.create_task(prefix.denounce()) + + async def terminate(self) -> None: + """Terminate the service and denounce its prefixes.""" + self._terminate = True + await self.denounce_all_prefixes() + logger.info(f"Service {self.name} terminated.", service=self.name) diff --git a/tests/test_run.py b/tests/test_run.py index 7fb2d4d..f60bfbb 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,10 @@ +import asyncio +import signal + import pytest -from anycastd.core._run import run_services +from anycastd.core._run import run_services, signal_handler from anycastd.core._service import Service +from structlog.testing import capture_logs @pytest.fixture @@ -13,3 +17,53 @@ async def test_future_created_for_each_service(mock_services): """A future is created for each service.""" await run_services(mock_services) assert all(mock_service.run.called for mock_service in mock_services) + + +@pytest.mark.parametrize("signal_to_handle", [signal.SIGTERM, signal.SIGINT]) +async def test_run_services_installs_signal_handlers( + mocker, mock_services, signal_to_handle: signal.Signals +): + """A handler is installed for signals we want to handle.""" + mock_loop = mocker.create_autospec(asyncio.AbstractEventLoop) + mocker.patch("anycastd.core._run.asyncio.get_event_loop", return_value=mock_loop) + + await run_services(mock_services) + + assert ( + mocker.call(signal_to_handle, mocker.ANY) + in mock_loop.add_signal_handler.mock_calls + ) + + +def test_signal_handler_logs_signal(mocker): + """The signal handler logs the received signal.""" + mocker.patch("anycastd.core._run.asyncio") + mocker.patch("anycastd.core._run.sys") + + with capture_logs() as logs: + signal_handler(signal.SIGTERM) + + assert logs[0]["event"] == "Received SIGTERM, terminating." + assert logs[0]["log_level"] == "info" + + +def test_signal_handler_cancels_all_tasks(mocker): + """The signal handler cancels all tasks.""" + tasks = [mocker.create_autospec(asyncio.Task) for _ in range(3)] + mocker.patch("anycastd.core._run.asyncio.all_tasks", return_value=tasks) + mocker.patch("anycastd.core._run.sys") + + signal_handler(signal.SIGTERM) + + for task in tasks: + task.cancel.assert_called_once_with("Received SIGTERM, terminating.") + + +def test_signal_handler_exits_with_zero(mocker): + """The signal handler exits with returncode zero.""" + mocker.patch("anycastd.core._run.asyncio") + mock_sys = mocker.patch("anycastd.core._run.sys") + + signal_handler(signal.SIGTERM) + + mock_sys.exit.assert_called_once_with(0) diff --git a/tests/test_service.py b/tests/test_service.py index 7ec41d9..94a3e0e 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from anycastd.core import Service from pytest_mock import MockerFixture @@ -35,16 +37,37 @@ def example_service_w_mock_checks(mocker: MockerFixture, example_service) -> Ser return example_service -async def test_run_awaits_all_checks(mocker: MockerFixture, example_service): +@pytest.fixture +def patch_asyncio_sleep_to_raise(mocker: MockerFixture) -> None: + """Patch the asyncio.sleep function to raise an exception. + + This is useful to terminate the services run loop at the end of the first execution + in tests, instead of running it indefinitely. + """ + mocker.patch( + "anycastd.core._service.asyncio.sleep", + side_effect=RuntimeError("Exit loop"), + ) + + +async def test_run_awaits_all_checks( + mocker: MockerFixture, patch_asyncio_sleep_to_raise, example_service +): """When run, the service awaits the status of all its health checks.""" mock_all_checks_healthy = mocker.patch.object(example_service, "all_checks_healthy") - await example_service.run(_only_once=True) + + with pytest.raises(RuntimeError, match="Exit loop"): + await example_service.run() + mock_all_checks_healthy.assert_awaited_once() @pytest.mark.parametrize("was_healthy", [True, False]) async def test_run_announces_all_when_health_state_changes_to_healty( - mocker: MockerFixture, example_service_w_mock_prefixes, was_healthy: bool + mocker: MockerFixture, + patch_asyncio_sleep_to_raise, + example_service_w_mock_prefixes, + was_healthy: bool, ): """ When run, all prefixes are announced if the all_checks_healthy method returns True @@ -59,7 +82,8 @@ async def test_run_announces_all_when_health_state_changes_to_healty( example_service_w_mock_prefixes, "announce_all_prefixes" ) - await example_service_w_mock_prefixes.run(_only_once=True) + with pytest.raises(RuntimeError, match="Exit loop"): + await example_service_w_mock_prefixes.run() if not was_healthy: mock_announce_all.assert_awaited_once() @@ -69,7 +93,10 @@ async def test_run_announces_all_when_health_state_changes_to_healty( @pytest.mark.parametrize("was_healthy", [True, False]) async def test_run_denounces_all_when_health_state_changes_to_unhealty( - mocker: MockerFixture, example_service_w_mock_prefixes, was_healthy: bool + mocker: MockerFixture, + patch_asyncio_sleep_to_raise, + example_service_w_mock_prefixes, + was_healthy: bool, ): """ When run, all prefixes are denounced if the all_checks_healthy method returns False @@ -84,7 +111,8 @@ async def test_run_denounces_all_when_health_state_changes_to_unhealty( example_service_w_mock_prefixes, "denounce_all_prefixes" ) - await example_service_w_mock_prefixes.run(_only_once=True) + with pytest.raises(RuntimeError, match="Exit loop"): + await example_service_w_mock_prefixes.run() if was_healthy: mock_denounce_all.assert_awaited_once() @@ -93,7 +121,7 @@ async def test_run_denounces_all_when_health_state_changes_to_unhealty( async def test_run_updates_health_state_when_changed( - mocker: MockerFixture, example_service_w_mock_prefixes + mocker: MockerFixture, patch_asyncio_sleep_to_raise, example_service_w_mock_prefixes ): """ When run, the service's health state is updated when the result of the @@ -104,7 +132,8 @@ async def test_run_updates_health_state_when_changed( example_service_w_mock_prefixes, "all_checks_healthy", return_value=True ) - await example_service_w_mock_prefixes.run(_only_once=True) + with pytest.raises(RuntimeError, match="Exit loop"): + await example_service_w_mock_prefixes.run() assert example_service_w_mock_prefixes.healthy is True @@ -259,3 +288,46 @@ async def test_denounce_all_prefixes_awaits_denounce_of_all_prefixes( await example_service_w_mock_prefixes.denounce_all_prefixes() for mock_prefix in example_service_w_mock_prefixes.prefixes: mock_prefix.denounce.assert_awaited_once() + + +async def test_run_coro_cancellation_logs_termination(example_service, mocker): + """When the run coroutine is cancelled, the termination is logged.""" + # Create a task to run the service loop + run_task = asyncio.create_task(example_service.run()) + # Give the event loop some time to poll the task above + await asyncio.sleep(0.3) + + with capture_logs() as logs: + # Cancel the task + run_task.cancel() + # Give the event loop some time to cancel + await asyncio.sleep(0.2) + + assert ( + logs[0]["event"] + == f"Coroutine for service {example_service.name} was cancelled." + ) + assert logs[0]["log_level"] == "debug" + assert logs[0]["service_name"] == example_service.name + assert logs[0]["service_healthy"] == example_service.healthy + + +async def test_run_coro_cancellation_awaits_termination(example_service, mocker): + """When the run coroutine is cancelled, service termination is awaited.""" + mock_terminate = mocker.patch.object(example_service, "terminate") + mock_all_checks_healthy = mocker.patch.object(example_service, "all_checks_healthy") + mock_all_checks_healthy.return_value = True + + # Create a task to run the service loop + run_task = asyncio.create_task(example_service.run()) + # Give the event loop some time to poll the task above + await asyncio.sleep(0.3) + # Sanity check to make sure the coro actually ran + mock_all_checks_healthy.assert_awaited() + + # Cancel the task + run_task.cancel() + # Give the event loop some time to cancel + await asyncio.sleep(0.2) + + mock_terminate.assert_awaited_once()