diff --git a/api/chat_provider.py b/api/chat_provider.py new file mode 100644 index 0000000..e09aeff --- /dev/null +++ b/api/chat_provider.py @@ -0,0 +1,170 @@ +""" +Chat provider abstraction — per-surface LLM decoupling. + +One ABC, two implementations (OpenAI, Anthropic), three independent +factory functions for classifier, chat answer, and query expansion. + +Follows the same pattern as api/embedding_provider.py. +""" + +import asyncio +import logging +import os +from abc import ABC, abstractmethod + +logger = logging.getLogger(__name__) + + +class ChatProvider(ABC): + """Abstract base class for chat completion providers.""" + + provider_name: str + default_model: str + + @abstractmethod + async def complete( + self, + messages: list[dict[str, str]], + *, + model: str | None = None, + temperature: float = 0.3, + max_tokens: int = 1024, + json_mode: bool = False, + ) -> str: + """Return the assistant's text response.""" + + +class OpenAIChatProvider(ChatProvider): + """OpenAI chat completions via the openai package.""" + + provider_name = "openai" + + def __init__(self, api_key: str, default_model: str = "gpt-4o-mini"): + from openai import OpenAI + self._client = OpenAI(api_key=api_key) + self.default_model = default_model + + async def complete( + self, + messages: list[dict[str, str]], + *, + model: str | None = None, + temperature: float = 0.3, + max_tokens: int = 1024, + json_mode: bool = False, + ) -> str: + kwargs: dict = dict( + model=model or self.default_model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + if json_mode: + kwargs["response_format"] = {"type": "json_object"} + + response = await asyncio.to_thread( + self._client.chat.completions.create, **kwargs + ) + return response.choices[0].message.content or "" + + +class AnthropicChatProvider(ChatProvider): + """Anthropic chat completions via the anthropic package.""" + + provider_name = "anthropic" + + def __init__(self, api_key: str, default_model: str = "claude-sonnet-4-20250514"): + import anthropic + self._client = anthropic.Anthropic(api_key=api_key) + self.default_model = default_model + + async def complete( + self, + messages: list[dict[str, str]], + *, + model: str | None = None, + temperature: float = 0.3, + max_tokens: int = 1024, + json_mode: bool = False, + ) -> str: + # Anthropic API: system message is a separate kwarg + system_text = None + user_messages = [] + for m in messages: + if m["role"] == "system": + system_text = m["content"] + else: + user_messages.append({"role": m["role"], "content": m["content"]}) + + # json_mode: append instruction (Anthropic has no native json_mode) + if json_mode and system_text: + system_text += "\n\nRespond ONLY with valid JSON." + elif json_mode: + system_text = "Respond ONLY with valid JSON." + + kwargs: dict = dict( + model=model or self.default_model, + messages=user_messages, + temperature=temperature, + max_tokens=max_tokens, + ) + if system_text: + kwargs["system"] = system_text + + response = await asyncio.to_thread( + self._client.messages.create, **kwargs + ) + return response.content[0].text + + +# --------------------------------------------------------------------------- +# Factory functions — one per surface +# --------------------------------------------------------------------------- + +_OPENAI_KEY_VAR = "OPENAI_API_KEY" +_ANTHROPIC_KEY_VAR = "ANTHROPIC_API_KEY" + + +def _require_key(env_var: str) -> str: + key = os.getenv(env_var, "") + if not key: + raise ValueError(f"{env_var} is required but not set") + return key + + +def create_classifier_provider() -> ChatProvider: + """Classifier provider. OpenAI only (structured output sensitive).""" + provider = os.getenv("CLASSIFIER_PROVIDER", "openai").strip().lower() + model = os.getenv("CLASSIFIER_MODEL", "gpt-4o") + + if provider == "openai": + return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model) + + raise ValueError(f"Unsupported CLASSIFIER_PROVIDER: {provider!r} (only 'openai' supported)") + + +def create_chat_provider() -> ChatProvider: + """Chat answer provider. Supports OpenAI and Anthropic.""" + provider = os.getenv("CHAT_PROVIDER", "openai").strip().lower() + # CHAT_MODEL with CHAT_LLM_MODEL as backward-compat alias + model = os.getenv("CHAT_MODEL") or os.getenv("CHAT_LLM_MODEL", "gpt-4o-mini") + + if provider == "openai": + return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model) + + if provider == "anthropic": + model = os.getenv("CHAT_MODEL", "claude-sonnet-4-20250514") + return AnthropicChatProvider(api_key=_require_key(_ANTHROPIC_KEY_VAR), default_model=model) + + raise ValueError(f"Unsupported CHAT_PROVIDER: {provider!r} (supported: 'openai', 'anthropic')") + + +def create_expansion_provider() -> ChatProvider: + """Query expansion provider. OpenAI only.""" + provider = os.getenv("EXPANSION_PROVIDER", "openai").strip().lower() + model = os.getenv("EXPANSION_MODEL", "gpt-4o-mini") + + if provider == "openai": + return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model) + + raise ValueError(f"Unsupported EXPANSION_PROVIDER: {provider!r} (only 'openai' supported)") diff --git a/api/personal_ingest_api.py b/api/personal_ingest_api.py index 5d8776f..57726ce 100644 --- a/api/personal_ingest_api.py +++ b/api/personal_ingest_api.py @@ -175,10 +175,19 @@ def quartz_url(entity_type: str, name: str) -> Optional[str]: # Global connection pool db_pool: Optional[asyncpg.Pool] = None -openai_client: Optional[Any] = None # lazy init for /chat LLM calls from api.embedding_provider import EmbeddingProvider, create_embedding_provider embedding_provider: Optional[EmbeddingProvider] = None + +from api.chat_provider import ( + ChatProvider, + create_classifier_provider, + create_chat_provider, + create_expansion_provider, +) +classifier_provider: Optional[ChatProvider] = None +chat_answer_provider: Optional[ChatProvider] = None +expansion_provider: Optional[ChatProvider] = None terminusdb_adapter: Optional[Any] = None # TerminusDBAdapter instance (lazy init) @@ -1225,7 +1234,7 @@ async def enqueue_outbox( @app.on_event("startup") async def startup(): """Initialize database connection pool and embedding provider""" - global db_pool, openai_client, embedding_provider + global db_pool, embedding_provider try: db_pool = await asyncpg.create_pool( DB_URL, @@ -1306,6 +1315,14 @@ async def startup(): except Exception as e: logger.warning(f"Commitment routers not mounted: {e}") + # Protocol layer router (requirements, coverage, signals, gap computation) + try: + from api.routers.protocol_router import create_protocol_router + app.include_router(create_protocol_router(db_pool), prefix="/protocol") + logger.info("Protocol router mounted (/protocol/)") + except Exception as e: + logger.warning(f"Protocol router not mounted: {e}") + # Task router is always mounted (no capability gate — core feature) try: from api.routers.task_router import create_router as create_task_router @@ -4596,8 +4613,6 @@ async def graph_shortest_path( # /chat Endpoint — RAG-powered conversational interface # ============================================================================= -CHAT_LLM_MODEL = os.getenv('CHAT_LLM_MODEL', 'gpt-4o-mini') - # ── B2 GraphRAG: graph-guided retrieval ────────────────────────────── @@ -5104,16 +5119,18 @@ def _rerank_chunks(query: str, chunks: list, top_k: int = 8) -> list: return reranked +EXPANSION_MODEL = os.getenv("EXPANSION_MODEL", "gpt-4o-mini") + + async def _expand_queries(query: str, n: int = 3) -> list: """Generate n query reformulations for multi-query retrieval (B8b). Returns [original_query] + up to n reformulations. """ - if not openai_client: + if not expansion_provider: return [query] try: - resp = openai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[{ + result = await expansion_provider.complete( + [{ "role": "user", "content": f"""Generate {n} alternative search queries for a bioregional knowledge commons. Original query: "{query}" @@ -5121,10 +5138,11 @@ async def _expand_queries(query: str, n: int = 3) -> list: Return ONLY the alternative queries, one per line. No numbering, no explanation. Each should rephrase the question using different terminology to find relevant documents.""" }], + model=EXPANSION_MODEL, max_tokens=200, temperature=0.7, ) - lines = [l.strip() for l in resp.choices[0].message.content.strip().split('\n') if l.strip()] + lines = [l.strip() for l in result.strip().split('\n') if l.strip()] return [query] + lines[:n] except Exception as e: logger.warning(f"B8b query expansion failed: {e}") @@ -5139,6 +5157,7 @@ class ChatRequest(BaseModel): include_code: bool = Field(default=False, description="Include code entity chunks in retrieval (default: exclude)") multi_query: bool = Field(default=False, description="Enable multi-query expansion for broader retrieval (B8b)") planner: bool = Field(default=False, description="B9a QueryPlan IR path (experimental)") + debug_prompt: bool = Field(default=False, description="Include assembled prompt in response (requires CHAT_DEBUG_PROMPT env)") @app.post("/chat") @@ -5153,21 +5172,20 @@ async def chat_endpoint(request: ChatRequest): """ if not db_pool: raise HTTPException(status_code=503, detail="Database not available") - # Lazy init openai_client for /chat LLM calls (separate from embedding provider) - global openai_client - if not openai_client: - if not OPENAI_API_KEY: - raise HTTPException( - status_code=503, - detail="LLM service not available (OPENAI_API_KEY not configured)", - ) + # Lazy init chat providers (classifier, chat answer, expansion) + global classifier_provider, chat_answer_provider, expansion_provider + if not classifier_provider or not chat_answer_provider or not expansion_provider: try: - from openai import OpenAI - openai_client = OpenAI(api_key=OPENAI_API_KEY) - except ImportError: + if not classifier_provider: + classifier_provider = create_classifier_provider() + if not chat_answer_provider: + chat_answer_provider = create_chat_provider() + if not expansion_provider: + expansion_provider = create_expansion_provider() + except (ValueError, ImportError) as e: raise HTTPException( status_code=503, - detail="openai package not installed", + detail=f"LLM provider not available: {e}", ) # ------------------------------------------------------------------ @@ -5187,7 +5205,7 @@ async def chat_endpoint(request: ChatRequest): from api.retrieval_executors import evidence_bundles_to_legacy_format from api.schemas.query_plan import QueryTaxonomy - classifier_output = await classify_query(request.query, openai_client) + classifier_output = await classify_query(request.query, classifier_provider) if classifier_output.query_taxonomy == QueryTaxonomy.OUT_OF_DOMAIN and \ classifier_output.confidence >= CLASSIFIER_CONFIDENCE_THRESHOLD: @@ -5356,17 +5374,14 @@ async def chat_endpoint(request: ChatRequest): # 4. Call LLM # ------------------------------------------------------------------ try: - llm_response = await asyncio.to_thread( - openai_client.chat.completions.create, - model=CHAT_LLM_MODEL, - messages=[ + answer = await chat_answer_provider.complete( + [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.3, max_tokens=1024, ) - answer = llm_response.choices[0].message.content or "" except Exception as e: logger.error(f"LLM call failed: {e}") raise HTTPException( @@ -5393,6 +5408,12 @@ async def chat_endpoint(request: ChatRequest): # Always emit plan_trace when planner was requested (including fallback) if planner_requested and plan_trace is not None: response["plan_trace"] = plan_trace + # Debug prompt capture (double-gated: request flag + server env var) + if request.debug_prompt and os.getenv("CHAT_DEBUG_PROMPT"): + response["_debug_prompt"] = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + } return response diff --git a/api/query_classifier.py b/api/query_classifier.py index 78a66a6..c0867c5 100644 --- a/api/query_classifier.py +++ b/api/query_classifier.py @@ -16,7 +16,9 @@ import json import logging +import os +from api.chat_provider import ChatProvider from api.schemas.query_plan import ( ClassifierOutput, DepthTier, @@ -26,6 +28,8 @@ logger = logging.getLogger(__name__) +CLASSIFIER_MODEL = os.getenv("CLASSIFIER_MODEL", "gpt-4o") + CLASSIFIER_CONFIDENCE_THRESHOLD = 0.7 CLASSIFIER_PROMPT = """You are a query classifier for a bioregional knowledge commons (BKC). @@ -158,30 +162,26 @@ def _apply_guardrails(query: str, output: ClassifierOutput) -> ClassifierOutput: async def classify_query( query: str, - openai_client, - model: str = "gpt-4o", + provider: ChatProvider, + model: str | None = None, ) -> ClassifierOutput: - """Classify a query into a QueryTaxonomy category via GPT-4o. + """Classify a query into a QueryTaxonomy category via an LLM provider. Returns ClassifierOutput with confidence score. On parse error or unexpected output, returns OUT_OF_DOMAIN with confidence=0.0 to trigger fallback to the baseline retrieval path. """ - import asyncio - try: - response = await asyncio.to_thread( - openai_client.chat.completions.create, - model=model, - messages=[ + raw = await provider.complete( + [ {"role": "system", "content": CLASSIFIER_PROMPT}, {"role": "user", "content": query}, ], + model=model or CLASSIFIER_MODEL, temperature=0.0, max_tokens=200, - response_format={"type": "json_object"}, + json_mode=True, ) - raw = response.choices[0].message.content or "{}" data = json.loads(raw) # Parse taxonomy diff --git a/api/routers/claims_router.py b/api/routers/claims_router.py index fe5b3ed..b5cf044 100644 --- a/api/routers/claims_router.py +++ b/api/routers/claims_router.py @@ -704,6 +704,7 @@ async def list_claims( claim_type: Optional[str] = Query(None, description="Filter by claim type"), claimant_uri: Optional[str] = Query(None, description="Filter by claimant"), about_uri: Optional[str] = Query(None, description="Filter by about entity (via graph edge)"), + since: Optional[datetime] = Query(None, description="Filter to claims created on or after this ISO datetime (e.g. 2026-01-01T00:00:00Z)"), limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0), ): @@ -732,6 +733,10 @@ async def list_claims( )""") params.append(about_uri) i += 1 + if since: + conditions.append(f"c.created_at >= ${i}") + params.append(since) + i += 1 where = ("WHERE " + " AND ".join(conditions)) if conditions else "" params.extend([limit, offset]) diff --git a/api/routers/commitment_router.py b/api/routers/commitment_router.py index d40c271..7c3fd25 100644 --- a/api/routers/commitment_router.py +++ b/api/routers/commitment_router.py @@ -55,6 +55,7 @@ class CommitmentResponse(BaseModel): validity_start: Optional[datetime] validity_end: Optional[datetime] state: str + scope: Optional[str] = None evidence_uri: Optional[str] metadata: Dict[str, Any] created_at: datetime @@ -94,6 +95,7 @@ class PoolResponse(BaseModel): activation_threshold_count: Optional[int] demurrage_rate_monthly: float state: str + scope: Optional[str] = None metadata: Dict[str, Any] created_at: datetime updated_at: datetime @@ -298,6 +300,7 @@ async def list_commitments( state: Optional[str] = Query(None, description="Filter by state"), pledger_uri: Optional[str] = Query(None), pool_rid: Optional[str] = Query(None), + offer_type: Optional[str] = Query(None, description="Filter by offer type (labor, goods, service, knowledge, stewardship)"), limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0), ): @@ -318,6 +321,10 @@ async def list_commitments( conditions.append(f"pool_id = (SELECT id FROM commitment_pools WHERE pool_rid = ${i})") params.append(pool_rid) i += 1 + if offer_type: + conditions.append(f"c.offer_type = ${i}") + params.append(offer_type) + i += 1 where = ("WHERE " + " AND ".join(conditions)) if conditions else "" params.extend([limit, offset]) @@ -929,6 +936,7 @@ def _row_to_commitment(row) -> CommitmentResponse: validity_start=row.get("validity_start"), validity_end=row.get("validity_end"), state=row["state"], + scope=row.get("scope"), evidence_uri=row.get("evidence_uri"), metadata=meta or {}, created_at=row["created_at"], @@ -951,6 +959,7 @@ def _row_to_pool(row) -> PoolResponse: activation_threshold_count=row.get("activation_threshold_count"), demurrage_rate_monthly=float(row["demurrage_rate_monthly"]), state=row["state"], + scope=row.get("scope"), metadata=meta or {}, created_at=row["created_at"], updated_at=row["updated_at"], diff --git a/api/routers/protocol_router.py b/api/routers/protocol_router.py new file mode 100644 index 0000000..bd43f34 --- /dev/null +++ b/api/routers/protocol_router.py @@ -0,0 +1,578 @@ +"""Protocol layer endpoints: requirements, coverage, signals, gap computation. + +Implements the Claims × Spore coordination protocol: + POST /requirements/create — declare a normative requirement + GET /requirements/{rid} — fetch requirement by RID + GET /requirements/ — list requirements (filterable) + POST /coverage/link — create a coverage link + GET /coverage/ — list coverage links (filterable) + GET /pools/{rid}/gaps — compute unmet/stale requirements for a pool + POST /signals/create — record a signal + GET /signals/ — list signals (filterable) + +Additive layer on top of existing claims/commitments/intents. +""" + +import hashlib +import json +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Pydantic models +# --------------------------------------------------------------------------- + +class RequirementCreateRequest(BaseModel): + scope: str = Field(..., description="personal | team | pool | org | federation | on_chain_group") + scope_ref: Optional[str] = Field(None, description="URI of the scoped entity (pool_rid, org_uri)") + policy_source: str = Field(..., description="URI of the policy/constitution declaring this") + requirement_type: str = Field(..., description="monitoring | reporting | stewardship | governance | contribution") + statement: str = Field(..., min_length=5, max_length=1000) + subject_uri: Optional[str] = None + frequency: Optional[str] = Field(None, description="once | weekly | monthly | quarterly | annual") + freshness_window_days: Optional[int] = Field(None, ge=1) + severity: str = Field("medium", description="low | medium | high | critical") + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class RequirementResponse(BaseModel): + requirement_rid: str + scope: str + scope_ref: Optional[str] + policy_source: str + requirement_type: str + statement: str + subject_uri: Optional[str] + frequency: Optional[str] + freshness_window_days: Optional[int] + severity: str + active: bool + metadata: Dict[str, Any] + created_at: datetime + + +class CoverageLinkRequest(BaseModel): + coverage_type: str = Field(..., description="commitment_covers_requirement | claim_covers_condition | evidence_covers_commitment") + source_rid: str = Field(..., description="RID of the covering artifact") + target_rid: str = Field(..., description="RID of the covered artifact") + valid_from: Optional[datetime] = None + valid_until: Optional[datetime] = None + confidence: Optional[float] = Field(None, ge=0.0, le=1.0) + provenance: Optional[str] = Field(None, description="manual | ai_inferred | policy_rule") + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class CoverageLinkResponse(BaseModel): + coverage_rid: str + coverage_type: str + source_rid: str + target_rid: str + valid_from: datetime + valid_until: Optional[datetime] + confidence: Optional[float] + provenance: Optional[str] + metadata: Dict[str, Any] + created_at: datetime + + +class SignalCreateRequest(BaseModel): + signal_type: str = Field(..., description="declaration | discourse | gap_computed | sensor | document_extract") + source_kind: str = Field(..., description="What produced it") + source_ref: Optional[str] = None + statement: str = Field(..., min_length=5, max_length=2000) + scope: str = Field(..., description="personal | team | pool | org | federation | on_chain_group") + subject_uri: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + confidence: Optional[float] = Field(None, ge=0.0, le=1.0) + fresh_until: Optional[datetime] = None + + +class SignalResponse(BaseModel): + signal_rid: str + signal_type: str + source_kind: str + source_ref: Optional[str] + statement: str + scope: str + subject_uri: Optional[str] + metadata: Dict[str, Any] + confidence: Optional[float] + fresh_until: Optional[datetime] + created_at: datetime + + +class GapSignalResponse(BaseModel): + requirement_rid: str + requirement_statement: str + requirement_type: str + severity: str + frequency: Optional[str] + freshness_window_days: Optional[int] + gap_type: str # unmet | stale + coverage_count: int + latest_coverage_until: Optional[datetime] + signal_rid: Optional[str] # RID of emitted signal (if created) + next_move: str # surface_only | request_offer | propose_commitment | escalate_to_council + + +class PoolGapsResponse(BaseModel): + pool_rid: str + total_requirements: int + unmet_count: int + stale_count: int + covered_count: int + gaps: List[GapSignalResponse] + + +# --------------------------------------------------------------------------- +# RID helpers +# --------------------------------------------------------------------------- + +_VALID_SCOPES = {"personal", "team", "pool", "org", "federation", "on_chain_group"} +_VALID_REQ_TYPES = {"monitoring", "reporting", "stewardship", "governance", "contribution"} +_VALID_SEVERITIES = {"low", "medium", "high", "critical"} +_VALID_FREQUENCIES = {"once", "weekly", "monthly", "quarterly", "annual"} +_VALID_SIGNAL_TYPES = {"declaration", "discourse", "gap_computed", "sensor", "document_extract"} +_VALID_COVERAGE_TYPES = {"commitment_covers_requirement", "claim_covers_condition", "evidence_covers_commitment"} +_VALID_PROVENANCES = {"manual", "ai_inferred", "policy_rule"} + + +def _requirement_rid(scope: str, scope_ref: Optional[str], policy_source: str, + requirement_type: str, statement: str, + subject_uri: Optional[str]) -> str: + canonical = json.dumps({ + "policy_source": policy_source, + "requirement_type": requirement_type, + "scope": scope, + "scope_ref": scope_ref or "", + "statement": statement, + "subject_uri": subject_uri or "", + }, sort_keys=True, separators=(",", ":")) + h = hashlib.blake2b(canonical.encode(), digest_size=32).hexdigest()[:32] + return f"orn:koi-net.requirement:{h}" + + +def _coverage_rid(coverage_type: str, source_rid: str, target_rid: str) -> str: + canonical = json.dumps({ + "coverage_type": coverage_type, + "source_rid": source_rid, + "target_rid": target_rid, + }, sort_keys=True, separators=(",", ":")) + h = hashlib.blake2b(canonical.encode(), digest_size=32).hexdigest()[:32] + return f"orn:koi-net.coverage:{h}" + + +def _signal_rid(signal_type: str, source_kind: str, source_ref: Optional[str], + statement: str, scope: str, + subject_uri: Optional[str] = None) -> str: + canonical = json.dumps({ + "scope": scope, + "signal_type": signal_type, + "source_kind": source_kind, + "source_ref": source_ref or "", + "statement": statement, + "subject_uri": subject_uri or "", + }, sort_keys=True, separators=(",", ":")) + h = hashlib.blake2b(canonical.encode(), digest_size=32).hexdigest()[:32] + return f"orn:koi-net.signal:{h}" + + +def _row_to_response(model_cls, row): + """Convert asyncpg Record to Pydantic model, parsing JSONB strings.""" + d = dict(row) + if "metadata" in d and isinstance(d["metadata"], str): + d["metadata"] = json.loads(d["metadata"]) + return model_cls(**d) + + +def _next_move(severity: str, scope: str) -> str: + """Compute suggested next move based on severity and scope.""" + if severity == "critical": + return "escalate_to_council" + if severity == "high": + return "propose_commitment" + if severity == "medium": + return "request_offer" + return "surface_only" + + +# --------------------------------------------------------------------------- +# Router factory +# --------------------------------------------------------------------------- + +def create_protocol_router(pool) -> APIRouter: + router = APIRouter() + + # ------------------------------------------------------------------ # + # Requirements CRUD # + # ------------------------------------------------------------------ # + + @router.post("/requirements/create", response_model=RequirementResponse, status_code=201) + async def create_requirement(body: RequirementCreateRequest): + if body.scope not in _VALID_SCOPES: + raise HTTPException(400, f"Invalid scope: {body.scope}") + if body.requirement_type not in _VALID_REQ_TYPES: + raise HTTPException(400, f"Invalid requirement_type: {body.requirement_type}") + if body.severity not in _VALID_SEVERITIES: + raise HTTPException(400, f"Invalid severity: {body.severity}") + if body.frequency and body.frequency not in _VALID_FREQUENCIES: + raise HTTPException(400, f"Invalid frequency: {body.frequency}") + + rid = _requirement_rid(body.scope, body.scope_ref, body.policy_source, + body.requirement_type, body.statement, body.subject_uri) + + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + INSERT INTO requirements ( + requirement_rid, scope, scope_ref, policy_source, requirement_type, + statement, subject_uri, frequency, freshness_window_days, + severity, metadata + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (requirement_rid) DO UPDATE SET + frequency = EXCLUDED.frequency, + freshness_window_days = EXCLUDED.freshness_window_days, + severity = EXCLUDED.severity, + subject_uri = EXCLUDED.subject_uri, + metadata = EXCLUDED.metadata + RETURNING * + """, rid, body.scope, body.scope_ref, body.policy_source, + body.requirement_type, body.statement, body.subject_uri, + body.frequency, body.freshness_window_days, body.severity, + json.dumps(body.metadata)) + return _row_to_response(RequirementResponse, row) + + @router.get("/requirements/{rid}", response_model=RequirementResponse) + async def get_requirement(rid: str): + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT * FROM requirements WHERE requirement_rid = $1", rid) + if not row: + raise HTTPException(404, f"Requirement not found: {rid}") + return _row_to_response(RequirementResponse, row) + + @router.get("/requirements/", response_model=List[RequirementResponse]) + async def list_requirements( + scope: Optional[str] = Query(None), + scope_ref: Optional[str] = Query(None), + requirement_type: Optional[str] = Query(None), + severity: Optional[str] = Query(None), + active_only: bool = Query(True), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + ): + async with pool.acquire() as conn: + conditions = [] + params: list = [] + i = 1 + if scope: + conditions.append(f"scope = ${i}") + params.append(scope) + i += 1 + if scope_ref: + conditions.append(f"scope_ref = ${i}") + params.append(scope_ref) + i += 1 + if requirement_type: + conditions.append(f"requirement_type = ${i}") + params.append(requirement_type) + i += 1 + if severity: + conditions.append(f"severity = ${i}") + params.append(severity) + i += 1 + if active_only: + conditions.append("active = TRUE") + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + params.extend([limit, offset]) + rows = await conn.fetch(f""" + SELECT * FROM requirements {where} + ORDER BY created_at DESC + LIMIT ${i} OFFSET ${i+1} + """, *params) + return [_row_to_response(RequirementResponse, r) for r in rows] + + # ------------------------------------------------------------------ # + # Coverage CRUD # + # ------------------------------------------------------------------ # + + @router.post("/coverage/link", response_model=CoverageLinkResponse, status_code=201) + async def create_coverage_link(body: CoverageLinkRequest): + if body.coverage_type not in _VALID_COVERAGE_TYPES: + raise HTTPException(400, f"Invalid coverage_type: {body.coverage_type}") + if body.provenance and body.provenance not in _VALID_PROVENANCES: + raise HTTPException(400, f"Invalid provenance: {body.provenance}") + + rid = _coverage_rid(body.coverage_type, body.source_rid, body.target_rid) + valid_from = body.valid_from or datetime.now(timezone.utc) + + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + INSERT INTO coverage_links ( + coverage_rid, coverage_type, source_rid, target_rid, + valid_from, valid_until, confidence, provenance, metadata + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (coverage_rid) DO UPDATE SET + valid_from = EXCLUDED.valid_from, + valid_until = EXCLUDED.valid_until, + confidence = EXCLUDED.confidence, + provenance = EXCLUDED.provenance, + metadata = EXCLUDED.metadata + RETURNING * + """, rid, body.coverage_type, body.source_rid, body.target_rid, + valid_from, body.valid_until, body.confidence, body.provenance, + json.dumps(body.metadata)) + return _row_to_response(CoverageLinkResponse, row) + + @router.get("/coverage/", response_model=List[CoverageLinkResponse]) + async def list_coverage( + target_rid: Optional[str] = Query(None, description="Filter by covered artifact"), + source_rid: Optional[str] = Query(None, description="Filter by covering artifact"), + coverage_type: Optional[str] = Query(None), + valid_only: bool = Query(True, description="Only return currently valid coverage"), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + ): + async with pool.acquire() as conn: + conditions = [] + params: list = [] + i = 1 + if target_rid: + conditions.append(f"target_rid = ${i}") + params.append(target_rid) + i += 1 + if source_rid: + conditions.append(f"source_rid = ${i}") + params.append(source_rid) + i += 1 + if coverage_type: + conditions.append(f"coverage_type = ${i}") + params.append(coverage_type) + i += 1 + if valid_only: + conditions.append("valid_from <= NOW()") + conditions.append("(valid_until IS NULL OR valid_until > NOW())") + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + params.extend([limit, offset]) + rows = await conn.fetch(f""" + SELECT * FROM coverage_links {where} + ORDER BY created_at DESC + LIMIT ${i} OFFSET ${i+1} + """, *params) + return [_row_to_response(CoverageLinkResponse, r) for r in rows] + + # ------------------------------------------------------------------ # + # Signals CRUD # + # ------------------------------------------------------------------ # + + @router.post("/signals/create", response_model=SignalResponse, status_code=201) + async def create_signal(body: SignalCreateRequest): + if body.signal_type not in _VALID_SIGNAL_TYPES: + raise HTTPException(400, f"Invalid signal_type: {body.signal_type}") + if body.scope not in _VALID_SCOPES: + raise HTTPException(400, f"Invalid scope: {body.scope}") + + rid = _signal_rid(body.signal_type, body.source_kind, body.source_ref, + body.statement, body.scope, body.subject_uri) + + async with pool.acquire() as conn: + row = await conn.fetchrow(""" + INSERT INTO signals ( + signal_rid, signal_type, source_kind, source_ref, + statement, scope, subject_uri, metadata, + confidence, fresh_until + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (signal_rid) DO UPDATE SET + metadata = EXCLUDED.metadata, + confidence = EXCLUDED.confidence, + fresh_until = EXCLUDED.fresh_until + RETURNING * + """, rid, body.signal_type, body.source_kind, body.source_ref, + body.statement, body.scope, body.subject_uri, + json.dumps(body.metadata), body.confidence, body.fresh_until) + return _row_to_response(SignalResponse, row) + + @router.get("/signals/", response_model=List[SignalResponse]) + async def list_signals( + signal_type: Optional[str] = Query(None), + scope: Optional[str] = Query(None), + source_ref: Optional[str] = Query(None), + fresh_only: bool = Query(False, description="Only return signals that are not stale"), + limit: int = Query(50, ge=1, le=200), + offset: int = Query(0, ge=0), + ): + async with pool.acquire() as conn: + conditions = [] + params: list = [] + i = 1 + if signal_type: + conditions.append(f"signal_type = ${i}") + params.append(signal_type) + i += 1 + if scope: + conditions.append(f"scope = ${i}") + params.append(scope) + i += 1 + if source_ref: + conditions.append(f"source_ref = ${i}") + params.append(source_ref) + i += 1 + if fresh_only: + conditions.append("(fresh_until IS NULL OR fresh_until > NOW())") + + where = ("WHERE " + " AND ".join(conditions)) if conditions else "" + params.extend([limit, offset]) + rows = await conn.fetch(f""" + SELECT * FROM signals {where} + ORDER BY created_at DESC + LIMIT ${i} OFFSET ${i+1} + """, *params) + return [_row_to_response(SignalResponse, r) for r in rows] + + # ------------------------------------------------------------------ # + # Gap computation # + # ------------------------------------------------------------------ # + + @router.get("/pools/{pool_rid}/gaps", response_model=PoolGapsResponse) + async def compute_pool_gaps(pool_rid: str): + """Compute unmet and stale requirements for a commitment pool. + + For each active requirement scoped to this pool, checks coverage_links + for valid coverage. No valid coverage → gap. Stale coverage (valid_until + expired or freshness_window exceeded) → stale gap. + + Emits gap_computed signals for each gap found. + """ + async with pool.acquire() as conn: + # Verify pool exists + pool_row = await conn.fetchrow( + "SELECT pool_rid FROM commitment_pools WHERE pool_rid = $1", pool_rid) + if not pool_row: + raise HTTPException(404, f"Pool not found: {pool_rid}") + + # Fetch active requirements for this pool + requirements = await conn.fetch(""" + SELECT * FROM requirements + WHERE active = TRUE + AND scope = 'pool' + AND scope_ref = $1 + ORDER BY CASE severity + WHEN 'critical' THEN 0 + WHEN 'high' THEN 1 + WHEN 'medium' THEN 2 + WHEN 'low' THEN 3 + ELSE 4 END ASC, + created_at ASC + """, pool_rid) + + gaps: list = [] + covered_count = 0 + now = datetime.now(timezone.utc) + + for req in requirements: + req_rid = req["requirement_rid"] + + # Check valid coverage (only commitment_covers_requirement, + # must have started and not expired) + coverage = await conn.fetch(""" + SELECT * FROM coverage_links + WHERE target_rid = $1 + AND coverage_type = 'commitment_covers_requirement' + AND valid_from <= $2 + AND (valid_until IS NULL OR valid_until > $2) + ORDER BY valid_from DESC + """, req_rid, now) + + # Determine gap status + gap_type = None + latest_until = None + + if not coverage: + # No coverage at all — check if there was ever any (stale vs unmet) + # Only commitment_covers_requirement counts for pool gap history + expired = await conn.fetchrow(""" + SELECT valid_until FROM coverage_links + WHERE target_rid = $1 + AND coverage_type = 'commitment_covers_requirement' + AND valid_from <= $2 + ORDER BY valid_until DESC NULLS LAST + LIMIT 1 + """, req_rid, now) + if expired and expired["valid_until"]: + gap_type = "stale" + latest_until = expired["valid_until"] + else: + gap_type = "unmet" + else: + # Has valid coverage — check freshness window if recurrent + fw = req["freshness_window_days"] + if fw: + latest_from = max(c["valid_from"] for c in coverage) + if now - latest_from > timedelta(days=fw): + gap_type = "stale" + latest_until = latest_from + timedelta(days=fw) + + if gap_type: + next_move = _next_move(req["severity"], req["scope"]) + + # Emit a gap_computed signal + statement = ( + f"{req['requirement_type'].title()} gap: {req['statement']} " + f"({gap_type}, severity={req['severity']})" + ) + sig_rid = _signal_rid("gap_computed", "gap_computation", pool_rid, + statement, "pool", req["subject_uri"]) + sig_meta = json.dumps({ + "requirement_rid": req_rid, + "gap_type": gap_type, + "next_move": next_move, + "coverage_count": len(coverage), + "computed_at": now.isoformat(), + }) + await conn.execute(""" + INSERT INTO signals ( + signal_rid, signal_type, source_kind, source_ref, + statement, scope, subject_uri, metadata, confidence + ) VALUES ($1, 'gap_computed', 'gap_computation', $2, $3, 'pool', $4, $5, 1.0) + ON CONFLICT (signal_rid) DO UPDATE SET + metadata = EXCLUDED.metadata, + confidence = EXCLUDED.confidence + """, sig_rid, pool_rid, statement, req["subject_uri"], sig_meta) + + gaps.append(GapSignalResponse( + requirement_rid=req_rid, + requirement_statement=req["statement"], + requirement_type=req["requirement_type"], + severity=req["severity"], + frequency=req["frequency"], + freshness_window_days=req["freshness_window_days"], + gap_type=gap_type, + coverage_count=len(coverage), + latest_coverage_until=latest_until, + signal_rid=sig_rid, + next_move=next_move, + )) + else: + covered_count += 1 + + unmet = sum(1 for g in gaps if g.gap_type == "unmet") + stale = sum(1 for g in gaps if g.gap_type == "stale") + + return PoolGapsResponse( + pool_rid=pool_rid, + total_requirements=len(requirements), + unmet_count=unmet, + stale_count=stale, + covered_count=covered_count, + gaps=gaps, + ) + + return router diff --git a/migrations/079_requirements.sql b/migrations/079_requirements.sql new file mode 100644 index 0000000..64f20f0 --- /dev/null +++ b/migrations/079_requirements.sql @@ -0,0 +1,47 @@ +-- 079_requirements.sql +-- Normative expectations with cadence — what a pool/org/federation constitution +-- says should be true. Generic across scopes; pool requirements are the first +-- use case but the table serves any normative artifact. +-- +-- Part of the Claims × Spore protocol layer (additive, no rewrites). + +CREATE TABLE IF NOT EXISTS requirements ( + id SERIAL PRIMARY KEY, + requirement_rid TEXT UNIQUE NOT NULL, -- orn:koi-net.requirement: + scope TEXT NOT NULL, -- personal | team | pool | org | federation | on_chain_group + scope_ref TEXT, -- URI of the scoped entity (pool_rid, org_uri, etc.) + policy_source TEXT NOT NULL, -- URI of the policy/constitution that declares this + requirement_type TEXT NOT NULL + CHECK (requirement_type IN ('monitoring', 'reporting', 'stewardship', 'governance', 'contribution')), + statement TEXT NOT NULL, -- Human-readable requirement + subject_uri TEXT, -- entity_registry.fuseki_uri of the entity this concerns + frequency TEXT + CHECK (frequency IS NULL OR frequency IN ('once', 'weekly', 'monthly', 'quarterly', 'annual')), + freshness_window_days INTEGER, -- How many days before coverage becomes stale + severity TEXT NOT NULL DEFAULT 'medium' + CHECK (severity IN ('low', 'medium', 'high', 'critical')), + active BOOLEAN NOT NULL DEFAULT TRUE, + metadata JSONB DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Predicates for requirement entities +INSERT INTO allowed_predicates (predicate, description, subject_types, object_types) +VALUES + ('requires', 'Pool, org, or policy declares a requirement', + ARRAY['CommitmentPool', 'Organization', 'Policy', 'SpecDoc'], + ARRAY['Requirement']) +ON CONFLICT (predicate) DO NOTHING; + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_requirements_scope ON requirements(scope); +CREATE INDEX IF NOT EXISTS idx_requirements_scope_ref ON requirements(scope_ref); +CREATE INDEX IF NOT EXISTS idx_requirements_type ON requirements(requirement_type); +CREATE INDEX IF NOT EXISTS idx_requirements_active ON requirements(active) WHERE active = TRUE; +CREATE INDEX IF NOT EXISTS idx_requirements_severity ON requirements(severity); +CREATE INDEX IF NOT EXISTS idx_requirements_subject ON requirements(subject_uri); + +-- Register migration +INSERT INTO koi_migrations (migration_id, checksum) +VALUES ('personal:079_requirements', 'v1_protocol_layer') +ON CONFLICT (migration_id) DO NOTHING; diff --git a/migrations/080_scope_on_commitments.sql b/migrations/080_scope_on_commitments.sql new file mode 100644 index 0000000..7edd1c1 --- /dev/null +++ b/migrations/080_scope_on_commitments.sql @@ -0,0 +1,26 @@ +-- 080_scope_on_commitments.sql +-- Add first-class scope column to commitments and commitment_pools. +-- Scope determines governance path — too important for metadata. +-- +-- Part of the Claims × Spore protocol layer (additive, no rewrites). + +-- Add scope to commitments (default 'pool' — backward compatible) +DO $$ BEGIN + ALTER TABLE commitments ADD COLUMN scope TEXT DEFAULT 'pool'; +EXCEPTION WHEN duplicate_column THEN NULL; +END $$; + +-- Add scope to commitment_pools (default 'pool' — backward compatible) +DO $$ BEGIN + ALTER TABLE commitment_pools ADD COLUMN scope TEXT DEFAULT 'pool'; +EXCEPTION WHEN duplicate_column THEN NULL; +END $$; + +-- Index for scope queries +CREATE INDEX IF NOT EXISTS idx_commitments_scope ON commitments(scope); +CREATE INDEX IF NOT EXISTS idx_pools_scope ON commitment_pools(scope); + +-- Register migration +INSERT INTO koi_migrations (migration_id, checksum) +VALUES ('personal:080_scope_on_commitments', 'v1_protocol_layer') +ON CONFLICT (migration_id) DO NOTHING; diff --git a/migrations/081_coverage_links.sql b/migrations/081_coverage_links.sql new file mode 100644 index 0000000..a2e440e --- /dev/null +++ b/migrations/081_coverage_links.sql @@ -0,0 +1,56 @@ +-- 081_coverage_links.sql +-- Explicit relational primitive for gap computation. +-- Coverage links connect artifacts that satisfy requirements: +-- commitment covers requirement +-- claim covers condition +-- evidence covers commitment +-- +-- Gap computation: for each active requirement, query coverage_links where +-- target_rid = requirement_rid AND valid_until > now(). No valid coverage → gap. +-- +-- Part of the Claims × Spore protocol layer (additive, no rewrites). + +CREATE TABLE IF NOT EXISTS coverage_links ( + id SERIAL PRIMARY KEY, + coverage_rid TEXT UNIQUE NOT NULL, -- deterministic RID + coverage_type TEXT NOT NULL + CHECK (coverage_type IN ( + 'commitment_covers_requirement', + 'claim_covers_condition', + 'evidence_covers_commitment' + )), + source_rid TEXT NOT NULL, -- RID of the covering artifact + target_rid TEXT NOT NULL, -- RID of the covered artifact + valid_from TIMESTAMPTZ NOT NULL DEFAULT NOW(), + valid_until TIMESTAMPTZ, -- NULL = open-ended; computed from freshness_window if recurrent + confidence FLOAT, -- 0–1 + provenance TEXT -- manual | ai_inferred | policy_rule + CHECK (provenance IS NULL OR provenance IN ('manual', 'ai_inferred', 'policy_rule')), + metadata JSONB DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Predicates for coverage relationships +INSERT INTO allowed_predicates (predicate, description, subject_types, object_types) +VALUES + ('covers', 'Artifact satisfies or fulfills a requirement, condition, or commitment', + ARRAY['Commitment', 'Claim', 'Evidence'], + ARRAY['Requirement', 'Commitment', 'Condition']) +ON CONFLICT (predicate) DO NOTHING; + +-- Indexes for gap computation +CREATE INDEX IF NOT EXISTS idx_coverage_target ON coverage_links(target_rid); +CREATE INDEX IF NOT EXISTS idx_coverage_source ON coverage_links(source_rid); +CREATE INDEX IF NOT EXISTS idx_coverage_type ON coverage_links(coverage_type); +-- Note: partial indexes with NOW() are not allowed (must be immutable). +-- Use a plain composite index instead; gap queries filter at runtime. +CREATE INDEX IF NOT EXISTS idx_coverage_valid ON coverage_links(target_rid, valid_until); + +-- Composite for gap queries: "find all valid coverage for this requirement" +CREATE INDEX IF NOT EXISTS idx_coverage_gap_check + ON coverage_links(target_rid, coverage_type, valid_until); + +-- Register migration +INSERT INTO koi_migrations (migration_id, checksum) +VALUES ('personal:081_coverage_links', 'v1_protocol_layer') +ON CONFLICT (migration_id) DO NOTHING; diff --git a/migrations/082_signals.sql b/migrations/082_signals.sql new file mode 100644 index 0000000..d7ac53c --- /dev/null +++ b/migrations/082_signals.sql @@ -0,0 +1,48 @@ +-- 082_signals.sql +-- Raw observations, declarations, discourse tensions, or computed gaps. +-- Signals are pre-intent: not yet interpreted into directional action. +-- The gap_computed type is emitted by the negative-space intelligence engine +-- when a requirement has no valid coverage. +-- +-- Part of the Claims × Spore protocol layer (additive, no rewrites). + +CREATE TABLE IF NOT EXISTS signals ( + id SERIAL PRIMARY KEY, + signal_rid TEXT UNIQUE NOT NULL, -- orn:koi-net.signal: + signal_type TEXT NOT NULL + CHECK (signal_type IN ('declaration', 'discourse', 'gap_computed', 'sensor', 'document_extract')), + source_kind TEXT NOT NULL, -- What produced it (transcript, gap_computation, sensor, etc.) + source_ref TEXT, -- URI of source document/pool/computation + statement TEXT NOT NULL, -- Human-readable description + scope TEXT NOT NULL, -- personal | team | pool | org | federation | on_chain_group + subject_uri TEXT, -- entity_registry.fuseki_uri + metadata JSONB DEFAULT '{}'::jsonb, + confidence FLOAT, -- 0–1, NULL if not applicable + fresh_until TIMESTAMPTZ, -- When this signal becomes stale + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Junction table: signals → intents (many-to-many) +-- Links to the existing intent_registry (migration 074) +CREATE TABLE IF NOT EXISTS signal_intents ( + id SERIAL PRIMARY KEY, + signal_rid TEXT NOT NULL, -- signals.signal_rid + intent_rid TEXT NOT NULL, -- intent_registry.intent_rid + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (signal_rid, intent_rid) +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_signals_type ON signals(signal_type); +CREATE INDEX IF NOT EXISTS idx_signals_scope ON signals(scope); +CREATE INDEX IF NOT EXISTS idx_signals_source_ref ON signals(source_ref); +CREATE INDEX IF NOT EXISTS idx_signals_subject ON signals(subject_uri); +CREATE INDEX IF NOT EXISTS idx_signals_fresh ON signals(fresh_until) + WHERE fresh_until IS NOT NULL; +CREATE INDEX IF NOT EXISTS idx_signal_intents_sig ON signal_intents(signal_rid); +CREATE INDEX IF NOT EXISTS idx_signal_intents_int ON signal_intents(intent_rid); + +-- Register migration +INSERT INTO koi_migrations (migration_id, checksum) +VALUES ('personal:082_signals', 'v1_protocol_layer') +ON CONFLICT (migration_id) DO NOTHING; diff --git a/scripts/bakeoff_chat_answers.py b/scripts/bakeoff_chat_answers.py new file mode 100644 index 0000000..7dcff8b --- /dev/null +++ b/scripts/bakeoff_chat_answers.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +""" +P2 Chat-Answer Provider Bakeoff — run both providers on frozen prompt packets. + +Usage: + # Generate answers from frozen packets (needs OPENAI_API_KEY + ANTHROPIC_API_KEY) + python scripts/bakeoff_chat_answers.py + + # Custom paths + python scripts/bakeoff_chat_answers.py \ + --packets tests/eval/results/prompt_packets.jsonl \ + --outdir tests/eval/results +""" + +import argparse +import asyncio +import json +import os +import sys +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from api.chat_provider import OpenAIChatProvider, AnthropicChatProvider + + +TEMPERATURE = 0.3 +MAX_TOKENS = 1024 + + +async def run_provider(provider, packets: list[dict]) -> list[dict]: + """Run a provider against all packets, return answers.""" + results = [] + for packet in packets: + messages = [ + {"role": "system", "content": packet["system_prompt"]}, + {"role": "user", "content": packet["user_prompt"]}, + ] + answer = await provider.complete( + messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS + ) + results.append({ + "id": packet["id"], + "question": packet["question"], + "answer": answer, + }) + print(f" {packet['id']}: {len(answer)} chars", file=sys.stderr) + return results + + +def generate_comparison(packets, openai_answers, anthropic_answers, outpath: Path): + """Generate comparison.md with side-by-side answers and scoring template.""" + oai_map = {a["id"]: a["answer"] for a in openai_answers} + ant_map = {a["id"]: a["answer"] for a in anthropic_answers} + + lines = [ + "# P2 Chat-Answer Bakeoff — Comparison", + "", + f"**Packets:** {len(packets)} frozen prompt packets from Octo", + f"**OpenAI model:** {os.getenv('CHAT_MODEL', os.getenv('CHAT_LLM_MODEL', 'gpt-4o-mini'))}", + f"**Anthropic model:** {os.getenv('ANTHROPIC_CHAT_MODEL', 'claude-sonnet-4-20250514')}", + f"**Temperature:** {TEMPERATURE}, **Max tokens:** {MAX_TOKENS}", + "", + "## Scoring Rubric (1-5 per dimension)", + "", + "| Dimension | Description |", + "|-----------|-------------|", + "| Groundedness | Does the answer use the provided context? |", + "| Completeness | Are key entities/relationships mentioned? |", + "| Citation | Does it reference sources/wiki links? |", + "| Concision | Is it appropriately brief? |", + "| Hallucination risk | Does it invent facts not in context? (5=no hallucination) |", + "", + "---", + "", + ] + + for packet in packets: + qid = packet["id"] + lines.extend([ + f"## {qid}: {packet['question']}", + "", + f"**Sources:** {len(packet.get('sources', []))} entities/docs", + "", + "### OpenAI", + "", + oai_map.get(qid, "(missing)"), + "", + "### Anthropic", + "", + ant_map.get(qid, "(missing)"), + "", + "### Scores", + "", + "| Dimension | OpenAI | Anthropic |", + "|-----------|--------|-----------|", + "| Groundedness | | |", + "| Completeness | | |", + "| Citation | | |", + "| Concision | | |", + "| Hallucination risk | | |", + "| **Preferred** | | |", + "", + "---", + "", + ]) + + lines.extend([ + "## Summary", + "", + "| Question | Preferred |", + "|----------|-----------|", + ]) + for packet in packets: + lines.append(f"| {packet['id']} | |") + lines.extend([ + "", + "**Overall recommendation:** ", + "", + ]) + + outpath.write_text("\n".join(lines)) + print(f" Wrote {outpath}", file=sys.stderr) + + +async def main(): + parser = argparse.ArgumentParser(description="P2 chat-answer bakeoff") + parser.add_argument("--packets", default="tests/eval/results/prompt_packets.jsonl") + parser.add_argument("--outdir", default="tests/eval/results") + args = parser.parse_args() + + packets_path = Path(args.packets) + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + + # Load packets + packets = [json.loads(line) for line in packets_path.read_text().strip().split("\n")] + print(f"Loaded {len(packets)} packets from {packets_path}", file=sys.stderr) + + # Init providers + openai_key = os.environ.get("OPENAI_API_KEY") + anthropic_key = os.environ.get("ANTHROPIC_API_KEY") + if not openai_key: + print("ERROR: OPENAI_API_KEY not set", file=sys.stderr) + sys.exit(1) + if not anthropic_key: + print("ERROR: ANTHROPIC_API_KEY not set", file=sys.stderr) + sys.exit(1) + + openai_model = os.getenv("CHAT_MODEL", os.getenv("CHAT_LLM_MODEL", "gpt-4o-mini")) + anthropic_model = os.getenv("ANTHROPIC_CHAT_MODEL", "claude-sonnet-4-20250514") + + openai_provider = OpenAIChatProvider(api_key=openai_key, default_model=openai_model) + anthropic_provider = AnthropicChatProvider(api_key=anthropic_key, default_model=anthropic_model) + + # Run OpenAI + print(f"\nRunning OpenAI ({openai_model})...", file=sys.stderr) + openai_answers = await run_provider(openai_provider, packets) + oai_path = outdir / "answers-openai.jsonl" + oai_path.write_text("\n".join(json.dumps(a) for a in openai_answers) + "\n") + print(f" Wrote {oai_path}", file=sys.stderr) + + # Run Anthropic + print(f"\nRunning Anthropic ({anthropic_model})...", file=sys.stderr) + anthropic_answers = await run_provider(anthropic_provider, packets) + ant_path = outdir / "answers-anthropic.jsonl" + ant_path.write_text("\n".join(json.dumps(a) for a in anthropic_answers) + "\n") + print(f" Wrote {ant_path}", file=sys.stderr) + + # Generate comparison + print("\nGenerating comparison...", file=sys.stderr) + generate_comparison(packets, openai_answers, anthropic_answers, outdir / "comparison.md") + + print("\nDone. Review tests/eval/results/comparison.md", file=sys.stderr) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_chat_provider.py b/tests/test_chat_provider.py new file mode 100644 index 0000000..a0905de --- /dev/null +++ b/tests/test_chat_provider.py @@ -0,0 +1,357 @@ +"""Unit tests for chat_provider.py — all mocked, no API key needed.""" + +import json +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from api.chat_provider import ( + ChatProvider, + OpenAIChatProvider, + AnthropicChatProvider, + create_classifier_provider, + create_chat_provider, + create_expansion_provider, +) + + +# --------------------------------------------------------------------------- +# Helper: build a provider with a mocked internal client +# --------------------------------------------------------------------------- + +def _make_openai_provider(default_model="gpt-4o-mini"): + with patch("openai.OpenAI"): + provider = OpenAIChatProvider(api_key="fake-key", default_model=default_model) + provider._client = MagicMock() + return provider + + +def _make_anthropic_provider(default_model="claude-sonnet-4-20250514"): + with patch("anthropic.Anthropic"): + provider = AnthropicChatProvider(api_key="fake-key", default_model=default_model) + provider._client = MagicMock() + return provider + + +# --------------------------------------------------------------------------- +# OpenAIChatProvider +# --------------------------------------------------------------------------- + +class TestOpenAIChatProvider: + + @pytest.mark.asyncio + async def test_complete_plain_text(self): + """Basic completion returns content string.""" + provider = _make_openai_provider() + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="Hello world"))] + provider._client.chat.completions.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + result = await provider.complete( + [{"role": "user", "content": "Hi"}], + temperature=0.5, + ) + + assert result == "Hello world" + call_kwargs = provider._client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "gpt-4o-mini" + assert call_kwargs["temperature"] == 0.5 + assert "response_format" not in call_kwargs + + @pytest.mark.asyncio + async def test_complete_json_mode(self): + """json_mode=True passes response_format to OpenAI.""" + provider = _make_openai_provider() + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content='{"key": "val"}'))] + provider._client.chat.completions.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + result = await provider.complete( + [{"role": "system", "content": "Return JSON"}, {"role": "user", "content": "Go"}], + json_mode=True, + ) + + assert json.loads(result) == {"key": "val"} + call_kwargs = provider._client.chat.completions.create.call_args[1] + assert call_kwargs["response_format"] == {"type": "json_object"} + + @pytest.mark.asyncio + async def test_model_override(self): + """Explicit model kwarg overrides default_model.""" + provider = _make_openai_provider(default_model="gpt-4o-mini") + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="ok"))] + provider._client.chat.completions.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + await provider.complete( + [{"role": "user", "content": "test"}], + model="gpt-4o", + ) + + call_kwargs = provider._client.chat.completions.create.call_args[1] + assert call_kwargs["model"] == "gpt-4o" + + +# --------------------------------------------------------------------------- +# AnthropicChatProvider +# --------------------------------------------------------------------------- + +class TestAnthropicChatProvider: + + @pytest.mark.asyncio + async def test_complete_extracts_system_message(self): + """System message is passed as system= kwarg, not in messages list.""" + provider = _make_anthropic_provider() + mock_response = MagicMock() + mock_response.content = [MagicMock(text="Response")] + provider._client.messages.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + result = await provider.complete([ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + ]) + + assert result == "Response" + call_kwargs = provider._client.messages.create.call_args[1] + assert call_kwargs["system"] == "You are helpful" + assert len(call_kwargs["messages"]) == 1 + assert call_kwargs["messages"][0]["role"] == "user" + + @pytest.mark.asyncio + async def test_json_mode_appends_instruction(self): + """json_mode=True appends JSON instruction to system prompt.""" + provider = _make_anthropic_provider() + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{"x": 1}')] + provider._client.messages.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + await provider.complete( + [{"role": "system", "content": "Be brief"}, {"role": "user", "content": "Go"}], + json_mode=True, + ) + + call_kwargs = provider._client.messages.create.call_args[1] + assert "Respond ONLY with valid JSON" in call_kwargs["system"] + assert call_kwargs["system"].startswith("Be brief") + + @pytest.mark.asyncio + async def test_json_mode_no_system_message(self): + """json_mode=True with no system message creates one.""" + provider = _make_anthropic_provider() + mock_response = MagicMock() + mock_response.content = [MagicMock(text='{}')] + provider._client.messages.create.return_value = mock_response + + async def _fake_to_thread(fn, **kw): + return fn(**kw) + + with patch("asyncio.to_thread", side_effect=_fake_to_thread): + await provider.complete( + [{"role": "user", "content": "Go"}], + json_mode=True, + ) + + call_kwargs = provider._client.messages.create.call_args[1] + assert call_kwargs["system"] == "Respond ONLY with valid JSON." + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + +class TestFactories: + + def test_classifier_defaults_openai(self): + """create_classifier_provider() defaults to OpenAI gpt-4o.""" + with patch.dict("os.environ", {"OPENAI_API_KEY": "fake"}, clear=False), \ + patch("openai.OpenAI"): + p = create_classifier_provider() + assert isinstance(p, OpenAIChatProvider) + assert p.default_model == "gpt-4o" + + def test_classifier_rejects_anthropic(self): + """Classifier does not support Anthropic.""" + with patch.dict("os.environ", {"CLASSIFIER_PROVIDER": "anthropic", "ANTHROPIC_API_KEY": "fake"}, clear=False): + with pytest.raises(ValueError, match="Unsupported CLASSIFIER_PROVIDER"): + create_classifier_provider() + + def test_classifier_missing_key_raises(self): + """Missing OPENAI_API_KEY raises ValueError.""" + with patch.dict("os.environ", {"OPENAI_API_KEY": ""}, clear=False): + with pytest.raises(ValueError, match="OPENAI_API_KEY"): + create_classifier_provider() + + def test_chat_provider_openai_default(self): + """create_chat_provider() defaults to OpenAI.""" + with patch.dict("os.environ", {"OPENAI_API_KEY": "fake", "CHAT_LLM_MODEL": "gpt-4o-mini"}, clear=False), \ + patch("openai.OpenAI"): + p = create_chat_provider() + assert isinstance(p, OpenAIChatProvider) + assert p.default_model == "gpt-4o-mini" + + def test_chat_provider_anthropic(self): + """CHAT_PROVIDER=anthropic creates AnthropicChatProvider.""" + with patch.dict("os.environ", { + "CHAT_PROVIDER": "anthropic", + "ANTHROPIC_API_KEY": "fake", + "CHAT_MODEL": "claude-sonnet-4-20250514", + }, clear=False), \ + patch("anthropic.Anthropic"): + p = create_chat_provider() + assert isinstance(p, AnthropicChatProvider) + + def test_chat_provider_unsupported_raises(self): + """Unknown CHAT_PROVIDER raises ValueError.""" + with patch.dict("os.environ", {"CHAT_PROVIDER": "gemini"}, clear=False): + with pytest.raises(ValueError, match="Unsupported CHAT_PROVIDER"): + create_chat_provider() + + def test_expansion_defaults_openai(self): + """create_expansion_provider() defaults to OpenAI gpt-4o-mini.""" + with patch.dict("os.environ", {"OPENAI_API_KEY": "fake"}, clear=False), \ + patch("openai.OpenAI"): + p = create_expansion_provider() + assert isinstance(p, OpenAIChatProvider) + assert p.default_model == "gpt-4o-mini" + + def test_expansion_rejects_anthropic(self): + """Expansion does not support Anthropic.""" + with patch.dict("os.environ", {"EXPANSION_PROVIDER": "anthropic", "ANTHROPIC_API_KEY": "fake"}, clear=False): + with pytest.raises(ValueError, match="Unsupported EXPANSION_PROVIDER"): + create_expansion_provider() + + def test_chat_model_env_precedence(self): + """CHAT_MODEL takes precedence over CHAT_LLM_MODEL.""" + with patch.dict("os.environ", { + "OPENAI_API_KEY": "fake", + "CHAT_MODEL": "gpt-4o", + "CHAT_LLM_MODEL": "gpt-4o-mini", + }, clear=False), \ + patch("openai.OpenAI"): + p = create_chat_provider() + assert p.default_model == "gpt-4o" + + def test_classifier_custom_model(self): + """CLASSIFIER_MODEL overrides default gpt-4o.""" + with patch.dict("os.environ", { + "OPENAI_API_KEY": "fake", + "CLASSIFIER_MODEL": "gpt-4.1", + }, clear=False), \ + patch("openai.OpenAI"): + p = create_classifier_provider() + assert p.default_model == "gpt-4.1" + + +# --------------------------------------------------------------------------- +# Endpoint-level: Anthropic provider model respected end-to-end +# --------------------------------------------------------------------------- + +class TestEndToEndProviderModel: + + @pytest.mark.asyncio + async def test_anthropic_chat_provider_uses_chat_model(self): + """CHAT_PROVIDER=anthropic + CHAT_MODEL=claude-sonnet-4-20250514 + -> provider.complete() called with the Anthropic model, not gpt-4o-mini.""" + from api.schemas.query_plan import EvidenceBundle, RetrievalOp, SourceType + + entity_bundle = EvidenceBundle( + source_uri="urn:e:1", source_type=SourceType.LOCAL_AUTHORITATIVE, + retrieval_op=RetrievalOp.ENTITY_LOOKUP, confidence=0.9, + text="Mock entity", + metadata={"entity_type": "Concept", "label": "Test", "fuseki_uri": "urn:e:1"}, + ) + + mock_classifier = AsyncMock() + mock_expansion = AsyncMock() + + # Chat answer provider: track the model it receives + mock_chat = AsyncMock() + mock_chat.complete = AsyncMock(return_value="Anthropic answer") + mock_chat.default_model = "claude-sonnet-4-20250514" + + mock_cm = AsyncMock() + mock_conn = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_conn) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_pool = MagicMock() + mock_pool.acquire.return_value = mock_cm + + with patch("api.personal_ingest_api.db_pool", mock_pool), \ + patch("api.personal_ingest_api.classifier_provider", mock_classifier), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_chat), \ + patch("api.personal_ingest_api.expansion_provider", mock_expansion), \ + patch("api.personal_ingest_api.generate_embedding", AsyncMock(return_value=[0.1]*1536)), \ + patch("api.personal_ingest_api._try_structured_graph_query", AsyncMock(return_value="")), \ + patch("api.retrieval_executors.entity_lookup", AsyncMock(return_value=[entity_bundle])), \ + patch("api.retrieval_executors.relationship_traverse", AsyncMock(return_value=[])), \ + patch("api.retrieval_executors.text_search", AsyncMock(return_value=[])), \ + patch("api.retrieval_executors.web_source_lookup", AsyncMock(return_value=[])): + + from api.personal_ingest_api import chat_endpoint, ChatRequest + response = await chat_endpoint(ChatRequest(query="What is eelgrass?")) + + assert response["answer"] == "Anthropic answer" + # The call should NOT pass model= (uses provider default_model) + call_kwargs = mock_chat.complete.call_args + assert "model" not in call_kwargs.kwargs or call_kwargs.kwargs.get("model") is None + + +# --------------------------------------------------------------------------- +# Live smoke tests (requires API keys, skipped when absent) +# --------------------------------------------------------------------------- + +import os + +_has_openai_key = bool(os.getenv("OPENAI_API_KEY")) +_has_anthropic_key = bool(os.getenv("ANTHROPIC_API_KEY")) + + +@pytest.mark.live +@pytest.mark.skipif(not _has_openai_key, reason="OPENAI_API_KEY not set") +@pytest.mark.asyncio +async def test_openai_live_smoke(): + """Quick OpenAI round-trip.""" + provider = create_chat_provider() + result = await provider.complete( + [{"role": "user", "content": "Say 'hello' and nothing else."}], + max_tokens=10, + ) + assert "hello" in result.lower() + + +@pytest.mark.live +@pytest.mark.skipif(not _has_anthropic_key, reason="ANTHROPIC_API_KEY not set") +@pytest.mark.asyncio +async def test_anthropic_live_smoke(): + """Quick Anthropic round-trip (needs ANTHROPIC_API_KEY).""" + os.environ["CHAT_PROVIDER"] = "anthropic" + try: + provider = create_chat_provider() + result = await provider.complete( + [{"role": "user", "content": "Say 'hello' and nothing else."}], + max_tokens=10, + ) + assert "hello" in result.lower() + finally: + os.environ.pop("CHAT_PROVIDER", None) diff --git a/tests/test_chat_retrieval.py b/tests/test_chat_retrieval.py index 776643d..5ac74d1 100644 --- a/tests/test_chat_retrieval.py +++ b/tests/test_chat_retrieval.py @@ -442,11 +442,9 @@ async def test_chat_endpoint_calls_executors_and_adapter(self): mock_text = AsyncMock(return_value=[text_bundle]) mock_web = AsyncMock(return_value=[web_bundle]) - # Mock the LLM (wrap in asyncio.to_thread mock) - mock_openai = MagicMock() - mock_completion = MagicMock() - mock_completion.choices = [MagicMock(message=MagicMock(content="Mock answer about eelgrass"))] - mock_openai.chat.completions.create.return_value = mock_completion + # Mock the chat answer provider + mock_chat_provider = AsyncMock() + mock_chat_provider.complete = AsyncMock(return_value="Mock answer about eelgrass") # Mock DB pool with proper async context manager mock_conn = AsyncMock() @@ -462,16 +460,12 @@ async def test_chat_endpoint_calls_executors_and_adapter(self): # Mock _try_structured_graph_query mock_graph_query = AsyncMock(return_value="") - # Mock asyncio.to_thread so the OpenAI call doesn't need a real thread - async def _fake_to_thread(fn, *args, **kwargs): - return fn(*args, **kwargs) - with patch("api.personal_ingest_api.db_pool", mock_pool), \ - patch("api.personal_ingest_api.openai_client", mock_openai), \ - patch("api.personal_ingest_api.OPENAI_API_KEY", "fake-key"), \ + patch("api.personal_ingest_api.classifier_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.expansion_provider", mock_chat_provider), \ patch("api.personal_ingest_api.generate_embedding", mock_embed), \ patch("api.personal_ingest_api._try_structured_graph_query", mock_graph_query), \ - patch("asyncio.to_thread", side_effect=_fake_to_thread), \ patch("api.retrieval_executors.entity_lookup", mock_entity), \ patch("api.retrieval_executors.relationship_traverse", mock_rel), \ patch("api.retrieval_executors.text_search", mock_text), \ @@ -508,3 +502,90 @@ async def _fake_to_thread(fn, *args, **kwargs): assert "intent" in response assert response["answer"] == "Mock answer about eelgrass" assert len(response["sources"]) > 0 + + +# --------------------------------------------------------------------------- +# debug_prompt gating tests +# --------------------------------------------------------------------------- + +def _chat_endpoint_patches(): + """Common patches for chat endpoint tests. Returns a context manager stack.""" + from api.schemas.query_plan import EvidenceBundle, RetrievalOp, SourceType + + entity_bundle = EvidenceBundle( + source_uri="urn:e:1", source_type=SourceType.LOCAL_AUTHORITATIVE, + retrieval_op=RetrievalOp.ENTITY_LOOKUP, confidence=0.9, + text="Mock entity", + metadata={"entity_type": "Concept", "label": "Test", "fuseki_uri": "urn:e:1"}, + ) + + mock_chat_provider = AsyncMock() + mock_chat_provider.complete = AsyncMock(return_value="Mock answer") + + mock_cm = AsyncMock() + mock_conn = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_conn) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_pool = MagicMock() + mock_pool.acquire.return_value = mock_cm + + return { + "api.personal_ingest_api.db_pool": mock_pool, + "api.personal_ingest_api.classifier_provider": mock_chat_provider, + "api.personal_ingest_api.chat_answer_provider": mock_chat_provider, + "api.personal_ingest_api.expansion_provider": mock_chat_provider, + "api.personal_ingest_api.generate_embedding": AsyncMock(return_value=[0.1] * 1536), + "api.personal_ingest_api._try_structured_graph_query": AsyncMock(return_value=""), + "api.retrieval_executors.entity_lookup": AsyncMock(return_value=[entity_bundle]), + "api.retrieval_executors.relationship_traverse": AsyncMock(return_value=[]), + "api.retrieval_executors.text_search": AsyncMock(return_value=[]), + "api.retrieval_executors.web_source_lookup": AsyncMock(return_value=[]), + } + + +class TestDebugPromptGating: + + @pytest.mark.asyncio + async def test_debug_prompt_absent_by_default(self): + """Default request has no _debug_prompt in response.""" + patches = _chat_endpoint_patches() + import contextlib + with contextlib.ExitStack() as stack: + for target, mock_obj in patches.items(): + stack.enter_context(patch(target, mock_obj)) + from api.personal_ingest_api import chat_endpoint, ChatRequest + response = await chat_endpoint(ChatRequest(query="test")) + assert "_debug_prompt" not in response + + @pytest.mark.asyncio + async def test_debug_prompt_gated_by_env(self): + """debug_prompt=True but no CHAT_DEBUG_PROMPT env -> no _debug_prompt.""" + patches = _chat_endpoint_patches() + import contextlib + with contextlib.ExitStack() as stack: + for target, mock_obj in patches.items(): + stack.enter_context(patch(target, mock_obj)) + stack.enter_context(patch.dict("os.environ", {}, clear=False)) + # Ensure CHAT_DEBUG_PROMPT is not set + import os + os.environ.pop("CHAT_DEBUG_PROMPT", None) + from api.personal_ingest_api import chat_endpoint, ChatRequest + response = await chat_endpoint(ChatRequest(query="test", debug_prompt=True)) + assert "_debug_prompt" not in response + + @pytest.mark.asyncio + async def test_debug_prompt_present_when_gated(self): + """debug_prompt=True + CHAT_DEBUG_PROMPT=1 -> _debug_prompt with prompts.""" + patches = _chat_endpoint_patches() + import contextlib + with contextlib.ExitStack() as stack: + for target, mock_obj in patches.items(): + stack.enter_context(patch(target, mock_obj)) + stack.enter_context(patch.dict("os.environ", {"CHAT_DEBUG_PROMPT": "1"}, clear=False)) + from api.personal_ingest_api import chat_endpoint, ChatRequest + response = await chat_endpoint(ChatRequest(query="test", debug_prompt=True)) + assert "_debug_prompt" in response + assert "system_prompt" in response["_debug_prompt"] + assert "user_prompt" in response["_debug_prompt"] + assert len(response["_debug_prompt"]["system_prompt"]) > 0 + assert len(response["_debug_prompt"]["user_prompt"]) > 0 diff --git a/tests/test_classifier_regression.py b/tests/test_classifier_regression.py index ee3b93f..5fa2379 100644 --- a/tests/test_classifier_regression.py +++ b/tests/test_classifier_regression.py @@ -26,6 +26,10 @@ ) from api.query_classifier import classify_query +import os +_has_openai_key = bool(os.getenv("OPENAI_API_KEY")) +_skip_no_key = pytest.mark.skipif(not _has_openai_key, reason="OPENAI_API_KEY not set") + # --------------------------------------------------------------------------- # Regression cases: the 18 questions the current classifier gets wrong # --------------------------------------------------------------------------- @@ -220,6 +224,16 @@ def _get_openai_client(): return OpenAI() +def _get_classifier_provider(): + """Wrap a real OpenAI client in the ChatProvider interface for classify_query.""" + from api.chat_provider import OpenAIChatProvider + import os + return OpenAIChatProvider( + api_key=os.environ["OPENAI_API_KEY"], + default_model=os.getenv("CLASSIFIER_MODEL", "gpt-4o"), + ) + + async def _classify_with_prompt( query: str, client, @@ -366,12 +380,13 @@ async def mock_classify(query: str) -> ClassifierOutput: @pytest.mark.live +@_skip_no_key def test_classifier_baseline(): """Run current production classifier against the 18 failures. Expected: ~0/18.""" - client = _get_openai_client() + provider = _get_classifier_provider() async def baseline_classify(query: str) -> ClassifierOutput: - return await classify_query(query, client) + return await classify_query(query, provider) report = asyncio.run(run_bakeoff(baseline_classify, REGRESSION_CASES)) print_report("Baseline (current classifier)", report) @@ -380,6 +395,7 @@ async def baseline_classify(query: str) -> ClassifierOutput: @pytest.mark.live +@_skip_no_key def test_bakeoff_all_variants(): """Run all 4 variants against the 18-question regression set. Prints comparison.""" client = _get_openai_client() @@ -442,6 +458,7 @@ async def run_all(): @pytest.mark.live +@_skip_no_key def test_full_52_variant_b(): """Run Variant B (winner) against all 52 questions BEFORE implementing. @@ -506,6 +523,7 @@ async def classify_all(): @pytest.mark.live +@_skip_no_key def test_full_52_variant_c(): """Run Variant C (gpt-4o + tuned prompt + Guard 3) against all 52 questions.""" golden_qa_path = Path(__file__).parent / "eval" / "golden_qa.json" @@ -567,6 +585,7 @@ async def classify_all(): @pytest.mark.live +@_skip_no_key def test_full_52_classification(): """Run the CURRENT production classifier against all 52 golden QA questions. @@ -578,13 +597,13 @@ def test_full_52_classification(): with open(golden_qa_path) as f: golden_qa = json.load(f) - client = _get_openai_client() + provider = _get_classifier_provider() async def classify_all(): results = [] for qa in golden_qa: expected = _expected_taxonomy(qa) - output = await classify_query(qa["question"], client) + output = await classify_query(qa["question"], provider) actual = output.query_taxonomy.value results.append({ "id": qa["id"], diff --git a/tests/test_protocol_layer.py b/tests/test_protocol_layer.py new file mode 100644 index 0000000..8374b03 --- /dev/null +++ b/tests/test_protocol_layer.py @@ -0,0 +1,553 @@ +"""In-process pytest tests for the Claims × Spore protocol layer. + +Tests requirements, coverage_links, signals, gap computation, and +the small router extensions (since filter, offer_type filter, scope on responses). + +Run: pytest tests/test_protocol_layer.py -v +Requires: PostgreSQL personal_koi running locally with migrations 079-082 applied. + Uses rollback transactions — no persistent side effects. +""" + +import json +import os +import sys +import time +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import asyncpg +import httpx +import pytest +import pytest_asyncio +from fastapi import FastAPI + +REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(REPO_ROOT)) + +DB_URL = os.getenv("POSTGRES_URL", "postgresql://darrenzal:@localhost:5432/personal_koi") + +# Mark all async tests in this module +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _SingleConnPool: + """Wraps a single asyncpg.Connection to quack like asyncpg.Pool.acquire().""" + + def __init__(self, conn): + self._conn = conn + + class _CM: + def __init__(self, conn): + self.conn = conn + + async def __aenter__(self): + return self.conn + + async def __aexit__(self, *a): + pass + + def acquire(self): + return self._CM(self._conn) + + +def _build_app(fake_pool) -> FastAPI: + app = FastAPI() + from api.routers.protocol_router import create_protocol_router + app.include_router(create_protocol_router(fake_pool), prefix="/protocol") + return app + + +def _build_claims_app(fake_pool) -> FastAPI: + app = FastAPI() + from api.routers.claims_router import create_router as create_claims_router + app.include_router(create_claims_router(fake_pool)) + return app + + +def _build_commitment_app(fake_pool) -> FastAPI: + app = FastAPI() + from api.routers.commitment_router import create_router as create_commitment_router, create_pool_router + app.include_router(create_commitment_router(fake_pool)) + app.include_router(create_pool_router(fake_pool)) + return app + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest_asyncio.fixture +async def test_env(): + """Shared env: rollback transaction, fake pool, protocol + claims + commitment apps.""" + c = await asyncpg.connect(DB_URL) + tx = c.transaction() + await tx.start() + try: + fp = _SingleConnPool(c) + yield c, fp + finally: + await tx.rollback() + await c.close() + + +@pytest_asyncio.fixture +async def protocol_client(test_env): + conn, fp = test_env + app = _build_app(fp) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + yield client, conn + + +@pytest_asyncio.fixture +async def claims_client(test_env): + conn, fp = test_env + app = _build_claims_app(fp) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + yield client, conn + + +@pytest_asyncio.fixture +async def commitment_client(test_env): + conn, fp = test_env + app = _build_commitment_app(fp) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + yield client, conn + + +# --------------------------------------------------------------------------- +# Seed helpers +# --------------------------------------------------------------------------- + +async def _seed_pool(conn, pool_rid="orn:koi-net.pool:test-herring"): + """Insert a minimal commitment pool for testing.""" + await conn.execute(""" + INSERT INTO commitment_pools (pool_rid, name, state, scope) + VALUES ($1, 'Test Herring Pool', 'forming', 'pool') + ON CONFLICT (pool_rid) DO NOTHING + """, pool_rid) + return pool_rid + + +async def _seed_requirement(conn, pool_rid, statement="Quarterly herring monitoring", + frequency="quarterly", freshness_days=90, severity="high"): + """Insert a requirement scoped to a pool.""" + ts = int(time.time() * 1_000_000) + rid = f"orn:koi-net.requirement:test-{ts}" + await conn.execute(""" + INSERT INTO requirements ( + requirement_rid, scope, scope_ref, policy_source, + requirement_type, statement, frequency, + freshness_window_days, severity + ) VALUES ($1, 'pool', $2, 'test-constitution', 'monitoring', $3, $4, $5, $6) + """, rid, pool_rid, statement, frequency, freshness_days, severity) + return rid + + +async def _seed_coverage(conn, source_rid, target_rid, *, + valid_from=None, valid_until=None, + coverage_type="commitment_covers_requirement"): + """Insert a coverage link.""" + ts = int(time.time() * 1_000_000) + rid = f"orn:koi-net.coverage:test-{ts}" + vf = valid_from or datetime.now(timezone.utc) + await conn.execute(""" + INSERT INTO coverage_links ( + coverage_rid, coverage_type, source_rid, target_rid, + valid_from, valid_until, confidence, provenance + ) VALUES ($1, $2, $3, $4, $5, $6, 0.9, 'manual') + """, rid, coverage_type, source_rid, target_rid, vf, valid_until) + return rid + + +async def _seed_commitment(conn, pool_rid, commitment_rid=None): + """Insert a minimal commitment in a pool.""" + ts = int(time.time() * 1_000_000) + rid = commitment_rid or f"orn:koi-net.commitment:test-{ts}" + pool_id = await conn.fetchval( + "SELECT id FROM commitment_pools WHERE pool_rid = $1", pool_rid) + await conn.execute(""" + INSERT INTO commitments (commitment_rid, pledger_uri, pool_id, title, offer_type) + VALUES ($1, 'urn:test:pledger', $2, 'Test commitment', 'stewardship') + """, rid, pool_id) + return rid + + +# =========================================================================== +# Test 1: Requirement RID uniqueness across two pools +# =========================================================================== + +async def test_requirement_rid_different_pools(protocol_client): + """Same statement + policy_source in two different pools must produce different RIDs.""" + client, conn = protocol_client + body = { + "scope": "pool", + "policy_source": "test-constitution", + "requirement_type": "monitoring", + "statement": "Quarterly habitat monitoring", + "frequency": "quarterly", + "severity": "high", + } + + r1 = await client.post("/protocol/requirements/create", + json={**body, "scope_ref": "pool-alpha"}) + assert r1.status_code == 201 + rid_a = r1.json()["requirement_rid"] + + r2 = await client.post("/protocol/requirements/create", + json={**body, "scope_ref": "pool-beta"}) + assert r2.status_code == 201 + rid_b = r2.json()["requirement_rid"] + + assert rid_a != rid_b, "Same statement in different pools must produce different RIDs" + + +# =========================================================================== +# Test 2: Requirement upsert updates mutable fields +# =========================================================================== + +async def test_requirement_upsert_mutable_fields(protocol_client): + """Re-creating the same requirement should update severity, frequency, etc.""" + client, conn = protocol_client + body = { + "scope": "pool", + "scope_ref": "pool-upsert-test", + "policy_source": "test-constitution", + "requirement_type": "monitoring", + "statement": "Monthly water quality check", + "frequency": "monthly", + "severity": "medium", + } + + r1 = await client.post("/protocol/requirements/create", json=body) + assert r1.status_code == 201 + rid = r1.json()["requirement_rid"] + assert r1.json()["severity"] == "medium" + + r2 = await client.post("/protocol/requirements/create", + json={**body, "severity": "critical"}) + assert r2.status_code == 201 + assert r2.json()["requirement_rid"] == rid + assert r2.json()["severity"] == "critical" + + +# =========================================================================== +# Test 3: Gap signal recomputation upserts metadata +# =========================================================================== + +async def test_gap_signal_recomputation_upserts(protocol_client): + """Hitting /gaps twice should update the signal metadata, not silently skip.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn) + req_rid = await _seed_requirement(conn, pool_rid) + + r1 = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r1.status_code == 200 + gaps1 = r1.json() + assert gaps1["unmet_count"] == 1 + sig_rid = gaps1["gaps"][0]["signal_rid"] + + r2 = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r2.status_code == 200 + assert r2.json()["gaps"][0]["signal_rid"] == sig_rid + + row = await conn.fetchrow( + "SELECT metadata FROM signals WHERE signal_rid = $1", sig_rid) + meta = json.loads(row["metadata"]) if isinstance(row["metadata"], str) else row["metadata"] + assert "computed_at" in meta + + +# =========================================================================== +# Test 4: Future-dated coverage does not suppress a gap +# =========================================================================== + +async def test_future_coverage_does_not_suppress_gap(protocol_client): + """A coverage link with valid_from in the future should not count as coverage.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn) + req_rid = await _seed_requirement(conn, pool_rid) + commit_rid = await _seed_commitment(conn, pool_rid) + + future = datetime.now(timezone.utc) + timedelta(days=1) + await _seed_coverage(conn, commit_rid, req_rid, valid_from=future) + + r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r.status_code == 200 + data = r.json() + assert data["unmet_count"] == 1, "Future coverage should not suppress a gap" + assert data["covered_count"] == 0 + + +# =========================================================================== +# Test 5: Wrong coverage type does not suppress a gap +# =========================================================================== + +async def test_wrong_coverage_type_does_not_suppress_gap(protocol_client): + """Only commitment_covers_requirement should count for pool gaps.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn) + req_rid = await _seed_requirement(conn, pool_rid) + + await _seed_coverage(conn, "orn:koi-net.claim:fake", req_rid, + coverage_type="claim_covers_condition") + + r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r.status_code == 200 + assert r.json()["unmet_count"] == 1, "claim_covers_condition should not satisfy a requirement" + + +# =========================================================================== +# Test 6: Valid coverage suppresses a gap +# =========================================================================== + +async def test_valid_coverage_suppresses_gap(protocol_client): + """A current commitment_covers_requirement link should make the requirement covered.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn) + req_rid = await _seed_requirement(conn, pool_rid) + commit_rid = await _seed_commitment(conn, pool_rid) + + await _seed_coverage(conn, commit_rid, req_rid, + valid_from=datetime.now(timezone.utc) - timedelta(days=1)) + + r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r.status_code == 200 + data = r.json() + assert data["covered_count"] == 1 + assert data["unmet_count"] == 0 + assert len(data["gaps"]) == 0 + + +# =========================================================================== +# Test 7: Expired coverage produces stale gap +# =========================================================================== + +async def test_expired_coverage_produces_stale_gap(protocol_client): + """Coverage that has expired should produce a 'stale' gap, not 'unmet'.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn) + req_rid = await _seed_requirement(conn, pool_rid) + commit_rid = await _seed_commitment(conn, pool_rid) + + past_start = datetime.now(timezone.utc) - timedelta(days=120) + past_end = datetime.now(timezone.utc) - timedelta(days=30) + await _seed_coverage(conn, commit_rid, req_rid, + valid_from=past_start, valid_until=past_end) + + r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r.status_code == 200 + data = r.json() + assert data["stale_count"] == 1 + assert data["gaps"][0]["gap_type"] == "stale" + + +# =========================================================================== +# Test 8: Malformed since returns 422 +# =========================================================================== + +async def test_claims_since_malformed_returns_422(claims_client): + """Malformed since parameter should return 422, not a database error.""" + client, conn = claims_client + r = await client.get("/claims/?since=not-a-date") + assert r.status_code == 422, f"Expected 422, got {r.status_code}" + + +# =========================================================================== +# Test 9: Claims since filter works with valid datetime +# =========================================================================== + +async def test_claims_since_valid_datetime(claims_client): + """Valid since parameter should filter claims without error.""" + client, conn = claims_client + r = await client.get("/claims/?since=2026-01-01T00:00:00Z") + assert r.status_code == 200 + + +# =========================================================================== +# Test 10: Commitment response includes scope +# =========================================================================== + +async def test_commitment_response_includes_scope(commitment_client): + """After migration 080, commitment responses should include scope.""" + client, conn = commitment_client + pool_rid = await _seed_pool(conn) + commit_rid = await _seed_commitment(conn, pool_rid) + await conn.execute( + "UPDATE commitments SET scope = 'pool' WHERE commitment_rid = $1", commit_rid) + + r = await client.get(f"/commitments/{commit_rid}") + assert r.status_code == 200 + assert r.json()["scope"] == "pool" + + +# =========================================================================== +# Test 11: Pool response includes scope +# =========================================================================== + +async def test_pool_response_includes_scope(commitment_client): + """After migration 080, pool responses should include scope.""" + client, conn = commitment_client + pool_rid = await _seed_pool(conn) + + r = await client.get(f"/pools/{pool_rid}") + assert r.status_code == 200 + assert r.json()["scope"] == "pool" + + +# =========================================================================== +# Test 12: End-to-end gap path +# =========================================================================== + +async def test_e2e_gap_path(protocol_client): + """Full demo path: create pool → create requirement → no coverage → gap + signal.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn, pool_rid="orn:koi-net.pool:e2e-test") + + # 1. Create requirement via API + req_r = await client.post("/protocol/requirements/create", json={ + "scope": "pool", + "scope_ref": pool_rid, + "policy_source": "e2e-test-constitution", + "requirement_type": "monitoring", + "statement": "Monthly biodiversity survey", + "frequency": "monthly", + "freshness_window_days": 35, + "severity": "high", + }) + assert req_r.status_code == 201 + req_rid = req_r.json()["requirement_rid"] + + # 2. Verify requirement shows in list + list_r = await client.get(f"/protocol/requirements/?scope_ref={pool_rid}") + assert list_r.status_code == 200 + assert len(list_r.json()) == 1 + + # 3. Compute gaps — should find 1 unmet + gaps_r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert gaps_r.status_code == 200 + data = gaps_r.json() + assert data["total_requirements"] == 1 + assert data["unmet_count"] == 1 + assert data["covered_count"] == 0 + + gap = data["gaps"][0] + assert gap["gap_type"] == "unmet" + assert gap["severity"] == "high" + assert gap["next_move"] == "propose_commitment" + assert gap["signal_rid"] is not None + + # 4. Verify signal was emitted + sig_r = await client.get(f"/protocol/signals/?source_ref={pool_rid}&signal_type=gap_computed") + assert sig_r.status_code == 200 + assert len(sig_r.json()) >= 1 + + # 5. Add coverage and re-check — gap should resolve + commit_rid = await _seed_commitment(conn, pool_rid) + await _seed_coverage(conn, commit_rid, req_rid, + valid_from=datetime.now(timezone.utc) - timedelta(hours=1)) + + gaps_r2 = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert gaps_r2.status_code == 200 + data2 = gaps_r2.json() + assert data2["covered_count"] == 1 + assert data2["unmet_count"] == 0 + assert len(data2["gaps"]) == 0 + + +# =========================================================================== +# Test 13: Pool not found returns 404 +# =========================================================================== + +async def test_gaps_pool_not_found(protocol_client): + client, conn = protocol_client + r = await client.get("/protocol/pools/nonexistent-pool/gaps") + assert r.status_code == 404 + + +# =========================================================================== +# Test 14: Commitment offer_type filter +# =========================================================================== + +async def test_commitment_offer_type_filter(commitment_client): + """The new offer_type filter should narrow commitment listings.""" + client, conn = commitment_client + pool_rid = await _seed_pool(conn) + await _seed_commitment(conn, pool_rid) # default: stewardship + + pool_id = await conn.fetchval( + "SELECT id FROM commitment_pools WHERE pool_rid = $1", pool_rid) + await conn.execute(""" + INSERT INTO commitments (commitment_rid, pledger_uri, pool_id, title, offer_type) + VALUES ('orn:koi-net.commitment:labor-test', 'urn:test:pledger', $1, 'Labor work', 'labor') + """, pool_id) + + r_all = await client.get("/commitments/") + all_count = len(r_all.json()) + + r_stew = await client.get("/commitments/?offer_type=stewardship") + assert r_stew.status_code == 200 + stew = r_stew.json() + assert all(c["offer_type"] == "stewardship" for c in stew) + assert len(stew) < all_count or all_count == 1 + + +# =========================================================================== +# Test 15: Same-pool same-statement different subjects produce different RIDs +# =========================================================================== + +async def test_requirement_rid_different_subjects_same_pool(protocol_client): + """Same statement in same pool but different subject_uri must produce different RIDs.""" + client, conn = protocol_client + body = { + "scope": "pool", + "scope_ref": "pool-multi-subject", + "policy_source": "test-constitution", + "requirement_type": "monitoring", + "statement": "Quarterly monitoring required", + "frequency": "quarterly", + "severity": "high", + } + + r1 = await client.post("/protocol/requirements/create", + json={**body, "subject_uri": "urn:species:herring"}) + assert r1.status_code == 201 + rid_a = r1.json()["requirement_rid"] + + r2 = await client.post("/protocol/requirements/create", + json={**body, "subject_uri": "urn:species:salmon"}) + assert r2.status_code == 201 + rid_b = r2.json()["requirement_rid"] + + assert rid_a != rid_b, "Same statement for different subjects must produce different RIDs" + + +# =========================================================================== +# Test 16: Future-dated coverage with finite end is unmet, not stale +# =========================================================================== + +async def test_future_coverage_with_end_date_is_unmet(protocol_client): + """Coverage starting next week and ending next month should be unmet, not stale.""" + client, conn = protocol_client + pool_rid = await _seed_pool(conn, pool_rid="orn:koi-net.pool:future-end-test") + req_rid = await _seed_requirement(conn, pool_rid) + commit_rid = await _seed_commitment(conn, pool_rid) + + # Future-dated coverage with a finite end date + future_start = datetime.now(timezone.utc) + timedelta(days=7) + future_end = datetime.now(timezone.utc) + timedelta(days=37) + await _seed_coverage(conn, commit_rid, req_rid, + valid_from=future_start, valid_until=future_end) + + r = await client.get(f"/protocol/pools/{pool_rid}/gaps") + assert r.status_code == 200 + data = r.json() + assert data["unmet_count"] == 1, "Future coverage with end date should be unmet, not stale" + assert data["stale_count"] == 0, "No expired past coverage exists — should not be stale" + assert data["gaps"][0]["gap_type"] == "unmet" diff --git a/tests/test_query_planner.py b/tests/test_query_planner.py index 3e1dab4..a3b53d5 100644 --- a/tests/test_query_planner.py +++ b/tests/test_query_planner.py @@ -36,21 +36,19 @@ class TestClassifier: @pytest.mark.asyncio async def test_classify_entity_definition(self): - """Mock OpenAI returns correct ClassifierOutput for entity question.""" + """Mock provider returns correct ClassifierOutput for entity question.""" from api.query_classifier import classify_query - mock_client = MagicMock() - mock_resp = MagicMock() - mock_resp.choices = [MagicMock(message=MagicMock(content=json.dumps({ + mock_provider = AsyncMock() + mock_provider.complete = AsyncMock(return_value=json.dumps({ "query_taxonomy": "entity_definition", "depth_tier": "standard", "entities": [{"name": "Eelgrass", "type": "Concept"}], "reasoning": "Asking for a definition of eelgrass", "confidence": 0.95, - })))] - mock_client.chat.completions.create.return_value = mock_resp + })) - result = await classify_query("What is eelgrass?", mock_client) + result = await classify_query("What is eelgrass?", mock_provider) assert result.query_taxonomy == QueryTaxonomy.ENTITY_DEFINITION assert result.depth_tier == DepthTier.STANDARD @@ -63,10 +61,10 @@ async def test_classify_parse_error_fallback(self): """Malformed LLM response -> OUT_OF_DOMAIN, confidence=0.0.""" from api.query_classifier import classify_query - mock_client = MagicMock() - mock_client.chat.completions.create.side_effect = Exception("API error") + mock_provider = AsyncMock() + mock_provider.complete = AsyncMock(side_effect=Exception("API error")) - result = await classify_query("test query", mock_client) + result = await classify_query("test query", mock_provider) assert result.query_taxonomy == QueryTaxonomy.OUT_OF_DOMAIN assert result.confidence == 0.0 @@ -293,10 +291,8 @@ async def test_fallback_below_threshold(self): confidence=0.5, ) - mock_openai = MagicMock() - mock_completion = MagicMock() - mock_completion.choices = [MagicMock(message=MagicMock(content="Mock answer"))] - mock_openai.chat.completions.create.return_value = mock_completion + mock_chat_provider = AsyncMock() + mock_chat_provider.complete = AsyncMock(return_value="Mock answer") mock_cm = AsyncMock() mock_conn = AsyncMock() @@ -312,15 +308,12 @@ async def test_fallback_below_threshold(self): text="Desc", metadata={"entity_type": "Concept", "label": "Test", "fuseki_uri": "urn:e:1"}) ]) - async def _fake_to_thread(fn, *args, **kwargs): - return fn(*args, **kwargs) - with patch("api.personal_ingest_api.db_pool", mock_pool), \ - patch("api.personal_ingest_api.openai_client", mock_openai), \ - patch("api.personal_ingest_api.OPENAI_API_KEY", "fake"), \ + patch("api.personal_ingest_api.classifier_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.expansion_provider", mock_chat_provider), \ patch("api.personal_ingest_api.generate_embedding", AsyncMock(return_value=[0.1]*1536)), \ patch("api.personal_ingest_api._try_structured_graph_query", AsyncMock(return_value="")), \ - patch("asyncio.to_thread", side_effect=_fake_to_thread), \ patch("api.query_classifier.classify_query", AsyncMock(return_value=low_conf)), \ patch("api.retrieval_executors.entity_lookup", mock_entity_lookup), \ patch("api.retrieval_executors.relationship_traverse", AsyncMock(return_value=[])), \ @@ -355,12 +348,14 @@ async def test_abstain_out_of_domain_high_confidence_no_llm(self): reasoning="Stock price question is out of domain", ) - mock_openai = MagicMock() + mock_chat_provider = AsyncMock() + mock_chat_provider.complete = AsyncMock(return_value="should not be called") mock_pool = MagicMock() with patch("api.personal_ingest_api.db_pool", mock_pool), \ - patch("api.personal_ingest_api.openai_client", mock_openai), \ - patch("api.personal_ingest_api.OPENAI_API_KEY", "fake"), \ + patch("api.personal_ingest_api.classifier_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.expansion_provider", mock_chat_provider), \ patch("api.personal_ingest_api.generate_embedding", AsyncMock(return_value=[0.1]*1536)), \ patch("api.query_classifier.classify_query", AsyncMock(return_value=ood_result)): @@ -373,8 +368,6 @@ async def test_abstain_out_of_domain_high_confidence_no_llm(self): assert response["plan_trace"]["fallback"] is False assert response["sources"] == [] assert "outside the scope" in response["answer"] - # LLM should NOT have been called - mock_openai.chat.completions.create.assert_not_called() @pytest.mark.asyncio async def test_plan_trace_emitted_on_fallback(self): @@ -382,10 +375,8 @@ async def test_plan_trace_emitted_on_fallback(self): # Same as test_fallback_below_threshold but focused on trace presence co = ClassifierOutput(query_taxonomy=QueryTaxonomy.ENTITY_DEFINITION, confidence=0.3) - mock_openai = MagicMock() - mock_completion = MagicMock() - mock_completion.choices = [MagicMock(message=MagicMock(content="Answer"))] - mock_openai.chat.completions.create.return_value = mock_completion + mock_chat_provider = AsyncMock() + mock_chat_provider.complete = AsyncMock(return_value="Answer") mock_cm = AsyncMock() mock_conn = AsyncMock() @@ -394,15 +385,12 @@ async def test_plan_trace_emitted_on_fallback(self): mock_pool = MagicMock() mock_pool.acquire.return_value = mock_cm - async def _fake_to_thread(fn, *args, **kwargs): - return fn(*args, **kwargs) - with patch("api.personal_ingest_api.db_pool", mock_pool), \ - patch("api.personal_ingest_api.openai_client", mock_openai), \ - patch("api.personal_ingest_api.OPENAI_API_KEY", "fake"), \ + patch("api.personal_ingest_api.classifier_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_chat_provider), \ + patch("api.personal_ingest_api.expansion_provider", mock_chat_provider), \ patch("api.personal_ingest_api.generate_embedding", AsyncMock(return_value=[0.1]*1536)), \ patch("api.personal_ingest_api._try_structured_graph_query", AsyncMock(return_value="")), \ - patch("asyncio.to_thread", side_effect=_fake_to_thread), \ patch("api.query_classifier.classify_query", AsyncMock(return_value=co)), \ patch("api.retrieval_executors.entity_lookup", AsyncMock(return_value=[])), \ patch("api.retrieval_executors.relationship_traverse", AsyncMock(return_value=[])), \ @@ -420,9 +408,12 @@ async def test_plan_trace_emitted_on_abstention(self): """OOD abstention response contains plan_trace with abstained=true.""" ood = ClassifierOutput(query_taxonomy=QueryTaxonomy.OUT_OF_DOMAIN, confidence=0.9) + mock_provider = AsyncMock() + with patch("api.personal_ingest_api.db_pool", MagicMock()), \ - patch("api.personal_ingest_api.openai_client", MagicMock()), \ - patch("api.personal_ingest_api.OPENAI_API_KEY", "fake"), \ + patch("api.personal_ingest_api.classifier_provider", mock_provider), \ + patch("api.personal_ingest_api.chat_answer_provider", mock_provider), \ + patch("api.personal_ingest_api.expansion_provider", mock_provider), \ patch("api.personal_ingest_api.generate_embedding", AsyncMock(return_value=[0.1]*1536)), \ patch("api.query_classifier.classify_query", AsyncMock(return_value=ood)):