Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ VERSION = "0.1.0"

# Azure AD Configuration
AZURE_CLIENT_ID = "your_client_id"
AZURE_CLIENT_SECRET = "your_client_secret"
AZURE_CLIENT_SECRET = "your_client_secrect"
AZURE_TENANT_ID = "your_tenant_id"

# Logging
Expand All @@ -19,4 +19,6 @@ ALLOWED_HOSTS = "["*"]"
API_PREFIX = "/api/v1"

# Microsoft Graph API
GRAPH_API_BASE_URL = "https://graph.microsoft.com/v1.0"
GRAPH_API_BASE_URL = "https://graph.microsoft.com/v1.0"


60 changes: 60 additions & 0 deletions app/api/middleware/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from fastapi import Request, Response
from app.services.cache_service import cache_service
from app.core.config import settings
import json

class CacheMiddleware:
"""Middleware to cache GET responses."""

def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
# Only cache HTTP GET requests
if scope["type"] == "http" and scope["method"] == "GET":
request = Request(scope, receive)
key = request.url.path + "?" + (request.url.query or "")
if settings.CACHE_ENABLED:
# Try to fetch from cache
cached = await cache_service.get(key)
if cached:
# Return cached response
headers = {"content-type": "application/json"}
response = Response(content=cached, status_code=200, headers=headers)
await response(scope, receive, send)
return

# Capture the response
responder = _ResponseCatcher(self.app, key)
await responder(scope, receive, send)
return

# Non-GET or caching disabled: continue normally
await self.app(scope, receive, send)


class _ResponseCatcher:
"""Helper to capture response body and cache it."""

def __init__(self, app: ASGIApp, key: str):
self.app = app
self.key = key
self.body = b""
self.status_code = 200
self.headers = {}

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def send_wrapper(message):
if message["type"] == "http.response.start":
self.status_code = message["status"]
self.headers = dict(message.get("headers", []))
elif message["type"] == "http.response.body":
self.body += message.get("body", b"")
await send(message)

await self.app(scope, receive, send_wrapper)

# Cache successful GET responses
if self.status_code == 200:
await cache_service.set(self.key, self.body.decode(), settings.CACHE_TTL_DEFAULT)
37 changes: 37 additions & 0 deletions app/api/v1/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from fastapi import APIRouter, HTTPException
from app.services.cache_service import cache_service
from app.core.config import settings

router = APIRouter()
PREFIX = settings.API_PREFIX + "/cache"

@router.get("/status")
async def cache_status():
"""
Get cache status (enabled flag and Redis connection health).
"""
enabled = settings.CACHE_ENABLED
try:
pong = await cache_service.redis.ping()
healthy = pong is True
except Exception as e:
healthy = False
return {"enabled": enabled, "healthy": healthy}

@router.post("/clear")
async def cache_clear():
"""
Clear the entire cache.
"""
if not settings.CACHE_ENABLED:
raise HTTPException(status_code=400, detail="Caching is disabled")
await cache_service.clear()
return {"status": "cleared"}

@router.get("/stats")
async def cache_stats():
"""
Retrieve cache hit/miss statistics.
"""
stats = cache_service.stats()
return stats
34 changes: 34 additions & 0 deletions app/api/v1/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional
from app.services.graph_service import GraphService

router = APIRouter()
graph_service = GraphService()

class UsersRequest(BaseModel):
token: str
top: Optional[int] = None

class MeRequest(BaseModel):
token: str

@router.post("/graph/users")
async def graph_users(req: UsersRequest):
try:
result = await graph_service.list_users(req.token, req.top)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Unhandled exception: {e}")
if "error" in result:
raise HTTPException(status_code=502, detail=result["error"])
return result

@router.post("/graph/me")
async def graph_me(req: MeRequest):
try:
result = await graph_service.get_me(req.token)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Unhandled exception: {e}")
if "error" in result:
raise HTTPException(status_code=502, detail=result["error"])
return result
37 changes: 35 additions & 2 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,46 @@ class Settings(BaseSettings):
default="https://graph.microsoft.com/v1.0",
description="Base URL for Microsoft Graph API",
)
# Redis Configuration
REDIS_URL: Optional[str] = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
REDIS_HOST: Optional[str] = Field(
default="localhost",
description="Redis host",
)
REDIS_PORT: Optional[int] = Field(
default=6379,
description="Redis port",
)
REDIS_DB: Optional[int] = Field(
default=0,
description="Redis database index",
)
REDIS_PASSWORD: Optional[str] = Field(
default=None,
description="Redis password",
)

# Cache Settings
CACHE_ENABLED: bool = Field(
default=True,
description="Enable or disable caching",
)
CACHE_TTL_DEFAULT: int = Field(
default=300,
description="Default TTL (seconds) for cache entries",
)
CACHE_KEY_PREFIX: str = Field(
default="autoaudit",
description="Prefix for all cache keys",
)

class Config:
"""Pydantic configuration."""

env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = True


settings = Settings()
27 changes: 26 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from app.core.config import settings
from app.api.v1 import auth
from app.utils.logger import logger
from app.api.v1 import cache
from app.api.middleware.cache_middleware import CacheMiddleware
from app.api.v1.graph import router as graph_router



