Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backend/app/api/v1/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
router = APIRouter()
logger = logging.getLogger(__name__)

# Backward-compatible test aliases for the shared rate limiter.
_scan_attempts_by_user = guard_scan_rate_limiter._local_attempts_by_key
_RATE_LIMIT_REQUESTS = settings.GUARD_RATE_LIMIT_REQUESTS


Expand Down
45 changes: 45 additions & 0 deletions backend/app/core/rate_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 30,
cleanup_interval: int = 100,
) -> None:
self._local_attempts_by_key: dict[str, deque[datetime]] = defaultdict(deque)
self._local_window_seconds_by_key: dict[str, int] = {}
self._local_lock = Lock()
self._local_cleanup_interval = max(1, cleanup_interval)
self._local_cleanup_calls = 0
self._redis_client: Optional[object] = None
self._redis_script: Optional[object] = None

Expand All @@ -59,6 +63,38 @@ def __init__(
"failures_open": 0,
}

def clear_local_attempts(self) -> None:
"""Clear the in-memory fallback state used when Redis is unavailable."""
with self._local_lock:
self._local_attempts_by_key.clear()
self._local_window_seconds_by_key.clear()
self._local_cleanup_calls = 0

def cleanup_stale_local_attempts(self, now: Optional[datetime] = None) -> int:
"""Prune expired in-memory keys and return how many were removed."""
with self._local_lock:
return self._cleanup_stale_local_attempts_locked(now=now)

def _cleanup_stale_local_attempts_locked(self, now: Optional[datetime] = None) -> int:
now = now or datetime.now(timezone.utc)
removed_keys = 0

for key, attempts in list(self._local_attempts_by_key.items()):
window_seconds = self._local_window_seconds_by_key.get(key)
if window_seconds is None:
window_seconds = 0

cutoff = now - timedelta(seconds=window_seconds)
while attempts and attempts[0] <= cutoff:
attempts.popleft()

if not attempts:
self._local_attempts_by_key.pop(key, None)
self._local_window_seconds_by_key.pop(key, None)
removed_keys += 1

return removed_keys

def _get_redis_client(self) -> Optional[object]:
if not settings.REDIS_URL or redis is None:
return None
Expand Down Expand Up @@ -104,6 +140,10 @@ def _check_local(
window_start = now - timedelta(seconds=window_seconds)

with self._local_lock:
existing_window = self._local_window_seconds_by_key.get(key, 0)
if window_seconds > existing_window:
self._local_window_seconds_by_key[key] = window_seconds

attempts = self._local_attempts_by_key[key]

while attempts and attempts[0] <= window_start:
Expand All @@ -129,6 +169,11 @@ def _check_local(
for _ in range(cost):
attempts.append(now)

self._local_cleanup_calls += 1
if self._local_cleanup_calls >= self._local_cleanup_interval:
self._local_cleanup_calls = 0
self._cleanup_stale_local_attempts_locked(now=now)

return False, 0

def check_and_consume(
Expand Down
4 changes: 2 additions & 2 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def clear_guard_rate_limits():
from app.core.rate_limit import guard_scan_rate_limiter

# 1. Clear local memory
guard_scan_rate_limiter._local_attempts_by_key.clear()
guard_scan_rate_limiter.clear_local_attempts()

# 2. Clear Redis
redis_client = guard_scan_rate_limiter._get_redis_client()
Expand All @@ -169,6 +169,6 @@ def clear_guard_rate_limits():
yield

# Clean up after the test completes
guard_scan_rate_limiter._local_attempts_by_key.clear()
guard_scan_rate_limiter.clear_local_attempts()
if redis_client is not None:
redis_client.flushdb()
6 changes: 2 additions & 4 deletions backend/tests/integration/test_rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@

import pytest

from app.api.v1 import guard as guard_api
from app.core.rate_limit import guard_scan_rate_limiter
from app.core.security import create_access_token
from app.models.user import User


@pytest.fixture(autouse=True)
def reset_rate_limiter():
guard_mod = sys.modules.get("app.api.v1.guard")
if guard_mod is not None and hasattr(guard_mod, "_scan_attempts_by_user"):
guard_mod._scan_attempts_by_user.clear()
guard_scan_rate_limiter.clear_local_attempts()
yield


Expand Down
15 changes: 8 additions & 7 deletions backend/tests/test_guard_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def test_scan_prompt_rate_limit(authenticated_client: TestClient):
},
}

# Clear rate limit lock before testing
from app.api.v1.guard import _scan_attempts_by_user, _RATE_LIMIT_REQUESTS
_scan_attempts_by_user.clear()
# Clear rate limit state before testing
from app.api.v1.guard import _RATE_LIMIT_REQUESTS
from app.core.rate_limit import guard_scan_rate_limiter
guard_scan_rate_limiter.clear_local_attempts()

with patch("app.modules.guard.llm_guard.LLMGuard", return_value=mock_guard):
# Fire 60 requests (allowed)
Expand All @@ -159,8 +160,8 @@ def test_bulk_scan_success(authenticated_client: TestClient):
},
}

from app.api.v1.guard import _scan_attempts_by_user
_scan_attempts_by_user.clear()
from app.core.rate_limit import guard_scan_rate_limiter
guard_scan_rate_limiter.clear_local_attempts()

payload = {"prompts": ["prompt 1", "prompt 2", "prompt 3"]}

Expand Down Expand Up @@ -194,8 +195,8 @@ def test_bulk_scan_rate_limiting(authenticated_client: TestClient):
},
}

from app.api.v1.guard import _scan_attempts_by_user
_scan_attempts_by_user.clear()
from app.core.rate_limit import guard_scan_rate_limiter
guard_scan_rate_limiter.clear_local_attempts()

# limit is 60. Let's send a batch of 40.
payload_1 = {"prompts": ["p"] * 40}
Expand Down
33 changes: 32 additions & 1 deletion backend/tests/test_rate_limit_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Unit tests for the shared rate limiter helper."""

from types import SimpleNamespace
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone

from app.core import rate_limit
from app.core.config import settings
Expand Down Expand Up @@ -200,3 +200,34 @@ def test_distributed_rate_limiter_circuit_breaker_recovery(monkeypatch):
assert limited is False
assert limiter.cb_state == "CLOSED"
assert limiter.consecutive_failures == 0


def test_distributed_rate_limiter_cleans_up_stale_local_keys(monkeypatch):
"""Expired in-memory keys are removed during the periodic cleanup sweep."""
fake_now = datetime(2026, 1, 1, tzinfo=timezone.utc)

class FrozenDateTime(datetime):
current = fake_now

@classmethod
def now(cls, tz=None):
return cls.current

monkeypatch.setattr(rate_limit, "datetime", FrozenDateTime)

limiter = rate_limit.DistributedRateLimiter(cleanup_interval=1)
limited, retry_after = limiter.check_and_consume(
key="stale:key",
limit=1,
window_seconds=60,
fail_closed=False,
)
assert limited is False
assert retry_after == 0

FrozenDateTime.current = fake_now + timedelta(seconds=61)

removed = limiter.cleanup_stale_local_attempts()

assert removed == 1
assert limiter.cleanup_stale_local_attempts() == 0
Loading