Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions api/chat_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Chat provider abstraction — per-surface LLM decoupling.

One ABC, two implementations (OpenAI, Anthropic), three independent
factory functions for classifier, chat answer, and query expansion.

Follows the same pattern as api/embedding_provider.py.
"""

import asyncio
import logging
import os
from abc import ABC, abstractmethod

logger = logging.getLogger(__name__)


class ChatProvider(ABC):
"""Abstract base class for chat completion providers."""

provider_name: str
default_model: str

@abstractmethod
async def complete(
self,
messages: list[dict[str, str]],
*,
model: str | None = None,
temperature: float = 0.3,
max_tokens: int = 1024,
json_mode: bool = False,
) -> str:
"""Return the assistant's text response."""


class OpenAIChatProvider(ChatProvider):
"""OpenAI chat completions via the openai package."""

provider_name = "openai"

def __init__(self, api_key: str, default_model: str = "gpt-4o-mini"):
from openai import OpenAI
self._client = OpenAI(api_key=api_key)
self.default_model = default_model

async def complete(
self,
messages: list[dict[str, str]],
*,
model: str | None = None,
temperature: float = 0.3,
max_tokens: int = 1024,
json_mode: bool = False,
) -> str:
kwargs: dict = dict(
model=model or self.default_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
if json_mode:
kwargs["response_format"] = {"type": "json_object"}

response = await asyncio.to_thread(
self._client.chat.completions.create, **kwargs
)
return response.choices[0].message.content or ""


class AnthropicChatProvider(ChatProvider):
"""Anthropic chat completions via the anthropic package."""

provider_name = "anthropic"

def __init__(self, api_key: str, default_model: str = "claude-sonnet-4-20250514"):
import anthropic
self._client = anthropic.Anthropic(api_key=api_key)
self.default_model = default_model

async def complete(
self,
messages: list[dict[str, str]],
*,
model: str | None = None,
temperature: float = 0.3,
max_tokens: int = 1024,
json_mode: bool = False,
) -> str:
# Anthropic API: system message is a separate kwarg
system_text = None
user_messages = []
for m in messages:
if m["role"] == "system":
system_text = m["content"]
else:
user_messages.append({"role": m["role"], "content": m["content"]})

# json_mode: append instruction (Anthropic has no native json_mode)
if json_mode and system_text:
system_text += "\n\nRespond ONLY with valid JSON."
elif json_mode:
system_text = "Respond ONLY with valid JSON."

kwargs: dict = dict(
model=model or self.default_model,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens,
)
if system_text:
kwargs["system"] = system_text

response = await asyncio.to_thread(
self._client.messages.create, **kwargs
)
return response.content[0].text


# ---------------------------------------------------------------------------
# Factory functions — one per surface
# ---------------------------------------------------------------------------

_OPENAI_KEY_VAR = "OPENAI_API_KEY"
_ANTHROPIC_KEY_VAR = "ANTHROPIC_API_KEY"


def _require_key(env_var: str) -> str:
key = os.getenv(env_var, "")
if not key:
raise ValueError(f"{env_var} is required but not set")
return key


def create_classifier_provider() -> ChatProvider:
"""Classifier provider. OpenAI only (structured output sensitive)."""
provider = os.getenv("CLASSIFIER_PROVIDER", "openai").strip().lower()
model = os.getenv("CLASSIFIER_MODEL", "gpt-4o")

if provider == "openai":
return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model)

raise ValueError(f"Unsupported CLASSIFIER_PROVIDER: {provider!r} (only 'openai' supported)")


def create_chat_provider() -> ChatProvider:
"""Chat answer provider. Supports OpenAI and Anthropic."""
provider = os.getenv("CHAT_PROVIDER", "openai").strip().lower()
# CHAT_MODEL with CHAT_LLM_MODEL as backward-compat alias
model = os.getenv("CHAT_MODEL") or os.getenv("CHAT_LLM_MODEL", "gpt-4o-mini")

if provider == "openai":
return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model)

if provider == "anthropic":
model = os.getenv("CHAT_MODEL", "claude-sonnet-4-20250514")
return AnthropicChatProvider(api_key=_require_key(_ANTHROPIC_KEY_VAR), default_model=model)

raise ValueError(f"Unsupported CHAT_PROVIDER: {provider!r} (supported: 'openai', 'anthropic')")


def create_expansion_provider() -> ChatProvider:
"""Query expansion provider. OpenAI only."""
provider = os.getenv("EXPANSION_PROVIDER", "openai").strip().lower()
model = os.getenv("EXPANSION_MODEL", "gpt-4o-mini")

if provider == "openai":
return OpenAIChatProvider(api_key=_require_key(_OPENAI_KEY_VAR), default_model=model)

raise ValueError(f"Unsupported EXPANSION_PROVIDER: {provider!r} (only 'openai' supported)")
75 changes: 48 additions & 27 deletions api/personal_ingest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,19 @@ def quartz_url(entity_type: str, name: str) -> Optional[str]:

# Global connection pool
db_pool: Optional[asyncpg.Pool] = None
openai_client: Optional[Any] = None # lazy init for /chat LLM calls

from api.embedding_provider import EmbeddingProvider, create_embedding_provider
embedding_provider: Optional[EmbeddingProvider] = None

from api.chat_provider import (
ChatProvider,
create_classifier_provider,
create_chat_provider,
create_expansion_provider,
)
classifier_provider: Optional[ChatProvider] = None
chat_answer_provider: Optional[ChatProvider] = None
expansion_provider: Optional[ChatProvider] = None
terminusdb_adapter: Optional[Any] = None # TerminusDBAdapter instance (lazy init)


