From 662df9f8491f9ed223a1fe61389d2cc517abc653 Mon Sep 17 00:00:00 2001 From: soria Date: Wed, 6 May 2026 16:39:02 -0400 Subject: [PATCH] Add agent-backed PubMed fact verifier --- medguard/config.py | 1 + medguard/core.py | 31 ++++-- medguard/guardrails/fact_check.py | 179 ++++++++++++++++++++++++++++++ medguard/knowledge/pubmed.py | 1 + tests/test_fact_check.py | 141 +++++++++++++++++++++++ 5 files changed, 343 insertions(+), 10 deletions(-) create mode 100644 tests/test_fact_check.py diff --git a/medguard/config.py b/medguard/config.py index 86772e0..b738e1f 100644 --- a/medguard/config.py +++ b/medguard/config.py @@ -56,6 +56,7 @@ class HallucinationConfig(BaseModel): class FactCheckConfig(BaseModel): enabled: bool = False # opt-in: requires network calls to PubMed + use_agent: bool = False # opt-in: requires an LLM caller for agent reasoning confidence_threshold: float = 0.4 max_claims_per_response: int = 5 ncbi_api_key_env: str = "NCBI_API_KEY" diff --git a/medguard/core.py b/medguard/core.py index 3c7efcd..4601a23 100644 --- a/medguard/core.py +++ b/medguard/core.py @@ -70,6 +70,11 @@ def __init__( self.fact_verifier: FactVerifier | None = None self._llm_caller: LLMCallerProtocol | None = None + try: + self._llm_caller = _build_llm_caller(self.config) + except Exception as exc: + log.warning("llm_caller_init_failed", error=str(exc)) + self._build_components() self.pipeline = self._build_pipeline() @@ -204,7 +209,7 @@ def _build_components(self) -> None: try: import httpx - from medguard.guardrails.fact_check import FactVerifier + from medguard.guardrails.fact_check import AgentFactVerifier, FactVerifier from medguard.knowledge.pubmed import PubMedClient if not hasattr(self, "_http_client"): @@ -213,19 +218,25 @@ def _build_components(self) -> None: self._http_client, max_results=cfg.guardrails.fact_checking.max_claims_per_response, ) - self.fact_verifier = FactVerifier( - pubmed, - confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold, + verifier_cls = ( + AgentFactVerifier + if cfg.guardrails.fact_checking.use_agent + else FactVerifier ) + if verifier_cls is AgentFactVerifier: + self.fact_verifier = verifier_cls( + pubmed, + self._llm_caller, + confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold, + ) + else: + self.fact_verifier = verifier_cls( + pubmed, + confidence_threshold=cfg.guardrails.fact_checking.confidence_threshold, + ) except Exception as exc: log.warning("fact_verifier_init_failed", error=str(exc)) - # LLM caller - try: - self._llm_caller = _build_llm_caller(cfg) - except Exception as exc: - log.warning("llm_caller_init_failed", error=str(exc)) - def _build_pipeline(self) -> GuardrailPipeline: return GuardrailPipeline( config=self.config, diff --git a/medguard/guardrails/fact_check.py b/medguard/guardrails/fact_check.py index 1a4c8e6..6e1b472 100644 --- a/medguard/guardrails/fact_check.py +++ b/medguard/guardrails/fact_check.py @@ -12,6 +12,7 @@ """ from __future__ import annotations +import json import re from typing import TYPE_CHECKING @@ -19,6 +20,7 @@ from pydantic import BaseModel if TYPE_CHECKING: + from medguard.guardrails.protocols import LLMCallerProtocol from medguard.knowledge.pubmed import FactEvidence, PubMedClient log = structlog.get_logger(__name__) @@ -144,6 +146,143 @@ async def verify(self, text: str) -> FactCheckResult: ) +class AgentFactVerifier(FactVerifier): + """ + Verifies claims by asking an LLM to reason over retrieved PubMed abstracts. + + The agent path is opt-in. If no LLM caller is configured, it falls back to + the keyword-backed PubMed verifier. + """ + + def __init__( + self, + pubmed: PubMedClient, + llm_caller: LLMCallerProtocol | None, + confidence_threshold: float = 0.4, + ) -> None: + super().__init__(pubmed, confidence_threshold=confidence_threshold) + self._llm_caller = llm_caller + + async def verify_claim(self, claim: str) -> FactEvidence: + """Return agent-scored PubMed evidence for one claim.""" + if self._llm_caller is None: + return await self._pubmed.verify_claim(claim) + + pmids = await self._pubmed.search(claim) + if not pmids: + return await self._pubmed.verify_claim(claim) + + import asyncio + + summaries, abstracts = await asyncio.gather( + self._pubmed.fetch_summaries(pmids), + self._pubmed.fetch_abstracts(pmids[:5]), + ) + abstract_map = {article.pmid: article.abstract for article in abstracts} + articles = [] + for summary in summaries: + summary.abstract = abstract_map.get(summary.pmid, "") + articles.append(summary) + + if not articles: + return await self._pubmed.verify_claim(claim) + + prompt = _build_agent_prompt(claim, articles[:5]) + try: + raw_response = await self._llm_caller.call(prompt) + verdict = _parse_agent_verdict(raw_response) + except Exception as exc: + log.debug("agent_fact_check_failed", claim=claim[:50], error=str(exc)) + return await self._pubmed.verify_claim(claim) + + cited_pmids = set(verdict.get("citations", [])) + cited_articles = [article for article in articles if article.pmid in cited_pmids] + evidence_articles = cited_articles or articles[:3] + verdict_name = str(verdict.get("verdict", "inconclusive")).lower() + confidence = _normalize_confidence(verdict.get("confidence", 0.0)) + reasoning = str(verdict.get("reasoning", "")).strip() + + from medguard.knowledge.pubmed import FactEvidence + + return FactEvidence( + claim=claim, + supporting=evidence_articles if verdict_name == "supported" else [], + contradicting=evidence_articles if verdict_name == "contradicted" else [], + total_results=len(articles), + verified=verdict_name == "supported" and confidence >= self._threshold, + confidence=confidence, + summary=f"Agent verdict: {verdict_name}", + reasoning=reasoning, + ) + + async def verify(self, text: str) -> FactCheckResult: + """Extract claims and verify each with agent reasoning over PubMed.""" + claims = _extract_claims(text) + if not claims: + return FactCheckResult( + claims_checked=0, + verified_claims=[], + unverified_claims=[], + low_confidence_claims=[], + overall_confidence=1.0, + pubmed_evidence=[], + flagged=False, + annotation="", + ) + + import asyncio + + evidences: list[FactEvidence] = await asyncio.gather( + *[self.verify_claim(c) for c in claims], + return_exceptions=True, + ) + + verified = [] + unverified = [] + low_confidence = [] + evidence_summaries = [] + + for claim, ev in zip(claims, evidences): + if isinstance(ev, Exception): + log.debug("agent_fact_check_error", claim=claim[:50], error=str(ev)) + continue + + evidence_summaries.append({ + "claim": ev.claim, + "verified": ev.verified, + "confidence": round(ev.confidence, 2), + "summary": ev.summary, + "reasoning": ev.reasoning, + "supporting_pmids": [a.pmid for a in ev.supporting[:3]], + "contradicting_pmids": [a.pmid for a in ev.contradicting[:3]], + }) + + if ev.total_results == 0: + unverified.append(claim) + elif ev.confidence < self._threshold: + low_confidence.append(claim) + else: + verified.append(claim) + + total = len(verified) + len(unverified) + len(low_confidence) + overall = sum( + e["confidence"] for e in evidence_summaries + ) / max(len(evidence_summaries), 1) + flagged = len(low_confidence) > 0 or len(unverified) > total * 0.5 + annotation = _build_annotation(verified, unverified, low_confidence, evidence_summaries) + + return FactCheckResult( + claims_checked=len(claims), + verified_claims=verified, + unverified_claims=unverified, + low_confidence_claims=low_confidence, + overall_confidence=round(overall, 2), + pubmed_evidence=evidence_summaries, + flagged=flagged, + annotation=annotation, + ) + + def _extract_claims(text: str) -> list[str]: """Extract falsifiable medical claims from text using regex patterns.""" seen: set[str] = set() @@ -161,6 +300,46 @@ def _extract_claims(text: str) -> list[str]: return claims[:8] # cap at 8 to avoid excessive API calls +def _build_agent_prompt(claim: str, articles: list) -> str: + evidence = [] + for article in articles: + abstract = article.abstract or article.title + evidence.append( + f"PMID: {article.pmid}\nTitle: {article.title}\nAbstract: {abstract[:1200]}" + ) + return ( + "You are verifying a medical claim against retrieved PubMed abstracts.\n" + "Return only JSON with keys: verdict, confidence, citations, reasoning.\n" + "verdict must be one of: supported, contradicted, inconclusive.\n" + "confidence must be a number from 0 to 1.\n" + "citations must contain only PMIDs from the provided evidence.\n\n" + f"Claim: {claim}\n\n" + "Evidence:\n" + "\n\n".join(evidence) + ) + + +def _parse_agent_verdict(raw_response: str) -> dict: + raw_response = raw_response.strip() + if raw_response.startswith("```"): + raw_response = re.sub(r"^```(?:json)?\s*", "", raw_response) + raw_response = re.sub(r"\s*```$", "", raw_response) + parsed = json.loads(raw_response) + verdict = str(parsed.get("verdict", "inconclusive")).lower() + if verdict not in {"supported", "contradicted", "inconclusive"}: + parsed["verdict"] = "inconclusive" + citations = parsed.get("citations", []) + parsed["citations"] = [str(citation) for citation in citations if citation] + return parsed + + +def _normalize_confidence(value) -> float: + try: + confidence = float(value) + except (TypeError, ValueError): + return 0.0 + return max(0.0, min(1.0, confidence)) + + def _build_annotation( verified: list[str], unverified: list[str], diff --git a/medguard/knowledge/pubmed.py b/medguard/knowledge/pubmed.py index c70a818..8e8ceee 100644 --- a/medguard/knowledge/pubmed.py +++ b/medguard/knowledge/pubmed.py @@ -44,6 +44,7 @@ class FactEvidence: verified: bool = False confidence: float = 0.0 summary: str = "" + reasoning: str = "" class PubMedClient: diff --git a/tests/test_fact_check.py b/tests/test_fact_check.py new file mode 100644 index 0000000..c74a7e7 --- /dev/null +++ b/tests/test_fact_check.py @@ -0,0 +1,141 @@ +import json + +import pytest + +from medguard.config import ( + DrugSafetyConfig, + FactCheckConfig, + GuardrailsConfig, + HallucinationConfig, + MedGuardConfig, + PHIConfig, + ScopeConfig, +) +from medguard.core import MedGuard +from medguard.guardrails.fact_check import AgentFactVerifier +from medguard.knowledge.pubmed import FactEvidence, PubMedArticle + + +class FakePubMed: + def __init__(self): + self.fallback_called = False + + async def search(self, query: str) -> list[str]: + assert "metformin" in query.lower() + return ["123", "456"] + + async def fetch_summaries(self, pmids: list[str]) -> list[PubMedArticle]: + return [ + PubMedArticle( + pmid="123", + title="Metformin for type 2 diabetes", + abstract="", + journal="Example Journal", + year="2024", + ), + PubMedArticle( + pmid="456", + title="Unrelated diabetes review", + abstract="", + journal="Example Journal", + year="2023", + ), + ] + + async def fetch_abstracts(self, pmids: list[str]) -> list[PubMedArticle]: + return [ + PubMedArticle( + pmid="123", + title="Metformin for type 2 diabetes", + abstract="Metformin improved glycemic control in adults with type 2 diabetes.", + ) + ] + + async def verify_claim(self, claim: str) -> FactEvidence: + self.fallback_called = True + return FactEvidence( + claim=claim, + total_results=1, + verified=True, + confidence=0.6, + summary="keyword fallback", + ) + + +class FakeLLM: + def __init__(self, payload: dict): + self.payload = payload + self.prompt = "" + + async def call(self, prompt: str) -> str: + self.prompt = prompt + return json.dumps(self.payload) + + +@pytest.mark.asyncio +async def test_agent_fact_verifier_returns_reasoned_evidence(): + llm = FakeLLM({ + "verdict": "supported", + "confidence": 0.82, + "citations": ["123"], + "reasoning": "The retrieved abstract directly discusses metformin use.", + }) + verifier = AgentFactVerifier(FakePubMed(), llm, confidence_threshold=0.4) + + evidence = await verifier.verify_claim("Metformin is effective for type 2 diabetes") + + assert evidence.verified is True + assert evidence.confidence == 0.82 + assert evidence.reasoning == "The retrieved abstract directly discusses metformin use." + assert [article.pmid for article in evidence.supporting] == ["123"] + assert "Return only JSON" in llm.prompt + assert "PMID: 123" in llm.prompt + + +@pytest.mark.asyncio +async def test_agent_fact_verifier_falls_back_without_llm(): + pubmed = FakePubMed() + verifier = AgentFactVerifier(pubmed, llm_caller=None, confidence_threshold=0.4) + + evidence = await verifier.verify_claim("Metformin is effective for type 2 diabetes") + + assert pubmed.fallback_called is True + assert evidence.summary == "keyword fallback" + + +@pytest.mark.asyncio +async def test_agent_fact_verifier_verify_includes_reasoning_summary(): + llm = FakeLLM({ + "verdict": "inconclusive", + "confidence": 0.2, + "citations": ["123"], + "reasoning": "The abstract discusses the topic but does not support the claim.", + }) + verifier = AgentFactVerifier(FakePubMed(), llm, confidence_threshold=0.4) + + result = await verifier.verify("Metformin is effective for type 2 diabetes") + + assert result.claims_checked == 1 + assert result.flagged is True + assert result.pubmed_evidence[0]["reasoning"] == ( + "The abstract discusses the topic but does not support the claim." + ) + + +def test_medguard_builds_agent_fact_verifier_when_enabled(monkeypatch): + llm = FakeLLM({}) + monkeypatch.setattr("medguard.core._build_llm_caller", lambda config: llm) + config = MedGuardConfig( + guardrails=GuardrailsConfig( + phi_detection=PHIConfig(enabled=False), + drug_safety=DrugSafetyConfig(enabled=False), + scope_enforcement=ScopeConfig(enabled=False), + hallucination_detection=HallucinationConfig(enabled=False), + fact_checking=FactCheckConfig(enabled=True, use_agent=True), + ) + ) + + medguard = MedGuard(config=config) + + assert isinstance(medguard.fact_verifier, AgentFactVerifier) + assert medguard.fact_verifier._llm_caller is llm