Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Healthcheck & Prefix as Protocols instead of ABCs requiring inheritance #31

Merged
merged 7 commits into from
Dec 23, 2023
6 changes: 6 additions & 0 deletions src/anycastd/core/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class Service:
prefixes: tuple[Prefix, ...]
health_checks: tuple[Healthcheck, ...]

def __post_init__(self) -> None:
if not all(isinstance(_, Prefix) for _ in self.prefixes):
raise TypeError("Prefixes must implement the Prefix protocol")
if not all(isinstance(_, Healthcheck) for _ in self.health_checks):
raise TypeError("Health checks must implement the Healthcheck protocol")

# The _only_once parameter is only used for testing.
# TODO: Look into a better way to do this.
async def run(self, *, _only_once: bool = False) -> None:
Expand Down
47 changes: 19 additions & 28 deletions src/anycastd/healthcheck/_cabourotte/main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,32 @@
import datetime
from dataclasses import dataclass, field

from anycastd.healthcheck._cabourotte.result import get_result
from anycastd.healthcheck._main import Healthcheck
from anycastd.healthcheck._common import CheckCoroutine, interval_check


class CabourotteHealthcheck(Healthcheck):
@dataclass
class CabourotteHealthcheck:
name: str
url: str
url: str = field(kw_only=True)
interval: datetime.timedelta = field(kw_only=True)

def __init__(self, name: str, *, url: str, interval: datetime.timedelta):
if not isinstance(interval, datetime.timedelta):
_check: CheckCoroutine = field(init=False, repr=False, compare=False)

def __post_init__(self) -> None:
if not isinstance(self.interval, datetime.timedelta):
raise TypeError("Interval must be a timedelta.")
if not isinstance(name, str):
if not isinstance(self.name, str):
raise TypeError("Name must be a string.")
if not isinstance(url, str):
if not isinstance(self.url, str):
raise TypeError("URL must be a string.")
self.name = name
self.url = url
self.__interval = interval

def __repr__(self) -> str:
return (
f"CabourotteHealthcheck(name={self.name!r}, url={self.url!r}, "
f"interval={self.interval!r})"
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, CabourotteHealthcheck):
return NotImplemented

return self.__dict__ == other.__dict__
self._check = interval_check(self.interval, self._get_status)

@property
def interval(self) -> datetime.timedelta:
return self.__interval

async def _check(self) -> bool:
"""Return whether the healthcheck is healthy or not."""
async def _get_status(self) -> bool:
"""Get the current status of the check as reported by cabourotte."""
result = await get_result(self.name, url=self.url)
return result.success

async def is_healthy(self) -> bool:
"""Return whether the healthcheck is healthy or not."""
return await self._check()
38 changes: 38 additions & 0 deletions src/anycastd/healthcheck/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from typing import TypeAlias

CheckCoroutine: TypeAlias = Callable[[], Awaitable[bool]]


def interval_check(interval: timedelta, check: CheckCoroutine) -> CheckCoroutine:
"""Wrap a check coroutine to only evaluate it if a given interval has passed.

Wraps a given check coroutine to only evaluate it if a given interval has passed,
returning the last result otherwise.

Args:
interval: The interval to wait between evaluations.
check: A check coroutine to be evaluated.

Returns:
A coroutine returning either the result of the given check coroutine or the
last result returned by it if the interval has not passed.
"""
last_checked: datetime | None = None
last_healthy: bool = False

async def _check() -> bool:
nonlocal last_checked, last_healthy

if last_checked is None or datetime.now(timezone.utc) - last_checked > interval:
healthy = await check()

last_checked = datetime.now(timezone.utc)
last_healthy = healthy

return healthy

return last_healthy

return _check
48 changes: 6 additions & 42 deletions src/anycastd/healthcheck/_main.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,10 @@
import datetime
from abc import ABC, abstractmethod
from typing import final
from typing import Protocol, runtime_checkable


class Healthcheck(ABC):
"""A healthcheck that represents a status."""
@runtime_checkable
class Healthcheck(Protocol):
"""A health check representing the status of a component."""