Expand Down Expand Up @@ -1225,7 +1234,7 @@ async def enqueue_outbox(
@app.on_event("startup")
async def startup():
"""Initialize database connection pool and embedding provider"""
global db_pool, openai_client, embedding_provider
global db_pool, embedding_provider
try:
db_pool = await asyncpg.create_pool(
DB_URL,
Expand Down Expand Up @@ -1306,6 +1315,14 @@ async def startup():
except Exception as e:
logger.warning(f"Commitment routers not mounted: {e}")

# Protocol layer router (requirements, coverage, signals, gap computation)
try:
from api.routers.protocol_router import create_protocol_router
app.include_router(create_protocol_router(db_pool), prefix="/protocol")
logger.info("Protocol router mounted (/protocol/)")
except Exception as e:
logger.warning(f"Protocol router not mounted: {e}")

# Task router is always mounted (no capability gate — core feature)
try:
from api.routers.task_router import create_router as create_task_router
Expand Down Expand Up @@ -4596,8 +4613,6 @@ async def graph_shortest_path(
# /chat Endpoint — RAG-powered conversational interface
# =============================================================================

CHAT_LLM_MODEL = os.getenv('CHAT_LLM_MODEL', 'gpt-4o-mini')


# ── B2 GraphRAG: graph-guided retrieval ──────────────────────────────

Expand Down Expand Up @@ -5104,27 +5119,30 @@ def _rerank_chunks(query: str, chunks: list, top_k: int = 8) -> list:
return reranked


EXPANSION_MODEL = os.getenv("EXPANSION_MODEL", "gpt-4o-mini")


async def _expand_queries(query: str, n: int = 3) -> list:
"""Generate n query reformulations for multi-query retrieval (B8b).
Returns [original_query] + up to n reformulations.
"""
if not openai_client:
if not expansion_provider:
return [query]
try:
resp = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{
result = await expansion_provider.complete(
[{
"role": "user",
"content": f"""Generate {n} alternative search queries for a bioregional knowledge commons.
Original query: "{query}"

Return ONLY the alternative queries, one per line. No numbering, no explanation.
Each should rephrase the question using different terminology to find relevant documents."""
}],
model=EXPANSION_MODEL,
max_tokens=200,
temperature=0.7,
)
lines = [l.strip() for l in resp.choices[0].message.content.strip().split('\n') if l.strip()]
lines = [l.strip() for l in result.strip().split('\n') if l.strip()]
return [query] + lines[:n]
except Exception as e:
logger.warning(f"B8b query expansion failed: {e}")
Expand All @@ -5139,6 +5157,7 @@ class ChatRequest(BaseModel):
include_code: bool = Field(default=False, description="Include code entity chunks in retrieval (default: exclude)")
multi_query: bool = Field(default=False, description="Enable multi-query expansion for broader retrieval (B8b)")
planner: bool = Field(default=False, description="B9a QueryPlan IR path (experimental)")
debug_prompt: bool = Field(default=False, description="Include assembled prompt in response (requires CHAT_DEBUG_PROMPT env)")


@app.post("/chat")
Expand All @@ -5153,21 +5172,20 @@ async def chat_endpoint(request: ChatRequest):
"""
if not db_pool:
raise HTTPException(status_code=503, detail="Database not available")
# Lazy init openai_client for /chat LLM calls (separate from embedding provider)
global openai_client
if not openai_client:
if not OPENAI_API_KEY:
raise HTTPException(
status_code=503,
detail="LLM service not available (OPENAI_API_KEY not configured)",
)
# Lazy init chat providers (classifier, chat answer, expansion)
global classifier_provider, chat_answer_provider, expansion_provider
if not classifier_provider or not chat_answer_provider or not expansion_provider:
try:
from openai import OpenAI
openai_client = OpenAI(api_key=OPENAI_API_KEY)
except ImportError:
if not classifier_provider:
classifier_provider = create_classifier_provider()
if not chat_answer_provider:
chat_answer_provider = create_chat_provider()
if not expansion_provider:
expansion_provider = create_expansion_provider()
except (ValueError, ImportError) as e:
raise HTTPException(
status_code=503,
detail="openai package not installed",
detail=f"LLM provider not available: {e}",
)

# ------------------------------------------------------------------
Expand All @@ -5187,7 +5205,7 @@ async def chat_endpoint(request: ChatRequest):
from api.retrieval_executors import evidence_bundles_to_legacy_format
from api.schemas.query_plan import QueryTaxonomy

classifier_output = await classify_query(request.query, openai_client)
classifier_output = await classify_query(request.query, classifier_provider)

if classifier_output.query_taxonomy == QueryTaxonomy.OUT_OF_DOMAIN and \
classifier_output.confidence >= CLASSIFIER_CONFIDENCE_THRESHOLD:
Expand Down Expand Up @@ -5356,17 +5374,14 @@ async def chat_endpoint(request: ChatRequest):
# 4. Call LLM
# ------------------------------------------------------------------
try:
llm_response = await asyncio.to_thread(
openai_client.chat.completions.create,
model=CHAT_LLM_MODEL,
messages=[
answer = await chat_answer_provider.complete(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.3,
max_tokens=1024,
)
answer = llm_response.choices[0].message.content or ""
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise HTTPException(
Expand All @@ -5393,6 +5408,12 @@ async def chat_endpoint(request: ChatRequest):
# Always emit plan_trace when planner was requested (including fallback)
if planner_requested and plan_trace is not None:
response["plan_trace"] = plan_trace
# Debug prompt capture (double-gated: request flag + server env var)
if request.debug_prompt and os.getenv("CHAT_DEBUG_PROMPT"):
response["_debug_prompt"] = {
"system_prompt": system_prompt,
"user_prompt": user_prompt,
}
return response


Expand Down
Loading