Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions medguard/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class HallucinationConfig(BaseModel):

class FactCheckConfig(BaseModel):
enabled: bool = False # opt-in: requires network calls to PubMed
use_agent: bool = False # opt-in: requires an LLM caller for agent reasoning
confidence_threshold: float = 0.4
max_claims_per_response: int = 5
ncbi_api_key_env: str = "NCBI_API_KEY"
Expand Down
31 changes: 21 additions & 10 deletions medguard/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def __init__(
self.fact_verifier: FactVerifier | None = None
self._llm_caller: LLMCallerProtocol | None = None

try:
self._llm_caller = _build_llm_caller(self.config)
except Exception as exc:
log.warning("llm_caller_init_failed", error=str(exc))

self._build_components()
self.pipeline = self._build_pipeline()

Expand Down Expand Up @@ -204,7 +209,7 @@ def _build_components(self) -> None:
try:
import httpx

from medguard.guardrails.fact_check import FactVerifier
from medguard.guardrails.fact_check import AgentFactVerifier, FactVerifier
from medguard.knowledge.pubmed import PubMedClient

if not hasattr(self, "_http_client"):
Expand All @@ -213,19 +218,25 @@ def _build_components(self) -> None:
self._http_client,
max_results=cfg.guardrails.fact_checking.max_claims_per_response,
)
self.fact_verifier = FactVerifier(
pubmed,
confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold,
verifier_cls = (
AgentFactVerifier
if cfg.guardrails.fact_checking.use_agent
else FactVerifier
)
if verifier_cls is AgentFactVerifier:
self.fact_verifier = verifier_cls(
pubmed,
self._llm_caller,
confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold,
)
else:
self.fact_verifier = verifier_cls(
pubmed,
confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold,
)
except Exception as exc:
log.warning("fact_verifier_init_failed", error=str(exc))

# LLM caller
try:
self._llm_caller = _build_llm_caller(cfg)
except Exception as exc:
log.warning("llm_caller_init_failed", error=str(exc))

def _build_pipeline(self) -> GuardrailPipeline:
return GuardrailPipeline(
config=self.config,
Expand Down
179 changes: 179 additions & 0 deletions medguard/guardrails/fact_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
"""
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING

import structlog
from pydantic import BaseModel

if TYPE_CHECKING:
from medguard.guardrails.protocols import LLMCallerProtocol
from medguard.knowledge.pubmed import FactEvidence, PubMedClient

log = structlog.get_logger(__name__)
Expand Down Expand Up @@ -144,6 +146,143 @@ async def verify(self, text: str) -> FactCheckResult:
)


class AgentFactVerifier(FactVerifier):
"""
Verifies claims by asking an LLM to reason over retrieved PubMed abstracts.

The agent path is opt-in. If no LLM caller is configured, it falls back to
the keyword-backed PubMed verifier.
"""

def __init__(
self,
pubmed: PubMedClient,
llm_caller: LLMCallerProtocol | None,
confidence_threshold: float = 0.4,
) -> None:
super().__init__(pubmed, confidence_threshold=confidence_threshold)
self._llm_caller = llm_caller

async def verify_claim(self, claim: str) -> FactEvidence:
"""Return agent-scored PubMed evidence for one claim."""
if self._llm_caller is None:
return await self._pubmed.verify_claim(claim)

pmids = await self._pubmed.search(claim)
if not pmids:
return await self._pubmed.verify_claim(claim)

import asyncio

summaries, abstracts = await asyncio.gather(
self._pubmed.fetch_summaries(pmids),
self._pubmed.fetch_abstracts(pmids[:5]),
)
abstract_map = {article.pmid: article.abstract for article in abstracts}
articles = []
for summary in summaries:
summary.abstract = abstract_map.get(summary.pmid, "")
articles.append(summary)

if not articles:
return await self._pubmed.verify_claim(claim)

prompt = _build_agent_prompt(claim, articles[:5])
try:
raw_response = await self._llm_caller.call(prompt)
verdict = _parse_agent_verdict(raw_response)
except Exception as exc:
log.debug("agent_fact_check_failed", claim=claim[:50], error=str(exc))
return await self._pubmed.verify_claim(claim)

cited_pmids = set(verdict.get("citations", []))
cited_articles = [article for article in articles if article.pmid in cited_pmids]
evidence_articles = cited_articles or articles[:3]
verdict_name = str(verdict.get("verdict", "inconclusive")).lower()
confidence = _normalize_confidence(verdict.get("confidence", 0.0))
reasoning = str(verdict.get("reasoning", "")).strip()

from medguard.knowledge.pubmed import FactEvidence

return FactEvidence(
claim=claim,
supporting=evidence_articles if verdict_name == "supported" else [],
contradicting=evidence_articles if verdict_name == "contradicted" else [],
total_results=len(articles),
verified=verdict_name == "supported" and confidence >= self._threshold,
confidence=confidence,
summary=f"Agent verdict: {verdict_name}",
reasoning=reasoning,
)

async def verify(self, text: str) -> FactCheckResult:
"""Extract claims and verify each with agent reasoning over PubMed."""
claims = _extract_claims(text)
if not claims:
return FactCheckResult(
claims_checked=0,
verified_claims=[],
unverified_claims=[],
low_confidence_claims=[],
overall_confidence=1.0,
pubmed_evidence=[],
flagged=False,
annotation="",
)

import asyncio

evidences: list[FactEvidence] = await asyncio.gather(
*[self.verify_claim(c) for c in claims],
return_exceptions=True,
)

verified = []
unverified = []
low_confidence = []
evidence_summaries = []

for claim, ev in zip(claims, evidences):
if isinstance(ev, Exception):
log.debug("agent_fact_check_error", claim=claim[:50], error=str(ev))
continue

evidence_summaries.append({
"claim": ev.claim,
"verified": ev.verified,
"confidence": round(ev.confidence, 2),
"summary": ev.summary,
"reasoning": ev.reasoning,
"supporting_pmids": [a.pmid for a in ev.supporting[:3]],
"contradicting_pmids": [a.pmid for a in ev.contradicting[:3]],
})

if ev.total_results == 0:
unverified.append(claim)
elif ev.confidence < self._threshold:
low_confidence.append(claim)
else:
verified.append(claim)

total = len(verified) + len(unverified) + len(low_confidence)
overall = sum(
e["confidence"] for e in evidence_summaries
) / max(len(evidence_summaries), 1)
flagged = len(low_confidence) > 0 or len(unverified) > total * 0.5
annotation = _build_annotation(verified, unverified, low_confidence, evidence_summaries)

return FactCheckResult(
claims_checked=len(claims),
verified_claims=verified,
unverified_claims=unverified,
low_confidence_claims=low_confidence,
overall_confidence=round(overall, 2),
pubmed_evidence=evidence_summaries,
flagged=flagged,
annotation=annotation,
)


def _extract_claims(text: str) -> list[str]:
"""Extract falsifiable medical claims from text using regex patterns."""
seen: set[str] = set()
Expand All @@ -161,6 +300,46 @@ def _extract_claims(text: str) -> list[str]:
return claims[:8] # cap at 8 to avoid excessive API calls


def _build_agent_prompt(claim: str, articles: list) -> str:
evidence = []
for article in articles:
abstract = article.abstract or article.title
evidence.append(
f"PMID: {article.pmid}\nTitle: {article.title}\nAbstract: {abstract[:1200]}"
)
return (
"You are verifying a medical claim against retrieved PubMed abstracts.\n"
"Return only JSON with keys: verdict, confidence, citations, reasoning.\n"
"verdict must be one of: supported, contradicted, inconclusive.\n"
"confidence must be a number from 0 to 1.\n"
"citations must contain only PMIDs from the provided evidence.\n\n"
f"Claim: {claim}\n\n"
"Evidence:\n" + "\n\n".join(evidence)
)


def _parse_agent_verdict(raw_response: str) -> dict:
raw_response = raw_response.strip()
if raw_response.startswith("```"):
raw_response = re.sub(r"^```(?:json)?\s*", "", raw_response)
raw_response = re.sub(r"\s*```$", "", raw_response)
parsed = json.loads(raw_response)
verdict = str(parsed.get("verdict", "inconclusive")).lower()
if verdict not in {"supported", "contradicted", "inconclusive"}:
parsed["verdict"] = "inconclusive"
citations = parsed.get("citations", [])
parsed["citations"] = [str(citation) for citation in citations if citation]
return parsed


def _normalize_confidence(value) -> float:
try:
confidence = float(value)
except (TypeError, ValueError):
return 0.0
return max(0.0, min(1.0, confidence))


def _build_annotation(
verified: list[str],
unverified: list[str],
Expand Down
1 change: 1 addition & 0 deletions medguard/knowledge/pubmed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class FactEvidence:
verified: bool = False
confidence: float = 0.0
summary: str = ""
reasoning: str = ""


class PubMedClient:
Expand Down
Loading