diff --git a/monitoring/grafana/tenet_dashboard.json b/monitoring/grafana/tenet_dashboard.json new file mode 100644 index 0000000..500b18d --- /dev/null +++ b/monitoring/grafana/tenet_dashboard.json @@ -0,0 +1,288 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": null, + "links": [], + "liveNow": false, + "panels": [ + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "id": 2, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "expr": "sum(rate(http_requests_total[5m])) by (endpoint, status)", + "legendFormat": "{{endpoint}} - {{status}}", + "refId": "A" + } + ], + "title": "API Request Rate", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "id": 4, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "expr": "histogram_quantile(0.95, sum(rate(http_request_duration_seconds_bucket[5m])) by (le, endpoint))", + "legendFormat": "p95 {{endpoint}}", + "refId": "A" + } + ], + "title": "Endpoint Latency (p95)", + "type": "timeseries" + }, + { + "datasource": "Prometheus", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 8 + }, + "id": 6, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "expr": "sum(rate(tenet_detections_total[5m])) by (threat_type, verdict)", + "legendFormat": "{{threat_type}} - {{verdict}}", + "refId": "A" + } + ], + "title": "Threat Detections (Rate)", + "type": "timeseries" + } + ], + "refresh": "5s", + "schemaVersion": 36, + "style": "dark", + "tags": ["tenet-ai", "security"], + "templating": { + "list": [] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "TENET AI - Security & Performance", + "uid": "tenet-ai-main", + "version": 1 +} diff --git a/requirements.txt b/requirements.txt index 35ad6c1..9782381 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,5 @@ requests>=2.33.1 python-dotenv>=1.2.2 # Utilities -python-dateutil>=2.8.2 \ No newline at end of file +python-dateutil>=2.8.2 +prometheus-client>=0.20.0 \ No newline at end of file diff --git a/services/analyzer/app.py b/services/analyzer/app.py index d6af919..faf7bde 100644 --- a/services/analyzer/app.py +++ b/services/analyzer/app.py @@ -24,6 +24,8 @@ from services.utils.logging_config import setup_logging from services.security import SecurityManager +from services.utils.metrics import PrometheusMiddleware, increment_detection +from prometheus_client import make_asgi_app logger = setup_logging(__name__) # Environment configuration @@ -45,6 +47,7 @@ # CORS middleware - configurable origins for security CORS_ALLOWED_ORIGINS = os.getenv("CORS_ALLOWED_ORIGINS", "https://localhost:3000,https://localhost:5173") allowed_origins = [origin.strip() for origin in CORS_ALLOWED_ORIGINS.split(",")] +app.add_middleware(PrometheusMiddleware) app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, @@ -53,6 +56,10 @@ allow_headers=["*"], ) +# Mount Prometheus metrics endpoint +metrics_app = make_asgi_app() +app.mount("/metrics", metrics_app) + # Global state redis_client: Optional[redis.Redis] = None ml_model = None @@ -202,6 +209,9 @@ async def analyze_prompt( auth = await security.require_auth(x_api_key, required_permission="analyze") prompt = request.prompt result = run_analysis(prompt) + + increment_detection(service="analyzer", threat_type=result.threat_type, verdict=result.verdict) + security.audit( action="analyze_prompt", result=result.verdict, @@ -456,6 +466,8 @@ async def _process_single_event(event_json: str): # Analyze the prompt result = run_analysis(prompt) + increment_detection(service="analyzer_bg", threat_type=result.threat_type, verdict=result.verdict) + # Update and store event await _update_and_store_event(event, event_id, result) diff --git a/services/ingest/app.py b/services/ingest/app.py index 2791568..2e50a1e 100644 --- a/services/ingest/app.py +++ b/services/ingest/app.py @@ -25,6 +25,8 @@ from fastapi import FastAPI, Header, HTTPException, Query from services.security import SecurityManager +from services.utils.metrics import PrometheusMiddleware, increment_detection +from prometheus_client import make_asgi_app from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field @@ -144,6 +146,7 @@ async def record_failure(self) -> None: version="0.1.0", ) +app.add_middleware(PrometheusMiddleware) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, @@ -152,6 +155,10 @@ async def record_failure(self) -> None: allow_headers=["*"], ) +# Mount Prometheus metrics endpoint +metrics_app = make_asgi_app() +app.mount("/metrics", metrics_app) + redis_client: Optional[redis.Redis] = None redis_cb = CircuitBreaker("redis-ingest") _shutdown_event = asyncio.Event() @@ -313,7 +320,9 @@ async def ingest_llm_event(request: LLMEventRequest, x_api_key: str = Header(... event_id = str(uuid.uuid4()) timestamp = datetime.utcnow().isoformat() - blocked, risk_score, verdict = quick_heuristic_check(request.prompt) + blocked, risk_score, verdict, threat_type = quick_heuristic_check(request.prompt) + + increment_detection(service="ingest", threat_type=threat_type, verdict=verdict) event_payload = { "event_id": event_id, @@ -358,7 +367,7 @@ async def ingest_llm_event(request: LLMEventRequest, x_api_key: str = Header(... ) -def quick_heuristic_check(prompt: str) -> tuple[bool, float, str]: +def quick_heuristic_check(prompt: str) -> tuple[bool, float, str, str]: prompt_lower = prompt.lower() injection_patterns = [ @@ -402,17 +411,17 @@ def quick_heuristic_check(prompt: str) -> tuple[bool, float, str]: for pattern in injection_patterns: if pattern in prompt_lower: - return True, 0.95, "malicious" + return True, 0.95, "malicious", "prompt_injection" for pattern in jailbreak_patterns: if pattern in prompt_lower: - return True, 0.90, "malicious" + return True, 0.90, "malicious", "jailbreak" for pattern in extraction_patterns: if pattern in prompt_lower: - return False, 0.75, "suspicious" + return False, 0.75, "suspicious", "data_extraction" - return False, 0.0, "benign" + return False, 0.0, "benign", "none" @app.get("/v1/events") diff --git a/services/utils/metrics.py b/services/utils/metrics.py new file mode 100644 index 0000000..6d5db76 --- /dev/null +++ b/services/utils/metrics.py @@ -0,0 +1,59 @@ +import time +from typing import Callable, Awaitable +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from prometheus_client import Counter, Histogram, REGISTRY + +# Define standard request metrics +REQUEST_COUNT = Counter( + "http_requests_total", + "Total HTTP requests", + ["method", "endpoint", "status"] +) + +REQUEST_LATENCY = Histogram( + "http_request_duration_seconds", + "HTTP request latency", + ["method", "endpoint"] +) + +# Define detection specific metrics +DETECTION_COUNT = Counter( + "tenet_detections_total", + "Total threats detected by type and verdict", + ["service", "threat_type", "verdict"] +) + +def increment_detection(service: str, threat_type: str, verdict: str) -> None: + """Helper to increment detection counters.""" + if not threat_type: + threat_type = "none" + DETECTION_COUNT.labels(service=service, threat_type=threat_type, verdict=verdict).inc() + + +class PrometheusMiddleware(BaseHTTPMiddleware): + """Middleware to collect basic HTTP metrics.""" + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + start_time = time.time() + + method = request.method + try: + response = await call_next(request) + status_code = str(response.status_code) + + route = request.scope.get("route") + endpoint = route.path if route else request.url.path + + REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status_code).inc() + REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(time.time() - start_time) + + return response + except Exception: + # If an unhandled exception occurs, assume 500 status code + status_code = "500" + route = request.scope.get("route") + endpoint = route.path if route else request.url.path + + REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status_code).inc() + REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(time.time() - start_time) + raise diff --git a/tests/unit/test_ingest.py b/tests/unit/test_ingest.py index b4eb1fe..fa1d92a 100644 --- a/tests/unit/test_ingest.py +++ b/tests/unit/test_ingest.py @@ -51,7 +51,7 @@ def test_detects_prompt_injection(self): ] for prompt in malicious_prompts: - blocked, risk_score, verdict = quick_heuristic_check(prompt) + blocked, risk_score, verdict, threat_type = quick_heuristic_check(prompt) assert blocked is True, f"Should block: {prompt}" assert risk_score > 0.8, f"Risk score should be high for: {prompt}" assert verdict == "malicious" @@ -66,7 +66,7 @@ def test_detects_jailbreak_attempts(self): ] for prompt in jailbreak_prompts: - blocked, risk_score, verdict = quick_heuristic_check(prompt) + blocked, risk_score, verdict, threat_type = quick_heuristic_check(prompt) assert blocked is True, f"Should block: {prompt}" assert risk_score >= 0.8 @@ -79,7 +79,7 @@ def test_flags_data_extraction(self): ] for prompt in extraction_prompts: - blocked, risk_score, verdict = quick_heuristic_check(prompt) + blocked, risk_score, verdict, threat_type = quick_heuristic_check(prompt) # These should be flagged (suspicious) but not blocked assert verdict == "suspicious", f"Should flag as suspicious: {prompt}" assert 0.5 < risk_score < 0.9 @@ -95,7 +95,7 @@ def test_allows_benign_prompts(self): ] for prompt in benign_prompts: - blocked, risk_score, verdict = quick_heuristic_check(prompt) + blocked, risk_score, verdict, threat_type = quick_heuristic_check(prompt) assert blocked is False, f"Should not block: {prompt}" assert risk_score == 0.0 assert verdict == "benign"