_last_checked: datetime.datetime | None = None
_last_healthy: bool = False

@property
@abstractmethod
def interval(self) -> datetime.timedelta:
"""The interval between checks."""
raise NotImplementedError

@final
async def is_healthy(self) -> bool:
"""Whether the healthcheck is healthy.

Runs a check and returns it's result if the interval has passed or no
previous check has been run. Otherwise, the previous result is returned.
"""
if (
self._last_checked is None
or datetime.datetime.now(datetime.timezone.utc) - self._last_checked
> self.interval
):
healthy = await self._check()

self._last_checked = datetime.datetime.now(datetime.timezone.utc)
self._last_healthy = healthy

return healthy

return self._last_healthy

@abstractmethod
async def _check(self) -> bool:
"""Run the healthcheck.

This method must be implemented by subclasses and return a boolean
indicating whether the healthcheck passed or failed.
"""
raise NotImplementedError
"""Whether the health checked component is healthy or not."""
...
2 changes: 1 addition & 1 deletion src/anycastd/prefix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from anycastd.prefix._frrouting.main import FRRoutingPrefix
from anycastd.prefix._main import VRF, Prefix
from anycastd.prefix._main import AFI, VRF, Prefix
33 changes: 18 additions & 15 deletions src/anycastd/prefix/_frrouting/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import suppress
from ipaddress import IPv4Network, IPv6Network
from pathlib import Path
from typing import Self, cast
from typing import Self, assert_never, cast

from anycastd._executor import Executor
from anycastd.prefix._frrouting.exceptions import (
Expand All @@ -12,10 +12,10 @@
FRRInvalidVTYSHError,
FRRNoBGPError,
)
from anycastd.prefix._main import VRF, Prefix
from anycastd.prefix._main import AFI, VRF


class FRRoutingPrefix(Prefix):
class FRRoutingPrefix:
vrf: VRF
vtysh: Path
executor: Executor
Expand Down Expand Up @@ -57,16 +57,26 @@ def __eq__(self, other: object) -> bool:
def prefix(self) -> IPv4Network | IPv6Network:
return self.__prefix

@property
def afi(self) -> AFI:
"""The address family of the prefix."""
match self.prefix:
case IPv4Network():
return AFI.IPv4
case IPv6Network():
return AFI.IPv6
case _ as unreachable:
assert_never(unreachable)

async def is_announced(self) -> bool:
"""Returns True if the prefix is announced.

Checks if the respective BGP prefix is configured in the default VRF.
"""
family = get_afi(self)
cmd = (
f"show bgp vrf {self.vrf} {family} unicast {self.prefix} json"
f"show bgp vrf {self.vrf} {self.afi} unicast {self.prefix} json"
if self.vrf
else f"show bgp {family} unicast {self.prefix} json"
else f"show bgp {self.afi} unicast {self.prefix} json"
)
show_prefix = await self._run_vtysh_commands((cmd,))
prefix_info = json.loads(show_prefix)
Expand All @@ -85,14 +95,13 @@ async def announce(self) -> None:

Adds the respective BGP prefix to the default VRF.
"""
family = get_afi(self)
asn = await self._get_local_asn()

await self._run_vtysh_commands(
(
"configure terminal",
f"router bgp {asn} vrf {self.vrf}" if self.vrf else f"router bgp {asn}",
f"address-family {family} unicast",
f"address-family {self.afi} unicast",
f"network {self.prefix}",
)
)
Expand All @@ -102,14 +111,13 @@ async def denounce(self) -> None:

Removes the respective BGP prefix from the default VRF.
"""
family = get_afi(self)
asn = await self._get_local_asn()

