Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.bulkhead.wrapper import BulkheadWrapper

__all__ = ["BulkheadWrapper"]
129 changes: 129 additions & 0 deletions key-value/key-value-aio/src/key_value/aio/wrappers/bulkhead/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import asyncio
from collections.abc import Callable, Coroutine, Mapping, Sequence
from typing import Any, SupportsFloat, TypeVar

from key_value.shared.errors.wrappers.bulkhead import BulkheadFullError
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper

T = TypeVar("T")


class BulkheadWrapper(BaseWrapper):
"""Wrapper that implements the bulkhead pattern to isolate operations with resource pools.

This wrapper limits the number of concurrent operations and queued operations to prevent
resource exhaustion and isolate failures. The bulkhead pattern is inspired by ship bulkheads
that prevent a single hull breach from sinking the entire ship.

Benefits:
- Prevents a single slow or failing backend from consuming all resources
- Limits concurrent requests to protect backend from overload
- Provides bounded queue to prevent unbounded memory growth
- Enables graceful degradation under high load

Example:
bulkhead = BulkheadWrapper(
key_value=store,
max_concurrent=10, # Max 10 concurrent operations
max_waiting=20, # Max 20 operations can wait in queue
)

try:
await bulkhead.get(key="mykey")
except BulkheadFullError:
# Too many concurrent operations, system is overloaded
# Handle gracefully (return cached value, error response, etc.)
pass
"""

def __init__(
self,
key_value: AsyncKeyValue,
max_concurrent: int = 10,
max_waiting: int = 20,
) -> None:
"""Initialize the bulkhead wrapper.

Args:
key_value: The store to wrap.
max_concurrent: Maximum number of concurrent operations. Defaults to 10.
max_waiting: Maximum number of operations that can wait in queue. Defaults to 20.
"""
self.key_value: AsyncKeyValue = key_value
self.max_concurrent: int = max_concurrent
self.max_waiting: int = max_waiting

# Use semaphore to limit concurrent operations
self._semaphore: asyncio.Semaphore = asyncio.Semaphore(max_concurrent)
self._waiting_count: int = 0
self._waiting_lock: asyncio.Lock = asyncio.Lock()

super().__init__()

