Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions app/api/middleware/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from fastapi import Request, Response
from app.services.cache_service import cache_service
from app.core.config import settings
import json

class CacheMiddleware:
"""Middleware to cache GET responses."""

def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
# Only cache HTTP GET requests
if scope["type"] == "http" and scope["method"] == "GET":
request = Request(scope, receive)
key = request.url.path + "?" + (request.url.query or "")
if settings.CACHE_ENABLED:
# Try to fetch from cache
cached = await cache_service.get(key)
if cached:
# Return cached response
headers = {"content-type": "application/json"}
response = Response(content=cached, status_code=200, headers=headers)
await response(scope, receive, send)
return

# Capture the response
responder = _ResponseCatcher(self.app, key)
await responder(scope, receive, send)
return

# Non-GET or caching disabled: continue normally
await self.app(scope, receive, send)


class _ResponseCatcher:
"""Helper to capture response body and cache it."""

def __init__(self, app: ASGIApp, key: str):
self.app = app
self.key = key
self.body = b""
self.status_code = 200
self.headers = {}

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def send_wrapper(message):
if message["type"] == "http.response.start":
self.status_code = message["status"]
self.headers = dict(message.get("headers", []))
elif message["type"] == "http.response.body":
self.body += message.get("body", b"")
await send(message)

await self.app(scope, receive, send_wrapper)

# Cache successful GET responses
if self.status_code == 200:
await cache_service.set(self.key, self.body.decode(), settings.CACHE_TTL_DEFAULT)
37 changes: 37 additions & 0 deletions app/api/v1/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from fastapi import APIRouter, HTTPException
from app.services.cache_service import cache_service
from app.core.config import settings

router = APIRouter()
PREFIX = settings.API_PREFIX + "/cache"

@router.get("/status")
async def cache_status():
"""
Get cache status (enabled flag and Redis connection health).
"""
enabled = settings.CACHE_ENABLED
try:
pong = await cache_service.redis.ping()
healthy = pong is True
except Exception as e:
healthy = False
return {"enabled": enabled, "healthy": healthy}

@router.post("/clear")
async def cache_clear():
"""
Clear the entire cache.
"""
if not settings.CACHE_ENABLED:
raise HTTPException(status_code=400, detail="Caching is disabled")
await cache_service.clear()
return {"status": "cleared"}

@router.get("/stats")
async def cache_stats():
"""
Retrieve cache hit/miss statistics.
"""
stats = cache_service.stats()
return stats
37 changes: 35 additions & 2 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ class Settings(BaseSettings):
default="https://graph.microsoft.com/v1.0",
description="Base URL for Microsoft Graph API",
)
# Redis Configuration
REDIS_URL: Optional[str] = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
REDIS_HOST: Optional[str] = Field(
default="localhost",
description="Redis host",
)
REDIS_PORT: Optional[int] = Field(
default=6379,
description="Redis port",
)
REDIS_DB: Optional[int] = Field(
default=0,
description="Redis database index",
)
REDIS_PASSWORD: Optional[str] = Field(
default=None,
description="Redis password",
)

# Cache Settings
CACHE_ENABLED: bool = Field(
default=True,
description="Enable or disable caching",
)
CACHE_TTL_DEFAULT: int = Field(
default=300,
description="Default TTL (seconds) for cache entries",
)
CACHE_KEY_PREFIX: str = Field(
default="autoaudit",
description="Prefix for all cache keys",
)

# Database (needed by health checks)
DATABASE_URL: Optional[str] = Field(
Expand Down Expand Up @@ -111,10 +146,8 @@ def _db_url_if_set(cls, v: Optional[str]) -> Optional[str]:

class Config:
"""Pydantic configuration."""

env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = True


settings = Settings()
37 changes: 25 additions & 12 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse

from app.core.config import settings
from app.api.v1 import health
from app.api.v1 import auth
from app.api.v1 import auth, cache, health
from app.api.v1.graph import router as graph_router
from app.api.middleware.cache_middleware import CacheMiddleware
from app.utils.logger import logger
from app.api.v1 import graph


def create_app() -> FastAPI:
Expand Down Expand Up @@ -37,6 +38,9 @@ def configure_middleware(app: FastAPI, settings):
allow_headers=["*"],
)