await self._run_vtysh_commands(
(
"configure terminal",
f"router bgp {asn} vrf {self.vrf}" if self.vrf else f"router bgp {asn}",
f"address-family {family} unicast",
f"address-family {self.afi} unicast",
f"no network {self.prefix}",
)
)
Expand Down Expand Up @@ -203,8 +211,3 @@ async def new(
return await cls(
prefix=prefix, vrf=vrf, vtysh=vtysh, executor=executor
).validate()


def get_afi(prefix: Prefix) -> str:
"""Return the FRR string AFI for the given IP type."""
return "ipv6" if not isinstance(prefix.prefix, IPv4Network) else "ipv4"
49 changes: 23 additions & 26 deletions src/anycastd/prefix/_main.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
from abc import ABC, abstractmethod
from enum import StrEnum
from ipaddress import IPv4Network, IPv6Network
from typing import TypeAlias
from typing import Protocol, TypeAlias, runtime_checkable

VRF: TypeAlias = str | None


class Prefix(ABC):
class AFI(StrEnum):
"""The IP address family."""

IPv4 = "ipv4"
IPv6 = "ipv6"


@runtime_checkable
class Prefix(Protocol):
"""An IP prefix that can be announced or denounced."""

@property
@abstractmethod
def prefix(self) -> IPv4Network | IPv6Network:
"""The IP prefix."""
raise NotImplementedError
...

@abstractmethod
async def is_announced(self) -> bool:
"""Whether the prefix is currently announced.
@property
def afi(self) -> AFI:
"""The address family of the prefix."""
...

This method must be implemented by subclasses and return a boolean
indicating whether the prefix is currently announced.
"""
raise NotImplementedError
async def is_announced(self) -> bool:
"""Whether the prefix is currently announced."""
...

@abstractmethod
async def announce(self) -> None:
"""Announce the prefix.
"""Announce the prefix."""
...

This method must be implemented by subclasses and announce the
prefix if it isn't announced already.
"""
raise NotImplementedError

@abstractmethod
async def denounce(self) -> None:
"""Denounce the prefix.

This method must be implemented by subclasses and denounce the
prefix if it is announced.
"""
raise NotImplementedError
"""Denounce the prefix."""
...
13 changes: 2 additions & 11 deletions tests/dummy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import datetime
from ipaddress import IPv4Network, IPv6Network

from anycastd.healthcheck import Healthcheck
from anycastd.prefix import Prefix


class DummyHealthcheck(Healthcheck):
class DummyHealthcheck:
"""A dummy healthcheck to test the abstract base class."""

def __init__(self, interval: datetime.timedelta, *args, **kwargs):
self.__interval = interval

@property
def interval(self) -> datetime.timedelta:
return self.__interval

async def _check(self) -> bool:
async def is_healthy(self) -> bool:
"""Always healthy."""
return True

Expand Down
10 changes: 5 additions & 5 deletions tests/healthcheck/cabourotte/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,23 @@ def test_non_equal():


@pytest.mark.asyncio
async def test__check_awaits_get_result(mocker: MockerFixture):
"""The check method awaits the result of get_result."""
async def test_get_status_awaits_get_result(mocker: MockerFixture):
"""The get status method awaits the result of get_result."""
name = "test"
url = "https://example.com"
healthcheck = CabourotteHealthcheck(
name, url=url, interval=datetime.timedelta(seconds=10)
)
mock_get_result = mocker.patch("anycastd.healthcheck._cabourotte.main.get_result")

await healthcheck._check()
await healthcheck._get_status()

mock_get_result.assert_awaited_once_with(name, url=url)


@pytest.mark.parametrize("success", [True, False])
@pytest.mark.asyncio
async def test__check_returns_result_success(success: bool, mocker: MockerFixture):
async def test_get_status_returns_result(success: bool, mocker: MockerFixture):
"""The check method returns True if the result is successful and False otherwise."""
healthcheck = CabourotteHealthcheck(
"test", url="https://example.com", interval=datetime.timedelta(seconds=10)
Expand All @@ -105,4 +105,4 @@ async def test__check_returns_result_success(success: bool, mocker: MockerFixtur
"anycastd.healthcheck._cabourotte.main.get_result", return_value=mock_result
)

assert await healthcheck._check() == success
assert await healthcheck._get_status() == success
Loading
Loading