diff --git a/Dockerfile b/Dockerfile index 2266452..d80cb63 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,6 +36,10 @@ COPY run.py /app/ # Port the application listens on EXPOSE 8000 +# Health check to ensure the container is responding to requests +HEALTHCHECK --interval=30s --timeout=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 + # Command to run the application using Uvicorn # The command format is: uvicorn [module:app_object] --host [ip] --port [port] # We use the standard uvicorn worker configuration diff --git a/README.md b/README.md index d8fdedf..4396001 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ Copy `.env.example` to `.env` (done automatically by `make dev-setup`) and fill | `ENABLE_ETL_SCHEDULER` | | Set `true` to run the ETL on a schedule | | `ETL_CRON` | | Cron expression (UTC) — takes precedence over `ETL_INTERVAL_MINUTES` | | `ETL_INTERVAL_MINUTES` | | ETL polling interval in minutes (default `15`) | +| `SHUTDOWN_TIMEOUT_SECONDS` | | Graceful shutdown timeout in seconds (default `30`) | | `BQ_ENABLED` | | Set `true` to enable BigQuery loading | | `BQ_PROJECT_ID` | | GCP project ID | | `BQ_DATASET` | | BigQuery dataset name | diff --git a/pytest.ini b/pytest.ini index 7765f02..4505ed7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,6 @@ [pytest] minversion = 7.0 -addopts = -ra -q +addopts = -ra -q -m "not integration" testpaths = tests +markers = + integration: integration tests requiring a live database diff --git a/run.py b/run.py index 6e0d946..38d86e9 100644 --- a/run.py +++ b/run.py @@ -2,6 +2,13 @@ from src.main import app # noqa: F401 if __name__ == "__main__": - uvicorn.run("src.main:app", host="0.0.0.0", port=8000) + from src.config import get_settings + settings = get_settings() + uvicorn.run( + "src.main:app", + host="0.0.0.0", + port=8000, + timeout_graceful_shutdown=settings.SHUTDOWN_TIMEOUT_SECONDS + ) diff --git a/src/analytics/service.py b/src/analytics/service.py index 4dbaa2f..e2ca6fc 100644 --- a/src/analytics/service.py +++ b/src/analytics/service.py @@ -305,6 +305,30 @@ def get_recent_scans(self, event_id: str, limit: int = 100) -> List[Dict[str, An raise finally: session.close() + + def get_scans_by_ticket_id(self, ticket_id: str, limit: int = 100) -> List[Dict[str, Any]]: + """Get scan records for a specific ticket identifier.""" + try: + session = get_session() + scans = session.query(TicketScan).filter( + TicketScan.ticket_id == ticket_id + ).order_by(desc(TicketScan.scan_timestamp)).limit(limit).all() + return [{ + "id": scan.id, + "ticket_id": scan.ticket_id, + "event_id": scan.event_id, + "scan_timestamp": scan.scan_timestamp.isoformat(), + "is_valid": scan.is_valid, + "location": scan.location + } for scan in scans] + except Exception as e: + log_error("Failed to get scans by ticket_id", { + "ticket_id": ticket_id, + "error": str(e) + }) + raise + finally: + session.close() def get_recent_transfers(self, event_id: str, limit: int = 100) -> List[Dict[str, Any]]: """Get recent transfer records for an event.""" diff --git a/src/config.py b/src/config.py index e206e2e..7de2b12 100644 --- a/src/config.py +++ b/src/config.py @@ -4,11 +4,6 @@ from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic_settings import BaseSettings - -class Settings(BaseSettings): - - class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") @@ -38,6 +33,7 @@ class Settings(BaseSettings): POOL_SIZE: int = 5 POOL_MAX_OVERFLOW: int = 10 REPORT_CACHE_MINUTES: int = 60 + SHUTDOWN_TIMEOUT_SECONDS: int = 30 SERVICE_API_KEY: str = "default_service_secret_change_me" ADMIN_API_KEY: str = "default_admin_secret_change_me" @@ -46,9 +42,6 @@ class Settings(BaseSettings): "owerri,warri,uyo,akure,ilorin,sokoto,zaria,maiduguri,asaba,nnewi" ) - class Config: - env_file = ".env" - settings = Settings() diff --git a/src/fraud.py b/src/fraud.py index 5639e49..e49e6dd 100644 --- a/src/fraud.py +++ b/src/fraud.py @@ -35,7 +35,7 @@ def check_fraud_rules(events: List[Dict[str, Any]]) -> List[str]: triggered.add("duplicate_ticket_transfer") # Rule 3: Excessive purchases by same user in a day (>5) - from datetime import date as date_type # noqa: PLC0415 – local import to avoid shadowing + from datetime import date as date_type # noqa: PLC0415 purchases_by_user_day: Dict[tuple[str, date_type], int] = {} for event in events: if event.get("type") == "purchase": @@ -47,6 +47,31 @@ def check_fraud_rules(events: List[Dict[str, Any]]) -> List[str]: if count > 5: triggered.add("excessive_purchases_user_day") + # Rule 4: Impossible travel (different locations within 30 min) + scans_by_ticket: Dict[str, List[Dict[str, Any]]] = {} + for event in events: + if event.get("type") == "scan": + tid = str(event.get("ticket_id", "")) + scans_by_ticket.setdefault(tid, []).append(event) + for _tid, scans in scans_by_ticket.items(): + scans.sort(key=lambda x: datetime.fromisoformat(str(x.get("timestamp", "")))) + for i in range(len(scans) - 1): + t1 = datetime.fromisoformat(str(scans[i].get("timestamp", ""))) + t2 = datetime.fromisoformat(str(scans[i + 1].get("timestamp", ""))) + loc1 = scans[i].get("location") + loc2 = scans[i + 1].get("location") + if loc1 != loc2 and (t2 - t1).total_seconds() <= 1800: + triggered.add("impossible_travel_scan") + break + + # Rule 5: Bulk allocation (20% or more of event capacity) + for event in events: + if event.get("type") == "purchase": + qty = float(event.get("qty", 1)) + capacity = float(event.get("capacity", 1000000)) + if capacity > 0 and (qty / capacity) >= 0.2: + triggered.add("bulk_allocation_purchase") + return list(triggered) @@ -58,7 +83,12 @@ def determine_severity(triggered_rules: List[str]) -> str: if not triggered_rules: return "none" - HIGH_RULES = {"too_many_purchases_same_ip", "excessive_purchases_user_day"} + HIGH_RULES = { + "too_many_purchases_same_ip", + "excessive_purchases_user_day", + "impossible_travel_scan", + "bulk_allocation_purchase", + } MEDIUM_RULES = {"duplicate_ticket_transfer"} s = set(triggered_rules) diff --git a/src/main.py b/src/main.py index 614bc90..d42d0e2 100644 --- a/src/main.py +++ b/src/main.py @@ -197,11 +197,14 @@ def on_startup() -> None: @app.on_event("shutdown") def on_shutdown() -> None: global etl_scheduler + log_info("Shutdown initiated: waiting for in-flight requests and scheduler...") if etl_scheduler is not None: try: - etl_scheduler.shutdown(wait=False) - except Exception: - pass + # wait=True ensures running jobs complete before scheduler stops + etl_scheduler.shutdown(wait=True) + log_info("ETL scheduler shut down successfully.") + except Exception as exc: + log_error("Error during scheduler shutdown", {"error": str(exc)}) # --------------------------------------------------------------------------- @@ -262,7 +265,7 @@ def generate_qr(payload: TicketRequest) -> Any: encoded = base64.b64encode(buffer.read()).decode("utf-8") QR_GENERATIONS_TOTAL.inc() log_info("QR code generated successfully") - return QRResponse(qr_base64=encoded) + return QRResponse(qr_base64=encoded, token=json.dumps(data, separators=(",", ":"))) @app.post("/validate-qr", response_model=QRValidateResponse) @@ -280,15 +283,30 @@ def validate_qr(payload: QRValidateRequest) -> QRValidateResponse: if hmac.compare_digest(provided_sig, expected_sig): QR_VALIDATIONS_TOTAL.labels(result="valid").inc() log_info("QR validation successful", {"ticket_id": unsigned.get("ticket_id")}) + analytics_service.log_ticket_scan( + ticket_id=str(unsigned.get("ticket_id") or "unknown"), + event_id=str(unsigned.get("event") or "unknown"), + is_valid=True + ) return QRValidateResponse(isValid=True, metadata=unsigned) log_warning("Invalid QR signature", {"metadata": unsigned}) QR_VALIDATIONS_TOTAL.labels(result="invalid").inc() + analytics_service.log_ticket_scan( + ticket_id=str(unsigned.get("ticket_id") or "unknown"), + event_id=str(unsigned.get("event") or "unknown"), + is_valid=False + ) return QRValidateResponse(isValid=False) except Exception as exc: log_warning("Invalid QR validation attempt", {"error": str(exc)}) QR_VALIDATIONS_TOTAL.labels(result="error").inc() return QRValidateResponse(isValid=False) +@app.get("/qr/scan-log/{ticket_id}") +def get_qr_scan_log(ticket_id: str) -> List[Dict[str, Any]]: + """Returns the scan audit log for a specific ticket.""" + return analytics_service.get_scans_by_ticket_id(ticket_id) + # --------------------------------------------------------------------------- # Analytics endpoints diff --git a/src/types_custom.py b/src/types_custom.py index 9eb75a1..6b55ad6 100644 --- a/src/types_custom.py +++ b/src/types_custom.py @@ -52,6 +52,7 @@ class TicketRequest(BaseModel): class QRResponse(BaseModel): model_config = ConfigDict(extra="forbid") qr_base64: str + token: str class QRValidateRequest(BaseModel): diff --git a/tests/conftest.py b/tests/conftest.py index 02bed76..0e1ffb7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,36 @@ import os +import pytest +from sqlalchemy import create_engine, text +from src.config import get_settings # Provide a non-default test key so startup validation passes in test environments. os.environ.setdefault("QR_SIGNING_KEY", "a" * 32) +# Force model training to skip in test environments +os.environ.setdefault("SKIP_MODEL_TRAINING", "true") + +@pytest.fixture(scope="session") +def db_engine(): + """Provides a database engine for integration tests.""" + settings = get_settings() + engine = create_engine(settings.DATABASE_URL) + yield engine + engine.dispose() + +@pytest.fixture +def clean_test_db(db_engine): + """Truncates all tables before/after integration tests.""" + tables = ["event_sales_summary", "daily_ticket_sales", "etl_run_log"] + with db_engine.begin() as conn: + for table in tables: + try: + conn.execute(text(f"TRUNCATE TABLE {table} RESTART IDENTITY CASCADE")) + except Exception: + # Tables might not exist yet if it's the first run + pass + yield + with db_engine.begin() as conn: + for table in tables: + try: + conn.execute(text(f"TRUNCATE TABLE {table} RESTART IDENTITY CASCADE")) + except Exception: + pass diff --git a/tests/e2e/test_qr_flow.py b/tests/e2e/test_qr_flow.py new file mode 100644 index 0000000..a254670 --- /dev/null +++ b/tests/e2e/test_qr_flow.py @@ -0,0 +1,103 @@ +import base64 +import json +import pytest +from fastapi.testclient import TestClient +from prometheus_client import REGISTRY + +from src.main import app +from src.config import get_settings + +@pytest.fixture +def client(monkeypatch): + """Fixture to provide a TestClient with a fixed signing key and cleared settings cache.""" + # Requirement: All tests must set QR_SIGNING_KEY to a 32-character test string + monkeypatch.setenv("QR_SIGNING_KEY", "q" * 32) + get_settings.cache_clear() + return TestClient(app) + +def test_qr_generate_validate_audit_flow(client): + """ + End-to-end test for the QR lifecycle: + 1. Generate a signed QR token. + 2. Validate the token successfully. + 3. Validate a tampered token (invalid signature). + 4. Validate a token with extra fields (signature mismatch). + 5. Verify the audit log contains valid and invalid entries. + 6. Verify Prometheus metrics. + """ + ticket_id = "E2E-TEST-001" + event_name = "E2E-Festival" + + # --- Step 1: Generate QR --- + gen_payload = { + "ticket_id": ticket_id, + "event": event_name, + "user": "tester@example.com" + } + resp = client.post("/generate-qr", json=gen_payload) + + # Assertions 1 & 2: Status 200, valid PNG, and token presence + assert resp.status_code == 200 + data = resp.json() + assert "qr_base64" in data + assert "token" in data # Extracted signed token requirement + + qr_content = base64.b64decode(data["qr_base64"]) + assert qr_content.startswith(b"\x89PNG"), "QR code must be a PNG image" + + token_str = data["token"] + token_obj = json.loads(token_str) + + # Store initial metrics + def get_metric(res): + return REGISTRY.get_sample_value("qr_validations_total", {"result": res}) or 0 + + m_valid_start = get_metric("valid") + m_invalid_start = get_metric("invalid") + + # --- Step 2: Validate (Successful) --- + val_resp = client.post("/validate-qr", json={"qr_text": token_str}) + + # Assertion 3: Valid scan returns True and correct metadata + assert val_resp.status_code == 200 + val_data = val_resp.json() + assert val_data["isValid"] is True + assert val_data["metadata"]["ticket_id"] == ticket_id + assert val_data["metadata"]["event"] == event_name + + # --- Step 3: Validate (Tampered signature) --- + tampered_token = token_obj.copy() + tampered_token["sig"] = "invalid_signature_string" + resp_tampered = client.post("/validate-qr", json={"qr_text": json.dumps(tampered_token)}) + + # Assertion 4: Tampered signature returns False + assert resp_tampered.status_code == 200 + assert resp_tampered.json()["isValid"] is False + + # --- Step 4: Validate (Extra field / Tampered payload) --- + extra_field_token = token_obj.copy() + extra_field_token["fraud"] = "injected" + resp_extra = client.post("/validate-qr", json={"qr_text": json.dumps(extra_field_token)}) + + # Assertion 5: Extra field (tampering) returns False + assert resp_extra.status_code == 200 + assert resp_extra.json()["isValid"] is False + + # --- Step 5: Audit Log --- + log_resp = client.get(f"/qr/scan-log/{ticket_id}") + + # Assertion 6: Audit log contains valid and invalid entries + assert log_resp.status_code == 200 + logs = log_resp.json() + assert len(logs) >= 2, "Should have at least one valid and one invalid log entry" + + has_valid = any(l["is_valid"] is True for l in logs) + has_invalid = any(l["is_valid"] is False for l in logs) + assert has_valid and has_invalid, "Audit log must contain both valid and invalid attempts" + + # --- Step 6: Prometheus Metrics --- + # Assertion 7: Valid scan incremented counter + assert get_metric("valid") == m_valid_start + 1 + + # Assertion 8: Invalid scans incremented counter (2 invalid attempts in steps 3-4) + assert get_metric("invalid") >= m_invalid_start + 2 diff --git a/tests/integration/test_etl_pipeline.py b/tests/integration/test_etl_pipeline.py new file mode 100644 index 0000000..929dbe7 --- /dev/null +++ b/tests/integration/test_etl_pipeline.py @@ -0,0 +1,132 @@ +import httpx +import pytest +from sqlalchemy import text +from src.etl import run_etl_once +from src.config import get_settings + +def _response(status_code: int, url: str, payload): + request = httpx.Request("GET", url) + return httpx.Response(status_code=status_code, json=payload, request=request) + +class DummyClient: + def __init__(self, side_effects_map): + self._side_effects_map = side_effects_map + self.calls = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, url, headers=None, params=None): + self.calls.append({"url": url, "headers": headers or {}, "params": params}) + # Extract base path for matching (e.g., /events or /ticket-sales) + for path, effects in self._side_effects_map.items(): + if path in url: + if not effects: + raise ValueError(f"No mock responses left for path: {path}") + effect = effects.pop(0) + if isinstance(effect, Exception): + raise effect + return effect + raise ValueError(f"No mock for URL: {url}") + +@pytest.mark.integration +def test_full_etl_pipeline_integration(monkeypatch, db_engine, clean_test_db): + """ + Integration test for the full ETL pipeline. + + 1. Mocks NestJS API to return 3 events and 10 sales records. + 2. Runs run_etl_once(). + 3. Verifies data in PostgreSQL (event_sales_summary, daily_ticket_sales). + 4. Runs again with updated sales to test upsert logic. + 5. Verifies log entry in etl_run_log. + """ + # 1. Setup mocks for first run (3 events, 10 sales) + events_data = [ + {"id": "E1", "name": "Event 1"}, + {"id": "E2", "name": "Event 2"}, + {"id": "E3", "name": "Event 3"}, + ] + # 10 sales records spread across 3 events and 2 dates + sales_data = [ + {"event_id": "E1", "qty": 2, "price": 50.0, "sale_date": "2024-01-01T10:00:00"}, + {"event_id": "E1", "qty": 1, "price": 50.0, "sale_date": "2024-01-01T11:00:00"}, + {"event_id": "E1", "qty": 1, "price": 50.0, "sale_date": "2024-01-02T10:00:00"}, + {"event_id": "E2", "qty": 5, "price": 20.0, "sale_date": "2024-01-01T10:00:00"}, + {"event_id": "E2", "qty": 1, "price": 20.0, "sale_date": "2024-01-01T11:00:00"}, + {"event_id": "E3", "qty": 1, "price": 100.0, "sale_date": "2024-01-01T10:00:00"}, + {"event_id": "E3", "qty": 1, "price": 100.0, "sale_date": "2024-01-01T11:00:00"}, + {"event_id": "E3", "qty": 1, "price": 100.0, "sale_date": "2024-01-01T12:00:00"}, + {"event_id": "E3", "qty": 1, "price": 100.0, "sale_date": "2024-01-01T13:00:00"}, + {"event_id": "E3", "qty": 1, "price": 100.0, "sale_date": "2024-01-02T10:00:00"}, + ] + + side_effects = { + "/events": [_response(200, "/events", {"data": events_data})], + "/ticket-sales": [_response(200, "/ticket-sales", {"data": sales_data})], + } + + dummy_client = DummyClient(side_effects) + monkeypatch.setattr("src.etl.extract.httpx.Client", lambda timeout: dummy_client) + + # 2. Run ETL + run_etl_once() + + # 3. Assertions for Run 1 + with db_engine.connect() as conn: + # event_sales_summary should have 3 rows + rows = conn.execute(text("SELECT event_id, total_tickets, total_revenue FROM event_sales_summary ORDER BY event_id")).fetchall() + assert len(rows) == 3 + + # E1: 3 sales totalling 4 tickets, 200 rev + assert rows[0][0] == "E1" + assert int(rows[0][1]) == 4 + assert float(rows[0][2]) == 200.0 + + # E2: 2 sales totalling 6 tickets, 120 rev + assert rows[1][0] == "E2" + assert int(rows[1][1]) == 6 + assert float(rows[1][2]) == 120.0 + + # E3: 5 sales totalling 5 tickets, 500 rev + assert rows[2][0] == "E3" + assert int(rows[2][1]) == 5 + assert float(rows[2][2]) == 500.0 + + # daily_ticket_sales breakdown + # E1 on 2024-01-01: 3 tickets (2+1), 150 rev + row_e1_d1 = conn.execute(text("SELECT tickets_sold, revenue FROM daily_ticket_sales WHERE event_id='E1' AND sale_date='2024-01-01'")).fetchone() + assert int(row_e1_d1[0]) == 3 + assert float(row_e1_d1[1]) == 150.0 + + # 4. Run again with updated sales data for upsert test + # E1 gets 2 more sales on 2024-01-01 (adding 2 qty) + # Important: In this implementation, transform_summary aggregates everything provided in 'sales'. + # If the API returns the same records again, they will be summed again. + # The requirement says "updated sales data", so we simulate a second run where more data is returned. + new_sale = {"event_id": "E1", "qty": 2, "price": 50.0, "sale_date": "2024-01-01T15:00:00"} + updated_sales_data = sales_data + [new_sale] + + side_effects["/events"].append(_response(200, "/events", {"data": events_data})) + side_effects["/ticket-sales"].append(_response(200, "/ticket-sales", {"data": updated_sales_data})) + + run_etl_once() + + # 5. Assertions for Run 2 (Upsert) + with db_engine.connect() as conn: + # E1: 4 (old) + 2 (new) = 6 tickets, 200 + 100 = 300 rev + row_e1 = conn.execute(text("SELECT total_tickets, total_revenue FROM event_sales_summary WHERE event_id='E1'")).fetchone() + assert int(row_e1[0]) == 6 + assert float(row_e1[1]) == 300.0 + + # E1 daily on 2024-01-01: 3 (old) + 2 (new) = 5 tickets, 150 + 100 = 250 rev + row_e1_d1_new = conn.execute(text("SELECT tickets_sold, revenue FROM daily_ticket_sales WHERE event_id='E1' AND sale_date='2024-01-01'")).fetchone() + assert int(row_e1_d1_new[0]) == 5 + assert float(row_e1_d1_new[1]) == 250.0 + + # etl_run_log should have 2 entries, status="success" + log_rows = conn.execute(text("SELECT status FROM etl_run_log")).fetchall() + assert len(log_rows) == 2 + assert all(r[0] == "success" for r in log_rows) diff --git a/tests/test_fraud.py b/tests/test_fraud.py index 54b5d91..3b8462d 100644 --- a/tests/test_fraud.py +++ b/tests/test_fraud.py @@ -1,90 +1,135 @@ -import os - -os.environ.setdefault("SKIP_MODEL_TRAINING", "true") - import pytest -from fastapi.testclient import TestClient +from datetime import datetime, timedelta +from src.fraud import check_fraud_rules, determine_severity -from src.main import app +class TestFraudRules: + """Unit tests for check_fraud_rules covering all 5 core rules and boundaries.""" -client = TestClient(app) + def test_too_many_purchases_same_ip(self): + """Rule 1: >3 purchases from same IP in 10-min window.""" + base_ts = datetime(2025, 1, 1, 10, 0, 0) + + # 3 in window (no trigger) + events_3 = [ + {"type": "purchase", "ip": "1.1.1.1", "timestamp": (base_ts + timedelta(minutes=i)).isoformat()} + for i in range(3) + ] + assert "too_many_purchases_same_ip" not in check_fraud_rules(events_3) + + # 4 in window (trigger) + events_4 = events_3 + [ + {"type": "purchase", "ip": "1.1.1.1", "timestamp": (base_ts + timedelta(seconds=240)).isoformat()} + ] + assert "too_many_purchases_same_ip" in check_fraud_rules(events_4) + + # 4 across two 10-min windows (no trigger) + events_split = [ + {"type": "purchase", "ip": "2.2.2.2", "timestamp": base_ts.isoformat()}, + {"type": "purchase", "ip": "2.2.2.2", "timestamp": (base_ts + timedelta(seconds=10)).isoformat()}, + {"type": "purchase", "ip": "2.2.2.2", "timestamp": (base_ts + timedelta(seconds=601)).isoformat()}, + {"type": "purchase", "ip": "2.2.2.2", "timestamp": (base_ts + timedelta(seconds=610)).isoformat()}, + ] + assert "too_many_purchases_same_ip" not in check_fraud_rules(events_split) + def test_duplicate_ticket_transfer(self): + """Rule 2: Same ticket_id transferred more than once.""" + # Single transfer (no trigger) + assert "duplicate_ticket_transfer" not in check_fraud_rules([{"type": "transfer", "ticket_id": "T1"}]) + + # Double transfer (trigger) + events = [ + {"type": "transfer", "ticket_id": "T2"}, + {"type": "transfer", "ticket_id": "T2"}, + ] + assert "duplicate_ticket_transfer" in check_fraud_rules(events) -def test_check_fraud_triggers_rules(): - # Too many purchases from same IP in 10min - base_event = { - "type": "purchase", - "user": "user1", - "ip": "1.2.3.4", - "ticket_id": "T1", - "timestamp": "2025-10-01T10:00:00", - } - events = [ - {**base_event, "timestamp": "2025-10-01T10:00:00"}, - {**base_event, "timestamp": "2025-10-01T10:01:00"}, - {**base_event, "timestamp": "2025-10-01T10:02:00"}, - {**base_event, "timestamp": "2025-10-01T10:03:00"}, - ] - resp = client.post("/check-fraud", json={"events": events}) - assert resp.status_code == 200 - assert "too_many_purchases_same_ip" in resp.json()["triggered_rules"] + def test_excessive_purchases_user_day(self): + """Rule 3: >5 purchases by same user on the same calendar day.""" + base_ts = datetime(2025, 1, 1, 10, 0, 0) + + # 5 purchases same day (no trigger) + events_5 = [ + {"type": "purchase", "user": "alice", "timestamp": (base_ts + timedelta(hours=i)).isoformat()} + for i in range(5) + ] + assert "excessive_purchases_user_day" not in check_fraud_rules(events_5) + + # 6 purchases same day (trigger) + events_6 = events_5 + [ + {"type": "purchase", "user": "alice", "timestamp": (base_ts + timedelta(hours=5)).isoformat()} + ] + assert "excessive_purchases_user_day" in check_fraud_rules(events_6) + + # 6 across two days (no trigger) + events_split = [ + {"type": "purchase", "user": "bob", "timestamp": "2025-01-01T23:00:00"}, + {"type": "purchase", "user": "bob", "timestamp": "2025-01-01T23:30:00"}, + {"type": "purchase", "user": "bob", "timestamp": "2025-01-02T00:10:00"}, + {"type": "purchase", "user": "bob", "timestamp": "2025-01-02T00:20:00"}, + {"type": "purchase", "user": "bob", "timestamp": "2025-01-02T00:30:00"}, + {"type": "purchase", "user": "bob", "timestamp": "2025-01-02T01:00:00"}, + ] + assert "excessive_purchases_user_day" not in check_fraud_rules(events_split) + def test_impossible_travel_scan(self): + """Rule 4: Same ticket scanned at different locations within 30 min.""" + base_ts = datetime(2025, 1, 1, 10, 0, 0) + + # Same location (no trigger) + events_same = [ + {"type": "scan", "ticket_id": "T1", "location": "London", "timestamp": base_ts.isoformat()}, + {"type": "scan", "ticket_id": "T1", "location": "London", "timestamp": (base_ts + timedelta(minutes=10)).isoformat()}, + ] + assert "impossible_travel_scan" not in check_fraud_rules(events_same) + + # Different locations within 30 min (trigger) + events_diff = [ + {"type": "scan", "ticket_id": "T2", "location": "London", "timestamp": base_ts.isoformat()}, + {"type": "scan", "ticket_id": "T2", "location": "Paris", "timestamp": (base_ts + timedelta(minutes=29)).isoformat()}, + ] + assert "impossible_travel_scan" in check_fraud_rules(events_diff) + + # Different locations after 30 min (no trigger) + events_far = [ + {"type": "scan", "ticket_id": "T3", "location": "London", "timestamp": base_ts.isoformat()}, + {"type": "scan", "ticket_id": "T3", "location": "Paris", "timestamp": (base_ts + timedelta(minutes=31)).isoformat()}, + ] + assert "impossible_travel_scan" not in check_fraud_rules(events_far) -def test_check_fraud_duplicate_transfer(): - events = [ - { - "type": "transfer", - "user": "user2", - "ip": "2.2.2.2", - "ticket_id": "T2", - "timestamp": "2025-10-01T11:00:00", - }, - { - "type": "transfer", - "user": "user3", - "ip": "2.2.2.2", - "ticket_id": "T2", - "timestamp": "2025-10-01T11:05:00", - }, - ] - resp = client.post("/check-fraud", json={"events": events}) - assert resp.status_code == 200 - assert "duplicate_ticket_transfer" in resp.json()["triggered_rules"] + def test_bulk_allocation_purchase(self): + """Rule 5: Single purchase exceeds 20% of event capacity.""" + ts = "2025-01-01T10:00:00" + # 19% of capacity (no trigger) + assert "bulk_allocation_purchase" not in check_fraud_rules([{"type": "purchase", "qty": 19, "capacity": 100, "timestamp": ts}]) + + # 20%+ (trigger) + assert "bulk_allocation_purchase" in check_fraud_rules([{"type": "purchase", "qty": 20, "capacity": 100, "timestamp": ts}]) + assert "bulk_allocation_purchase" in check_fraud_rules([{"type": "purchase", "qty": 21, "capacity": 100, "timestamp": ts}]) + def test_coverage_edge_cases(self): + """Hit all remaining lines and edge cases in fraud.py.""" + ts = "2025-01-01T10:00:00" + events = [ + {"type": "other", "data": "ignored"}, + {"type": "purchase", "qty": 1, "capacity": 0, "timestamp": ts}, # capacity 0 handle + {"type": "scan", "ticket_id": "Tunique", "timestamp": ts}, # solitary scan + ] + assert check_fraud_rules(events) == [] -def test_check_fraud_excessive_user_purchases(): - events = [ - { - "type": "purchase", - "user": "user4", - "ip": "3.3.3.3", - "ticket_id": f"T{i}", - "timestamp": "2025-10-01T12:00:00", - } - for i in range(6) - ] - resp = client.post("/check-fraud", json={"events": events}) - assert resp.status_code == 200 - assert "excessive_purchases_user_day" in resp.json()["triggered_rules"] +class TestSeverityMapping: + """Unit tests for determine_severity mapping logic.""" -def test_check_fraud_no_triggers(): - events = [ - { - "type": "purchase", - "user": "user5", - "ip": "4.4.4.4", - "ticket_id": "T10", - "timestamp": "2025-10-01T13:00:00", - }, - { - "type": "transfer", - "user": "user5", - "ip": "4.4.4.4", - "ticket_id": "T10", - "timestamp": "2025-10-01T13:10:00", - }, - ] - resp = client.post("/check-fraud", json={"events": events}) - assert resp.status_code == 200 - assert resp.json()["triggered_rules"] == [] \ No newline at end of file + @pytest.mark.parametrize("rules,expected", [ + ([], "none"), + (["duplicate_ticket_transfer"], "medium"), + (["too_many_purchases_same_ip"], "high"), + (["excessive_purchases_user_day"], "high"), + (["impossible_travel_scan"], "high"), + (["bulk_allocation_purchase"], "high"), + (["duplicate_ticket_transfer", "too_many_purchases_same_ip"], "high"), + (["unknown_rule"], "low"), + ]) + def test_severity_levels(self, rules, expected): + """Verify that severity maps correctly to the highest risk rule.""" + assert determine_severity(rules) == expected \ No newline at end of file diff --git a/tests/test_severity.py b/tests/test_severity.py deleted file mode 100644 index 5509739..0000000 --- a/tests/test_severity.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - -from src.fraud import determine_severity - - -def test_determine_severity_none(): - assert determine_severity([]) == "none" - - -def test_determine_severity_high_when_high_rule_present(): - assert determine_severity(["too_many_purchases_same_ip"]) == "high" - assert determine_severity(["excessive_purchases_user_day"]) == "high" - - -def test_determine_severity_medium_when_medium_rule_present(): - assert determine_severity(["duplicate_ticket_transfer"]) == "medium" - - -def test_determine_severity_priority_high_over_medium(): - assert determine_severity(["duplicate_ticket_transfer", "too_many_purchases_same_ip"]) == "high" - - -def test_determine_severity_low_for_unknown_rules(): - assert determine_severity(["some_new_rule"]) == "low"