async def _execute_with_bulkhead(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T:
"""Execute an operation with bulkhead resource limiting."""
# Check if we can accept this operation
async with self._waiting_lock:
if self._waiting_count >= self.max_waiting:
raise BulkheadFullError(max_concurrent=self.max_concurrent, max_waiting=self.max_waiting)
self._waiting_count += 1

try:
# Acquire semaphore to limit concurrency
async with self._semaphore:
# Once we have the semaphore, we're no longer waiting
async with self._waiting_lock:
self._waiting_count -= 1

# Execute the operation
return await operation(*args, **kwargs)
except Exception:
# Make sure to decrement waiting count if we error before acquiring semaphore
async with self._waiting_lock:
# Only decrement if we're still counted as waiting
# (might have already decremented if we got the semaphore)
if self._waiting_count > 0 and self._semaphore.locked():
self._waiting_count -= 1
raise

@override
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
return await self._execute_with_bulkhead(self.key_value.get, key=key, collection=collection)

@override
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
return await self._execute_with_bulkhead(self.key_value.get_many, keys=keys, collection=collection)

@override
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
return await self._execute_with_bulkhead(self.key_value.ttl, key=key, collection=collection)

@override
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
return await self._execute_with_bulkhead(self.key_value.ttl_many, keys=keys, collection=collection)

@override
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
return await self._execute_with_bulkhead(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl)

@override
async def put_many(
self,
keys: Sequence[str],
values: Sequence[Mapping[str, Any]],
*,
collection: str | None = None,
ttl: SupportsFloat | None = None,
) -> None:
return await self._execute_with_bulkhead(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl)

@override
async def delete(self, key: str, *, collection: str | None = None) -> bool:
return await self._execute_with_bulkhead(self.key_value.delete, key=key, collection=collection)

@override
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
return await self._execute_with_bulkhead(self.key_value.delete_many, keys=keys, collection=collection)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.circuit_breaker.wrapper import CircuitBreakerWrapper

__all__ = ["CircuitBreakerWrapper"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import time
from collections.abc import Callable, Coroutine, Mapping, Sequence
from enum import Enum
from typing import Any, SupportsFloat, TypeVar

from key_value.shared.errors.wrappers.circuit_breaker import CircuitOpenError
from typing_extensions import override

from key_value.aio.protocols.key_value import AsyncKeyValue
from key_value.aio.wrappers.base import BaseWrapper

T = TypeVar("T")


class CircuitState(Enum):
"""States for the circuit breaker."""

CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, blocking requests
HALF_OPEN = "half_open" # Testing if service recovered


class CircuitBreakerWrapper(BaseWrapper):
"""Wrapper that implements the circuit breaker pattern to prevent cascading failures.

This wrapper tracks operation failures and opens the circuit after a threshold of consecutive
failures. When the circuit is open, requests are blocked immediately without attempting the
operation. After a recovery timeout, the circuit moves to half-open state to test if the
backend has recovered.

The circuit breaker pattern is essential for production resilience as it:
- Prevents cascading failures when a backend becomes unhealthy
- Reduces load on failing backends, giving them time to recover
- Provides fast failure responses instead of waiting for timeouts
- Automatically attempts recovery after a configured timeout

Example:
circuit_breaker = CircuitBreakerWrapper(
key_value=store,
failure_threshold=5, # Open after 5 consecutive failures
recovery_timeout=30.0, # Try recovery after 30 seconds
success_threshold=2, # Close after 2 successes in half-open
)

try:
value = await circuit_breaker.get(key="mykey")
except CircuitOpenError:
# Circuit is open, backend is considered unhealthy
# Handle gracefully (use cache, return default, etc.)
pass
"""

def __init__(
self,
key_value: AsyncKeyValue,
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
success_threshold: int = 2,
error_types: tuple[type[Exception], ...] = (Exception,),
) -> None:
"""Initialize the circuit breaker wrapper.

Args:
key_value: The store to wrap.
failure_threshold: Number of consecutive failures before opening the circuit. Defaults to 5.
recovery_timeout: Seconds to wait before attempting recovery (moving to half-open). Defaults to 30.0.
success_threshold: Number of consecutive successes in half-open state before closing the circuit. Defaults to 2.
error_types: Tuple of exception types that count as failures. Defaults to (Exception,).
"""
self.key_value: AsyncKeyValue = key_value
self.failure_threshold: int = failure_threshold
self.recovery_timeout: float = recovery_timeout
self.success_threshold: int = success_threshold
self.error_types: tuple[type[Exception], ...] = error_types

# Circuit state
self._state: CircuitState = CircuitState.CLOSED
self._failure_count: int = 0
self._success_count: int = 0
self._last_failure_time: float | None = None

super().__init__()

def _check_circuit(self) -> None:
"""Check the circuit state and potentially transition states."""
if self._state == CircuitState.OPEN:
# Check if we should move to half-open
if self._last_failure_time is not None and time.time() - self._last_failure_time >= self.recovery_timeout:
self._state = CircuitState.HALF_OPEN
self._success_count = 0
else:
# Circuit is still open, raise error
raise CircuitOpenError(failure_count=self._failure_count, last_failure_time=self._last_failure_time)

def _on_success(self) -> None:
"""Handle successful operation."""
if self._state == CircuitState.HALF_OPEN:
self._success_count += 1
if self._success_count >= self.success_threshold:
# Close the circuit
self._state = CircuitState.CLOSED
self._failure_count = 0
self._success_count = 0
elif self._state == CircuitState.CLOSED:
# Reset failure count on success
self._failure_count = 0

def _on_failure(self) -> None:
"""Handle failed operation."""
self._last_failure_time = time.time()

if self._state == CircuitState.HALF_OPEN:
# Failed in half-open, go back to open
self._state = CircuitState.OPEN
self._success_count = 0
elif self._state == CircuitState.CLOSED:
self._failure_count += 1
if self._failure_count >= self.failure_threshold:
# Open the circuit
self._state = CircuitState.OPEN

async def _execute_with_circuit_breaker(self, operation: Callable[..., Coroutine[Any, Any, T]], *args: Any, **kwargs: Any) -> T:
"""Execute an operation with circuit breaker logic."""
self._check_circuit()

try:
result = await operation(*args, **kwargs)
except self.error_types:
self._on_failure()
raise
else:
self._on_success()
return result

@override
async def get(self, key: str, *, collection: str | None = None) -> dict[str, Any] | None:
return await self._execute_with_circuit_breaker(self.key_value.get, key=key, collection=collection)

@override
async def get_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[dict[str, Any] | None]:
return await self._execute_with_circuit_breaker(self.key_value.get_many, keys=keys, collection=collection)

@override
async def ttl(self, key: str, *, collection: str | None = None) -> tuple[dict[str, Any] | None, float | None]:
return await self._execute_with_circuit_breaker(self.key_value.ttl, key=key, collection=collection)

@override
async def ttl_many(self, keys: Sequence[str], *, collection: str | None = None) -> list[tuple[dict[str, Any] | None, float | None]]:
return await self._execute_with_circuit_breaker(self.key_value.ttl_many, keys=keys, collection=collection)

@override
async def put(self, key: str, value: Mapping[str, Any], *, collection: str | None = None, ttl: SupportsFloat | None = None) -> None:
return await self._execute_with_circuit_breaker(self.key_value.put, key=key, value=value, collection=collection, ttl=ttl)

@override
async def put_many(
self,
keys: Sequence[str],
values: Sequence[Mapping[str, Any]],
*,
collection: str | None = None,
ttl: SupportsFloat | None = None,
) -> None:
return await self._execute_with_circuit_breaker(self.key_value.put_many, keys=keys, values=values, collection=collection, ttl=ttl)

@override
async def delete(self, key: str, *, collection: str | None = None) -> bool:
return await self._execute_with_circuit_breaker(self.key_value.delete, key=key, collection=collection)

@override
async def delete_many(self, keys: Sequence[str], *, collection: str | None = None) -> int:
return await self._execute_with_circuit_breaker(self.key_value.delete_many, keys=keys, collection=collection)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from key_value.aio.wrappers.rate_limit.wrapper import RateLimitWrapper

__all__ = ["RateLimitWrapper"]
Loading
Loading