Skip to content

Commit

Permalink
Add signal handler for SIGTERM & SIGINT to terminate services (#54)
Browse files Browse the repository at this point in the history
* Add `terminate` method to service

* Create signal handler for SIGTERM & SIGINT

* Add tests for service task cancellation

* Implement service task cancellation handling

* Use simple sync signal handler that cancels all asyncio tasks

* Add tests for signal handler

* Move asyncio sleep to end of loop

* Remove `_only_once` test specific parameter by introducing a `RuntimeError` in tests instead
  • Loading branch information
SRv6d authored Mar 1, 2024
1 parent 4be226d commit 638303a
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 24 deletions.
28 changes: 27 additions & 1 deletion src/anycastd/core/_run.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import asyncio
import signal
import sys
from collections.abc import Iterable
from functools import partial
from typing import NoReturn

import structlog

from anycastd._configuration import MainConfiguration, config_to_service
from anycastd.core._service import Service

logger = structlog.get_logger()


async def run_from_configuration(configuration: MainConfiguration) -> None:
"""Run anycastd using an instance of the main configuration."""
Expand All @@ -12,11 +20,29 @@ async def run_from_configuration(configuration: MainConfiguration) -> None:


async def run_services(services: Iterable[Service]) -> None:
"""Run services.
"""Run services until termination.
A signal handler is installed to manage termination. When a SIGTERM or SIGINT
signal is received, graceful termination is managed by the handler.
Args:
services: The services to run.
"""
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, partial(signal_handler, sig))

async with asyncio.TaskGroup() as tg:
for service in services:
tg.create_task(service.run())


def signal_handler(signal: signal.Signals) -> NoReturn:
"""Logs the received signal and terminates all tasks."""
msg = f"Received {signal.name}, terminating."

logger.info(msg)
for task in asyncio.all_tasks():
task.cancel(msg)

sys.exit(0)
41 changes: 27 additions & 14 deletions src/anycastd/core/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Service:
health_checks: tuple[Healthcheck, ...]

_healthy: bool = field(default=False, init=False, repr=False, compare=False)
_terminate: bool = field(default=False, init=False, repr=False, compare=False)
_log: structlog.typing.FilteringBoundLogger = field(
default=logger, init=False, repr=False, compare=False
)
Expand Down Expand Up @@ -83,27 +84,33 @@ def healthy(self, new_value: bool) -> None:
service_healthy=self.healthy,
)

# The _only_once parameter is only used for testing.
# TODO: Look into a better way to do this.
async def run(self, *, _only_once: bool = False) -> None:
async def run(self) -> None:
"""Run the service.
This will announce the prefixes when all health checks are
passing, and denounce them otherwise.
passing, and denounce them otherwise. If the returned coroutine is cancelled,
the service will be terminated, denouncing all prefixes in the process.
"""
self._log.info(f"Starting service {self.name}.", service_healthy=self.healthy)
while True:
checks_currently_healthy: bool = await self.all_checks_healthy()
try:
while not self._terminate:
checks_currently_healthy: bool = await self.all_checks_healthy()

if checks_currently_healthy and not self.healthy:
self.healthy = True
await self.announce_all_prefixes()
elif not checks_currently_healthy and self.healthy:
self.healthy = False
await self.denounce_all_prefixes()
if checks_currently_healthy and not self.healthy:
self.healthy = True
await self.announce_all_prefixes()
elif not checks_currently_healthy and self.healthy:
self.healthy = False
await self.denounce_all_prefixes()

if _only_once:
break
await asyncio.sleep(0.05)

except asyncio.CancelledError:
self._log.debug(
f"Coroutine for service {self.name} was cancelled.",
service_healthy=self.healthy,
)
await self.terminate()

async def all_checks_healthy(self) -> bool:
"""Runs all checks and returns their cumulative result.
Expand Down Expand Up @@ -144,3 +151,9 @@ async def denounce_all_prefixes(self) -> None:
async with asyncio.TaskGroup() as tg:
for prefix in self.prefixes:
tg.create_task(prefix.denounce())

async def terminate(self) -> None:
"""Terminate the service and denounce its prefixes."""
self._terminate = True
await self.denounce_all_prefixes()
logger.info(f"Service {self.name} terminated.", service=self.name)
56 changes: 55 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import signal

import pytest
from anycastd.core._run import run_services
from anycastd.core._run import run_services, signal_handler
from anycastd.core._service import Service
from structlog.testing import capture_logs


@pytest.fixture
Expand All @@ -13,3 +17,53 @@ async def test_future_created_for_each_service(mock_services):
"""A future is created for each service."""
await run_services(mock_services)
assert all(mock_service.run.called for mock_service in mock_services)


@pytest.mark.parametrize("signal_to_handle", [signal.SIGTERM, signal.SIGINT])
async def test_run_services_installs_signal_handlers(
mocker, mock_services, signal_to_handle: signal.Signals
):
"""A handler is installed for signals we want to handle."""
mock_loop = mocker.create_autospec(asyncio.AbstractEventLoop)
mocker.patch("anycastd.core._run.asyncio.get_event_loop", return_value=mock_loop)

await run_services(mock_services)

assert (
mocker.call(signal_to_handle, mocker.ANY)
in mock_loop.add_signal_handler.mock_calls
)


def test_signal_handler_logs_signal(mocker):
"""The signal handler logs the received signal."""
mocker.patch("anycastd.core._run.asyncio")
mocker.patch("anycastd.core._run.sys")

with capture_logs() as logs:
signal_handler(signal.SIGTERM)

assert logs[0]["event"] == "Received SIGTERM, terminating."
assert logs[0]["log_level"] == "info"


def test_signal_handler_cancels_all_tasks(mocker):
"""The signal handler cancels all tasks."""
tasks = [mocker.create_autospec(asyncio.Task) for _ in range(3)]
mocker.patch("anycastd.core._run.asyncio.all_tasks", return_value=tasks)
mocker.patch("anycastd.core._run.sys")

signal_handler(signal.SIGTERM)

for task in tasks:
task.cancel.assert_called_once_with("Received SIGTERM, terminating.")


def test_signal_handler_exits_with_zero(mocker):
"""The signal handler exits with returncode zero."""
mocker.patch("anycastd.core._run.asyncio")
mock_sys = mocker.patch("anycastd.core._run.sys")

signal_handler(signal.SIGTERM)

mock_sys.exit.assert_called_once_with(0)
88 changes: 80 additions & 8 deletions tests/test_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest
from anycastd.core import Service
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -35,16 +37,37 @@ def example_service_w_mock_checks(mocker: MockerFixture, example_service) -> Ser
return example_service


async def test_run_awaits_all_checks(mocker: MockerFixture, example_service):
@pytest.fixture
def patch_asyncio_sleep_to_raise(mocker: MockerFixture) -> None:
"""Patch the asyncio.sleep function to raise an exception.
This is useful to terminate the services run loop at the end of the first execution
in tests, instead of running it indefinitely.
"""
mocker.patch(
"anycastd.core._service.asyncio.sleep",
side_effect=RuntimeError("Exit loop"),
)


async def test_run_awaits_all_checks(
mocker: MockerFixture, patch_asyncio_sleep_to_raise, example_service
):
"""When run, the service awaits the status of all its health checks."""
mock_all_checks_healthy = mocker.patch.object(example_service, "all_checks_healthy")
await example_service.run(_only_once=True)

with pytest.raises(RuntimeError, match="Exit loop"):
await example_service.run()

mock_all_checks_healthy.assert_awaited_once()


@pytest.mark.parametrize("was_healthy", [True, False])
async def test_run_announces_all_when_health_state_changes_to_healty(
mocker: MockerFixture, example_service_w_mock_prefixes, was_healthy: bool
mocker: MockerFixture,
patch_asyncio_sleep_to_raise,
example_service_w_mock_prefixes,
was_healthy: bool,
):
"""
When run, all prefixes are announced if the all_checks_healthy method returns True
Expand All @@ -59,7 +82,8 @@ async def test_run_announces_all_when_health_state_changes_to_healty(
example_service_w_mock_prefixes, "announce_all_prefixes"
)

await example_service_w_mock_prefixes.run(_only_once=True)
with pytest.raises(RuntimeError, match="Exit loop"):
await example_service_w_mock_prefixes.run()

if not was_healthy:
mock_announce_all.assert_awaited_once()
Expand All @@ -69,7 +93,10 @@ async def test_run_announces_all_when_health_state_changes_to_healty(

@pytest.mark.parametrize("was_healthy", [True, False])
async def test_run_denounces_all_when_health_state_changes_to_unhealty(
mocker: MockerFixture, example_service_w_mock_prefixes, was_healthy: bool
mocker: MockerFixture,
patch_asyncio_sleep_to_raise,
example_service_w_mock_prefixes,
was_healthy: bool,
):
"""
When run, all prefixes are denounced if the all_checks_healthy method returns False
Expand All @@ -84,7 +111,8 @@ async def test_run_denounces_all_when_health_state_changes_to_unhealty(
example_service_w_mock_prefixes, "denounce_all_prefixes"
)

await example_service_w_mock_prefixes.run(_only_once=True)
with pytest.raises(RuntimeError, match="Exit loop"):
await example_service_w_mock_prefixes.run()

if was_healthy:
mock_denounce_all.assert_awaited_once()
Expand All @@ -93,7 +121,7 @@ async def test_run_denounces_all_when_health_state_changes_to_unhealty(


async def test_run_updates_health_state_when_changed(
mocker: MockerFixture, example_service_w_mock_prefixes
mocker: MockerFixture, patch_asyncio_sleep_to_raise, example_service_w_mock_prefixes
):
"""
When run, the service's health state is updated when the result of the
Expand All @@ -104,7 +132,8 @@ async def test_run_updates_health_state_when_changed(
example_service_w_mock_prefixes, "all_checks_healthy", return_value=True
)

await example_service_w_mock_prefixes.run(_only_once=True)
with pytest.raises(RuntimeError, match="Exit loop"):
await example_service_w_mock_prefixes.run()

assert example_service_w_mock_prefixes.healthy is True

Expand Down Expand Up @@ -259,3 +288,46 @@ async def test_denounce_all_prefixes_awaits_denounce_of_all_prefixes(
await example_service_w_mock_prefixes.denounce_all_prefixes()
for mock_prefix in example_service_w_mock_prefixes.prefixes:
mock_prefix.denounce.assert_awaited_once()


async def test_run_coro_cancellation_logs_termination(example_service, mocker):
"""When the run coroutine is cancelled, the termination is logged."""
# Create a task to run the service loop
run_task = asyncio.create_task(example_service.run())
# Give the event loop some time to poll the task above
await asyncio.sleep(0.3)

with capture_logs() as logs:
# Cancel the task
run_task.cancel()
# Give the event loop some time to cancel
await asyncio.sleep(0.2)

assert (
logs[0]["event"]
== f"Coroutine for service {example_service.name} was cancelled."
)
assert logs[0]["log_level"] == "debug"
assert logs[0]["service_name"] == example_service.name
assert logs[0]["service_healthy"] == example_service.healthy


async def test_run_coro_cancellation_awaits_termination(example_service, mocker):
"""When the run coroutine is cancelled, service termination is awaited."""
mock_terminate = mocker.patch.object(example_service, "terminate")
mock_all_checks_healthy = mocker.patch.object(example_service, "all_checks_healthy")
mock_all_checks_healthy.return_value = True

# Create a task to run the service loop
run_task = asyncio.create_task(example_service.run())
# Give the event loop some time to poll the task above
await asyncio.sleep(0.3)
# Sanity check to make sure the coro actually ran
mock_all_checks_healthy.assert_awaited()

# Cancel the task
run_task.cancel()
# Give the event loop some time to cancel
await asyncio.sleep(0.2)

mock_terminate.assert_awaited_once()

0 comments on commit 638303a

Please sign in to comment.