def create_app() -> FastAPI:
Expand Down Expand Up @@ -35,7 +39,7 @@ def configure_middleware(app: FastAPI, settings):
allow_methods=["*"],
allow_headers=["*"],
)

app.add_middleware(CacheMiddleware)
# Trusted host middleware
app.add_middleware(
TrustedHostMiddleware,
Expand Down Expand Up @@ -105,6 +109,27 @@ async def health_check():
"status": "healthy",
"version": settings.VERSION,
}

def configure_routing(app: FastAPI, settings):
# Authentication endpoints
app.include_router(
auth.router,
prefix=f"{settings.API_PREFIX}/auth",
tags=["Authentication"],
responses={404: {"description": "Not found"}},
)

# Cache endpoints
app.include_router(
cache.router,
prefix=f"{settings.API_PREFIX}/cache",
tags=["Cache"],
)

app.include_router(
graph_router,
prefix=f"{settings.API_PREFIX}",
tags=["Graph"],
)

app = create_app()
66 changes: 66 additions & 0 deletions app/services/cache_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import aioredis
import asyncio
from typing import Optional
from app.core.config import settings
from structlog import get_logger

logger = get_logger()

class CacheService:
"""Redis-based cache service for AutoAudit API."""

def __init__(self):
# Initialize Redis connection pool
self.redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
# Stats
self.hits = 0
self.misses = 0

async def get(self, key: str) -> Optional[str]:
"""Retrieve a value from cache."""
# Use namespaced cache keys to avoid collisions across environments/projects
value = await self.redis.get(f"{settings.CACHE_KEY_PREFIX}:{key}")
if value is None:
self.misses += 1
logger.debug("Cache miss", key=key)
else:
self.hits += 1
logger.debug("Cache hit", key=key)
return value

async def set(self, key: str, value: str, ttl: Optional[int] = None) -> None:
"""Set a value in cache with TTL."""
expire = ttl or settings.CACHE_TTL_DEFAULT
await self.redis.set(
f"{settings.CACHE_KEY_PREFIX}:{key}",
value,
ex=expire,
)
logger.debug("Cache set", key=key, ttl=expire)

async def delete(self, key: str) -> None:
"""Delete a key from cache."""
await self.redis.delete(f"{settings.CACHE_KEY_PREFIX}:{key}")
logger.debug("Cache delete", key=key)

async def clear(self) -> None:
"""Clear the entire cache (use with caution)."""
await self.redis.flushdb()
logger.warning("Cache cleared")

def stats(self) -> dict:
"""Return cache hit/miss statistics."""
total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": f"{hit_rate:.2f}%",
}

# Instantiate a singleton service
cache_service = CacheService()
70 changes: 70 additions & 0 deletions app/services/graph_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

import json
import httpx
from typing import Dict, Any, Optional
from app.core.config import settings
from app.services.cache_service import cache_service
from app.utils.logger import logger

class GraphService:
def __init__(self):
self.base = settings.GRAPH_API_BASE_URL.rstrip("/")

def _users_cache_key(self, token: str, top: Optional[int]) -> str:
suffix = f":top={top}" if top else ""
# Use only the first 16 chars of token to avoid storing full secrets
return f"graph_users:{token[:16]}{suffix}"

def _me_cache_key(self, token: str) -> str:

return f"graph_me:{token[:16]}"

async def list_users(self, token: str, top: Optional[int] = None) -> Dict[str, Any]:
key = self._users_cache_key(token, top)
cached = await cache_service.get(key)
if cached:
try:
data = json.loads(cached)
logger.info("Users fetched from cache", count=len(data.get("value", [])))
return {"source": "cache", "data": data}
except Exception:
logger.warning("Users cache decode failed; calling Graph", key=key)

headers = {"Authorization": f"Bearer {token}"}
params = {"$top": str(top)} if top else None
async with httpx.AsyncClient(verify=False) as client:
resp = await client.get(f"{self.base}/users", headers=headers, params=params)

if resp.status_code == 200:
data = resp.json()
await cache_service.set(key, json.dumps(data))
logger.info("Users fetched from Graph", count=len(data.get("value", [])))
return {"source": "graph", "data": data}

logger.warning("Users call failed", status_code=resp.status_code, text=resp.text)
return {"source": "graph", "error": f"{resp.status_code}: {resp.text}"}

async def get_me(self, token: str) -> Dict[str, Any]:
key = self._me_cache_key(token)

cached = await cache_service.get(key)
if cached:
try:
data = json.loads(cached)
logger.info("Me fetched from cache", user_id=data.get("id"))
return {"source": "cache", "data": data}
except Exception:
logger.warning("Me cache decode failed; calling Graph", key=key)

headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(verify=False) as client:
resp = await client.get(f"{self.base}/me", headers=headers)

if resp.status_code == 200:
data = resp.json()
await cache_service.set(key, json.dumps(data))
logger.info("Me fetched from Graph", user_id=data.get("id"))
return {"source": "graph", "data": data}

logger.warning("Me call failed", status_code=resp.status_code, text=resp.text)
return {"source": "graph", "error": f"{resp.status_code}: {resp.text}"}
Loading