# Caches GET responses; keep enabled for /graph/* and any other GETs
app.add_middleware(CacheMiddleware)

# Trusted host middleware
app.add_middleware(
TrustedHostMiddleware,
Expand All @@ -48,7 +52,9 @@ def configure_middleware(app: FastAPI, settings):


def configure_routing(app: FastAPI, settings):
# ----------------------------
# Authentication endpoints
# ----------------------------
app.include_router(
auth.router,
prefix=f"{settings.API_PREFIX}/auth",
Expand All @@ -64,12 +70,22 @@ def configure_routing(app: FastAPI, settings):

)

# Graph API endpoints

# Graph API endpoints
# Mounted under /api/v1/graph/*
app.include_router(
graph.router,
graph_router,
prefix=f"{settings.API_PREFIX}/graph",
tags=["Graph API"],
responses={404: {"description": "Not found & Unsuccessfull"}}, #need to change this later
)


# Cache endpoints
# /api/v1/cache/status, /api/v1/cache/clear, /api/v1/cache/stats
app.include_router(
cache.router,
prefix=f"{settings.API_PREFIX}/cache",
tags=["Cache"],
)


Expand Down Expand Up @@ -119,11 +135,8 @@ async def root():
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"version": settings.VERSION,
}
return {"status": "healthy", "version": settings.VERSION}

# At the very bottom of main.py
app = create_app()

# App entrypoint
app = create_app()
66 changes: 66 additions & 0 deletions app/services/cache_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import aioredis
import asyncio
from typing import Optional
from app.core.config import settings
from structlog import get_logger

logger = get_logger()

class CacheService:
"""Redis-based cache service for AutoAudit API."""

def __init__(self):
# Initialize Redis connection pool
self.redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
# Stats
self.hits = 0
self.misses = 0

async def get(self, key: str) -> Optional[str]:
"""Retrieve a value from cache."""
# Use namespaced cache keys to avoid collisions across environments/projects
value = await self.redis.get(f"{settings.CACHE_KEY_PREFIX}:{key}")
if value is None:
self.misses += 1
logger.debug("Cache miss", key=key)
else:
self.hits += 1
logger.debug("Cache hit", key=key)
return value

async def set(self, key: str, value: str, ttl: Optional[int] = None) -> None:
"""Set a value in cache with TTL."""
expire = ttl or settings.CACHE_TTL_DEFAULT
await self.redis.set(
f"{settings.CACHE_KEY_PREFIX}:{key}",
value,
ex=expire,
)
logger.debug("Cache set", key=key, ttl=expire)

async def delete(self, key: str) -> None:
"""Delete a key from cache."""
await self.redis.delete(f"{settings.CACHE_KEY_PREFIX}:{key}")
logger.debug("Cache delete", key=key)

async def clear(self) -> None:
"""Clear the entire cache (use with caution)."""
await self.redis.flushdb()
logger.warning("Cache cleared")

def stats(self) -> dict:
"""Return cache hit/miss statistics."""
total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": f"{hit_rate:.2f}%",
}

# Instantiate a singleton service
cache_service = CacheService()
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ dependencies = [
"uvicorn>=0.35.0",
"python-dotenv>=1.1.1",
"structlog>=25.4.0",
"redis>=6.4.0",
"aioredis>=2.0.1",
]

[project.optional-dependencies]
Expand All @@ -22,6 +24,8 @@ dev = [

[tool.uv]
dev-dependencies = [
"fakeredis>=2.31.0",
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
"pytest-redis>=3.1.3",
]
Loading