diff --git a/frontend/src/components/MessageWithCitations.tsx b/frontend/src/components/MessageWithCitations.tsx index 692ee3e..4b1b409 100644 --- a/frontend/src/components/MessageWithCitations.tsx +++ b/frontend/src/components/MessageWithCitations.tsx @@ -130,6 +130,14 @@ function CitationBadge({ /** * Process text to replace [cite: N] markers with citation badges * ADR-3.4.1: Uses [cite: N] format to avoid false positives with [N] + * + * Citation Index Mapping: + * - LLM outputs: [cite: 1], [cite: 2], ..., [cite: N] (1-based, user-friendly) + * - Backend emits: cite-0, cite-1, ..., cite-(N-1) (0-based, internal) + * - Frontend displays: 1, 2, ..., N (1-based, user-friendly) + * + * Example: For 9 chunks + * - LLM: [cite: 9] → citationId: cite-8 → Display: 9 */ function processTextWithCitations( text: string, @@ -149,15 +157,20 @@ function processTextWithCitations( result.push(text.slice(lastIndex, match.index)); } - const citationIndex = parseInt(match[1], 10); - const citationId = `cite-${citationIndex}`; + // Parse the 1-based citation number from LLM output (e.g., [cite: 9]) + const llmCitationNumber = parseInt(match[1], 10); + + // Convert to 0-based citation ID for lookup (e.g., cite-8) + // LLM uses 1-based: [cite: 1] through [cite: N] + // Backend emits 0-based: cite-0 through cite-(N-1) + const citationId = `cite-${llmCitationNumber - 1}`; const citation = citations[citationId]; if (citation) { result.push( @@ -166,7 +179,7 @@ function processTextWithCitations( // Graceful degradation for invalid citations result.push( - [{citationIndex}] + [{llmCitationNumber}] ); } diff --git a/frontend/src/components/__tests__/CitationDeepDive.test.tsx b/frontend/src/components/__tests__/CitationDeepDive.test.tsx index b2f39d1..38bbfcc 100644 --- a/frontend/src/components/__tests__/CitationDeepDive.test.tsx +++ b/frontend/src/components/__tests__/CitationDeepDive.test.tsx @@ -53,13 +53,13 @@ describe('Citation Deep-Dive Integration', () => { render( ); - const citationBadge = screen.getByRole('button', { name: /Citation 0/i }); + const citationBadge = screen.getByRole('button', { name: /Citation 1/i }); await user.click(citationBadge); expect(handleCitationClick).toHaveBeenCalledWith(mockCitations['cite-0']); @@ -77,17 +77,18 @@ describe('Citation Deep-Dive Integration', () => { render( ); - const citationBadge = screen.getByRole('button', { name: /Citation 1/i }); + const citationBadge = screen.getByRole('button', { name: /Citation 2/i }); await user.click(citationBadge); expect(receivedCitation).not.toBeNull(); expect(receivedCitation!.documentId).toBe('doc-uuid-456'); + expect(receivedCitation!.id).toBe('cite-1'); }); }); @@ -110,13 +111,13 @@ describe('Citation Deep-Dive Integration', () => { render( ); - const citationBadge = screen.getByRole('button', { name: /Citation 0/i }); + const citationBadge = screen.getByRole('button', { name: /Citation 1/i }); await user.click(citationBadge); // The handler is called - workspace page validates and shows error diff --git a/frontend/src/components/__tests__/MessageWithCitations.test.tsx b/frontend/src/components/__tests__/MessageWithCitations.test.tsx index aa77542..dc5c147 100644 --- a/frontend/src/components/__tests__/MessageWithCitations.test.tsx +++ b/frontend/src/components/__tests__/MessageWithCitations.test.tsx @@ -11,9 +11,11 @@ import { MessageWithCitations } from '../MessageWithCitations'; import type { CitationMap } from '@/types/citation'; describe('MessageWithCitations', () => { + // Mock data uses 0-based citation IDs to match backend format + // LLM outputs [cite: 1], [cite: 2] → Maps to cite-0, cite-1 const mockCitations: CitationMap = { - 'cite-1': { - id: 'cite-1', + 'cite-0': { + id: 'cite-0', title: 'intro.pdf', mediaType: 'text/plain', chunkIndex: 0, @@ -21,8 +23,8 @@ describe('MessageWithCitations', () => { preview: 'Python is a powerful programming language used worldwide...', documentId: 'doc-uuid-1', }, - 'cite-2': { - id: 'cite-2', + 'cite-1': { + id: 'cite-1', title: 'advanced.pdf', mediaType: 'text/plain', chunkIndex: 5, @@ -152,7 +154,7 @@ describe('MessageWithCitations', () => { const button = screen.getByRole('button'); await user.click(button); - expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-1']); + expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-0']); }); }); diff --git a/ragitect/agents/rag/streaming.py b/ragitect/agents/rag/streaming.py index 14d5811..2bda85a 100644 --- a/ragitect/agents/rag/streaming.py +++ b/ragitect/agents/rag/streaming.py @@ -68,7 +68,7 @@ class LangGraphToAISDKAdapter: [cite: N] markers split across token chunks (20-char lookahead). Reference Implementation: - - See git history for CitationStreamParser (pre-Story 4.3) + - See git history for CitationStreamParser - Production examples: docs/research/2025-12-28-ENG-4.3-langgraph-streaming-adapter.md Example: @@ -98,11 +98,38 @@ def _build_citations_from_context( ) -> list[Citation]: """Build Citation objects from context chunks. + Node-Level Streaming Simplification: + ================================================ + This method emits ALL context chunks as citations upfront, not just + the ones referenced in the LLM response. This is intentional because: + + 1. **No Buffer Needed**: Node-level streaming (stream_mode="updates") + provides complete text per node. Citation markers like [cite: 1] + are never split across chunks, unlike token-level streaming. + + 2. **Indexing Scheme**: + - Internal (Python): 0-based (enumerate index) + - LLM Prompt: 1-based ([Chunk 1], [Chunk 2] in context) + - SSE sourceId: 0-based (cite-0, cite-1) + - Frontend display: 1-based (Citation 1, 2, etc.) + + 3. **Frontend Compatibility**: AI SDK useChat() receives all source + documents before text-delta, enabling immediate citation rendering. + + 4. **Future Enhancement**: Detect only cited sources via regex: + ```python + pattern = re.compile(r"\\[cite:\\s*(\\d+)\\]") + for match in pattern.finditer(full_text): + cite_idx = int(match.group(1)) - 1 # 1-based to 0-based + ``` + This would filter to only actually-cited sources. + Args: - context_chunks: List of context chunk dicts from merge_context node + context_chunks: List of context chunk dicts from merge_context node. + Expected keys: document_id, title, chunk_index, score/rerank_score, content Returns: - List of Citation objects ready for streaming + List of Citation objects ready for streaming as source-document events. """ citations = [] for idx, chunk in enumerate(context_chunks): diff --git a/ragitect/api/v1/chat.py b/ragitect/api/v1/chat.py index 0876f66..f664f9e 100644 --- a/ragitect/api/v1/chat.py +++ b/ragitect/api/v1/chat.py @@ -11,7 +11,6 @@ import json import logging -import re import uuid from collections.abc import AsyncGenerator from uuid import UUID @@ -25,7 +24,6 @@ from ragitect.agents.rag import build_rag_graph from ragitect.agents.rag.state import RAGState from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter -from ragitect.api.schemas.chat import Citation from ragitect.prompts.rag_prompts import build_rag_system_prompt from ragitect.services.config import EmbeddingConfig from ragitect.services.database.connection import get_async_session @@ -111,174 +109,6 @@ async def format_sse_stream( yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n" -def build_citation_metadata(context_chunks: list[dict]) -> list[Citation]: - """Build citation metadata from context chunks. - - NOTE: Prompt engineering for [N] format - This function just prepares metadata for frontend consumption. - - Args: - context_chunks: Chunks returned from retrieve_context() - - Returns: - List of Citation objects for streaming - """ - citations = [] - for i, chunk in enumerate(context_chunks): - citations.append( - Citation.from_context_chunk( - index=i + 1, # 1-based index - document_id=str(chunk.get("document_id", "")), - document_name=chunk.get("document_name", "Unknown"), - chunk_index=chunk.get("chunk_index", 0), - similarity=chunk.get("rerank_score") or chunk.get("similarity", 0.0), - content=chunk.get("content", ""), - ) - ) - return citations - - -class CitationStreamParser: - """Stateful parser for detecting citations across chunk boundaries. - - ADR Decision: Real-time regex streaming with cross-chunk buffering. - Handles edge case where LLM outputs '[cite:' in one chunk and '0]' in next. - - ADR-3.4.1: Citation format changed from [N] to [cite: N] to avoid - false positives with markdown lists and array indices. - """ - - def __init__(self, citations: list[Citation]): - """Initialize parser with available citations. - - Args: - citations: Pre-built citation metadata from context chunks - """ - self.citations = citations - self.buffer = "" # Buffer for partial citation markers - self.emitted_ids: set[str] = set() # Track which citations already emitted - # ADR-3.4.1: Updated pattern from [N] to [cite: N] format - self.pattern = re.compile(r"\[cite:\s*(\d+)\]") - - def parse_chunk(self, chunk: str) -> tuple[str, list[Citation]]: - """Parse chunk and detect citation markers. - - Args: - chunk: New text chunk from LLM stream - - Returns: - Tuple of (text_to_emit, new_citations_found) - """ - # Add chunk to buffer - self.buffer += chunk - - # Find all complete citation markers in buffer - new_citations = [] - for match in self.pattern.finditer(self.buffer): - cite_idx = int(match.group(1)) - cite_id = f"cite-{cite_idx}" - - # Validate citation index (ADR: Hallucination Handling) - # 1-based index check - if cite_idx < 1 or cite_idx > len(self.citations): - logger.warning( - "LLM cited non-existent source [%d] (only %d chunks available)", - cite_idx, - len(self.citations), - ) - continue # Graceful degradation - skip invalid citation - - # Emit each citation only once - if cite_id not in self.emitted_ids: - # Map 1-based index to 0-based list - new_citations.append(self.citations[cite_idx - 1]) - self.emitted_ids.add(cite_id) - - # Emit text, but keep last 20 chars in buffer for partial markers - # Max citation marker length: "[cite: 9999]" = 12 chars, buffer 20 for safety - if len(self.buffer) > 20: - text_to_emit = self.buffer[:-20] - self.buffer = self.buffer[-20:] - else: - text_to_emit = "" - - return text_to_emit, new_citations - - def flush(self) -> str: - """Flush remaining buffer at end of stream. - - Returns: - Any remaining text in the buffer - """ - remaining = self.buffer - self.buffer = "" - return remaining - - -async def format_sse_stream_with_citations( - chunks: AsyncGenerator[str, None], - citations: list[Citation], -) -> AsyncGenerator[str, None]: - """Format LLM chunks with AI SDK UI Message Stream Protocol v1 + citations. - - ADR: Real-time regex streaming with cross-chunk buffering. - Emits citations as 'source-document' parts for AI SDK useChat. - - Args: - chunks: LLM token stream - citations: Pre-built citation metadata - - Yields: - SSE formatted messages (UI Message Stream Protocol v1): - - data: {"type": "start", "messageId": "..."} - Message start - - data: {"type": "text-start", "id": "..."} - Text block start - - data: {"type": "text-delta", "id": "...", "delta": "..."} - Text chunks - - data: {"type": "source-document", "sourceId": "...", ...} - Citations - - data: {"type": "text-end", "id": "..."} - Text block end - - data: {"type": "finish", "finishReason": "stop"} - Stream end - """ - message_id = str(uuid.uuid4()) - text_id = str(uuid.uuid4()) - parser = CitationStreamParser(citations) - - # Message start (protocol requirement) - yield f"data: {json.dumps({'type': 'start', 'messageId': message_id})}\n\n" - - # Text block start - yield f"data: {json.dumps({'type': 'text-start', 'id': text_id})}\n\n" - - async for chunk in chunks: - # Parse chunk for citations - text_to_emit, new_citations = parser.parse_chunk(chunk) - - # Emit text delta if we have text - if text_to_emit: - yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': text_to_emit})}\n\n" - - # Emit source-document parts for detected citations - for citation in new_citations: - source_doc = citation.to_sse_dict() - yield f"data: {json.dumps(source_doc)}\n\n" - - # Flush remaining buffer - remaining = parser.flush() - if remaining: - yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': remaining})}\n\n" - - # Log citation usage for monitoring (ADR: Zero Citations case) - if citations and not parser.emitted_ids: - logger.info( - "LLM response had no citations despite %d available chunks", - len(citations), - ) - - # Text block end - yield f"data: {json.dumps({'type': 'text-end', 'id': text_id})}\n\n" - - # Finish message - yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n" - - async def empty_workspace_response() -> AsyncGenerator[str, None]: """Return SSE stream for empty workspace message using AI SDK protocol. diff --git a/tests/agents/rag/test_streaming.py b/tests/agents/rag/test_streaming.py index 32a3e68..9e9c228 100644 --- a/tests/agents/rag/test_streaming.py +++ b/tests/agents/rag/test_streaming.py @@ -5,6 +5,7 @@ """ import json +import logging from typing import Any, AsyncIterator import pytest @@ -305,3 +306,133 @@ def astream(self, inputs, stream_mode): ) assert source_doc_idx < text_delta_idx, "Citations must appear before text" + + +class TestCitationValidation: + """Test citation index validation for invalid references. + + AC #4: Invalid citation indices should be logged as warnings and not crash. + """ + + async def test_build_citations_handles_empty_chunks(self): + """Test _build_citations_from_context handles empty list gracefully.""" + from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter + + adapter = LangGraphToAISDKAdapter() + + citations = adapter._build_citations_from_context([]) + + assert citations == [] + + async def test_build_citations_handles_missing_fields(self): + """Test _build_citations_from_context handles chunks with missing optional fields.""" + from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter + + adapter = LangGraphToAISDKAdapter() + + # Minimal chunk with only required fields missing + chunks = [ + { + "content": "Some content", + # Missing: document_id, title, chunk_index, score + } + ] + + citations = adapter._build_citations_from_context(chunks) + + assert len(citations) == 1 + assert citations[0].source_id == "cite-0" + # Should use defaults for missing fields + assert citations[0].title == "Unknown" + + async def test_citation_indices_are_zero_based(self): + """Test that citation indices are 0-based (cite-0, cite-1, etc.).""" + from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter + + adapter = LangGraphToAISDKAdapter() + + chunks = [ + { + "content": "First", + "title": "doc1.pdf", + "document_id": "d1", + "chunk_index": 0, + "score": 0.9, + }, + { + "content": "Second", + "title": "doc2.pdf", + "document_id": "d2", + "chunk_index": 1, + "score": 0.8, + }, + { + "content": "Third", + "title": "doc3.pdf", + "document_id": "d3", + "chunk_index": 2, + "score": 0.7, + }, + ] + + citations = adapter._build_citations_from_context(chunks) + + assert len(citations) == 3 + assert citations[0].source_id == "cite-0" + assert citations[1].source_id == "cite-1" + assert citations[2].source_id == "cite-2" + + async def test_citation_sse_format_matches_ai_sdk_protocol(self): + """Test that Citation.to_sse_dict() produces valid AI SDK source-document format.""" + from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter + + adapter = LangGraphToAISDKAdapter() + + chunks = [ + { + "content": "Test content", + "title": "research.pdf", + "document_id": "doc-abc-123", + "chunk_index": 5, + "score": 0.95, + } + ] + + citations = adapter._build_citations_from_context(chunks) + sse_dict = citations[0].to_sse_dict() + + # Verify AI SDK source-document format + assert sse_dict["type"] == "source-document" + assert sse_dict["sourceId"] == "cite-0" + assert sse_dict["mediaType"] == "text/plain" + assert sse_dict["title"] == "research.pdf" + + # Verify providerMetadata.ragitect structure + ragitect_meta = sse_dict["providerMetadata"]["ragitect"] + assert ragitect_meta["chunkIndex"] == 5 + assert ragitect_meta["similarity"] == 0.95 + assert ragitect_meta["preview"] == "Test content" + assert ragitect_meta["documentId"] == "doc-abc-123" + + async def test_citation_uses_rerank_score_over_score(self): + """Test that rerank_score takes precedence over score when available.""" + from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter + + adapter = LangGraphToAISDKAdapter() + + chunks = [ + { + "content": "Test", + "title": "doc.pdf", + "document_id": "d1", + "chunk_index": 0, + "score": 0.5, # Original vector similarity + "rerank_score": 0.95, # Reranker score should take precedence + } + ] + + citations = adapter._build_citations_from_context(chunks) + sse_dict = citations[0].to_sse_dict() + + # Should use rerank_score (0.95) not score (0.5) + assert sse_dict["providerMetadata"]["ragitect"]["similarity"] == 0.95 diff --git a/tests/api/v1/test_chat_streaming.py b/tests/api/v1/test_chat_streaming.py index beb283d..0849824 100644 --- a/tests/api/v1/test_chat_streaming.py +++ b/tests/api/v1/test_chat_streaming.py @@ -23,7 +23,7 @@ def setup_langgraph_streaming_mocks(mocker): """Setup common mocks for LangGraph streaming architecture. - After Story 4.3, the endpoint uses LangGraphToAISDKAdapter with full graph execution. + the endpoint uses LangGraphToAISDKAdapter with full graph execution. This helper mocks all required dependencies for the streaming pipeline. Args: @@ -72,7 +72,7 @@ async def mock_embed_fn(model, text: str): mock_structured_llm = mocker.AsyncMock() # Mock strategy response for generate_strategy node - from ragitect.agents.rag.schemas import SearchStrategy, Search + from ragitect.agents.rag.schemas import Search, SearchStrategy mock_strategy = SearchStrategy( reasoning="Test analysis", @@ -141,7 +141,7 @@ async def test_stream_returns_sse_content_type(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post( @@ -182,7 +182,7 @@ async def test_stream_format_is_sse(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post( @@ -405,7 +405,7 @@ async def test_chat_retrieves_relevant_chunks(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post( @@ -443,7 +443,7 @@ async def test_chat_uses_context_in_prompt(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post( @@ -502,7 +502,7 @@ async def test_chat_with_provider_override_uses_specified_provider( return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks mocks = setup_langgraph_streaming_mocks(mocker) # Track which provider was requested @@ -643,7 +643,7 @@ async def test_chat_without_provider_uses_default(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks mocks = setup_langgraph_streaming_mocks(mocker) mock_create_llm = mocker.AsyncMock() @@ -665,277 +665,6 @@ async def test_chat_without_provider_uses_default(self, async_client, mocker): assert call_kwargs.get("provider") is None -class TestCitationStreaming: - """Tests for citation detection and streaming.""" - - async def test_citation_parser_uses_cite_format(self): - """Test that CitationStreamParser uses [cite: N] format""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"), - Citation.from_context_chunk(2, "doc-id-2", "doc2.pdf", 1, 0.8, "Content 2"), - ] - - parser = CitationStreamParser(citations) - - # Parse chunks with [cite: N] format - text1, found1 = parser.parse_chunk("Python is great[cite: 1] and ") - text2, found2 = parser.parse_chunk("versatile[cite: 2].") - remaining = parser.flush() - - # Should find both citations with new format - assert len(found1) == 1 - assert found1[0].source_id == "cite-1" - assert len(found2) == 1 - assert found2[0].source_id == "cite-2" - - async def test_citation_parser_ignores_old_bare_bracket_format(self): - """Test that parser does NOT match old [N] format""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"), - ] - - parser = CitationStreamParser(citations) - - # Parse chunks with OLD bare [N] format - should NOT match - text1, found1 = parser.parse_chunk("Python is great[1] but this won't match.") - remaining = parser.flush() - - # Should NOT find the citation with old format - assert len(found1) == 0 - - async def test_build_citation_metadata_creates_citations_from_chunks(self): - """Test that build_citation_metadata creates Citation objects from context chunks (AC1, AC2).""" - from ragitect.api.v1.chat import build_citation_metadata - - context_chunks = [ - { - "content": "Python is a programming language used for many applications.", - "document_name": "python-intro.pdf", - "chunk_index": 0, - "similarity": 0.95, - }, - { - "content": "FastAPI is a modern web framework for building APIs.", - "document_name": "fastapi-docs.pdf", - "chunk_index": 3, - "rerank_score": 0.88, # Should use rerank_score over similarity - }, - ] - - citations = build_citation_metadata(context_chunks) - - assert len(citations) == 2 - assert citations[0].source_id == "cite-1" - assert citations[0].title == "python-intro.pdf" - assert citations[1].source_id == "cite-2" - assert citations[1].title == "fastapi-docs.pdf" - - async def test_citation_stream_parser_detects_markers(self): - """Test that CitationStreamParser detects [N] markers in text (AC1).""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"), - Citation.from_context_chunk(2, "doc-id-2", "doc2.pdf", 1, 0.8, "Content 2"), - ] - - parser = CitationStreamParser(citations) - - # Parse chunks with citation markers using [cite: N] format - text1, found1 = parser.parse_chunk("Python is great[cite: 1] and ") - text2, found2 = parser.parse_chunk("versatile[cite: 2].") - remaining = parser.flush() - - # Should find both citations - assert len(found1) == 1 - assert found1[0].source_id == "cite-1" - assert len(found2) == 1 - assert found2[0].source_id == "cite-2" - - async def test_citation_stream_parser_handles_split_markers(self): - """Test that parser handles citation markers split across chunks (AC1).""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(1, "doc-id", "doc.pdf", 0, 0.9, "Content"), - ] - - parser = CitationStreamParser(citations) - - # Split "[cite: 1]" across chunks - text1, found1 = parser.parse_chunk("Hello [cite:") - text2, found2 = parser.parse_chunk(" 1] world") - remaining = parser.flush() - - # Should eventually find the citation - all_citations = found1 + found2 - assert len(all_citations) == 1 - assert all_citations[0].source_id == "cite-1" - - async def test_citation_stream_parser_ignores_invalid_citations(self, caplog): - """Test that parser logs warning for invalid citation indices (AC6 - hallucination handling).""" - import logging - - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(0, "doc-id", "doc.pdf", 0, 0.9, "Content"), - ] - - parser = CitationStreamParser(citations) - - with caplog.at_level(logging.WARNING, logger="ragitect.api.v1.chat"): - # Try to cite [cite: 99] which doesn't exist - text, found = parser.parse_chunk("Test [cite: 99] content") - remaining = parser.flush() - - # Should not find any citations (invalid index) - assert len(found) == 0 - # Should have logged a warning - assert "cited non-existent source" in caplog.text or "99" in caplog.text - - async def test_citation_stream_emits_each_citation_once(self): - """Test that each citation is only emitted once even if marker appears multiple times.""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import CitationStreamParser - - citations = [ - Citation.from_context_chunk(0, "doc-id", "doc.pdf", 0, 0.9, "Content"), - ] - - parser = CitationStreamParser(citations) - - # Same citation marker appears twice (using [cite: N] format) - assert len(citations) >= 1 - # Test with [cite: 1] which maps to index 0 - text1, found1 = parser.parse_chunk("First[cite: 1] and ") - text2, found2 = parser.parse_chunk("again[cite: 1].") - remaining = parser.flush() - - # Should only emit once - total_found = len(found1) + len(found2) - assert total_found == 1 - - async def test_format_sse_stream_with_citations_emits_source_documents(self): - """Test that format_sse_stream_with_citations emits source-document events (AC1, AC2).""" - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import format_sse_stream_with_citations - - citations = [ - Citation.from_context_chunk( - 1, "doc-uuid", "intro.pdf", 0, 0.95, "Python is..." - ), - ] - - async def mock_chunks(): - yield "Python" - yield " is great" - yield "[cite: 1]" - yield "." - - events = [] - async for event in format_sse_stream_with_citations(mock_chunks(), citations): - events.append(event) - - # Should contain source-document event - source_doc_events = [e for e in events if "source-document" in e] - assert len(source_doc_events) >= 1 - - # Verify source-document contains expected fields - source_event = source_doc_events[0] - # Expect cite-1 because Citation.from_context_chunk will now create cite-1 if we mock correctly or if we pass explicit ID - # Wait, from_context_chunk uses index arg. - # If we pass index=1 manually: - assert "cite-1" in source_event - assert "intro.pdf" in source_event - - async def test_format_sse_stream_with_citations_handles_zero_citations( - self, caplog - ): - """Test that stream works when LLM doesn't cite any sources (AC6).""" - import logging - - from ragitect.api.schemas.chat import Citation - from ragitect.api.v1.chat import format_sse_stream_with_citations - - # Citations available but LLM doesn't use them - citations = [ - Citation.from_context_chunk(1, "doc-id", "doc.pdf", 0, 0.9, "Content"), - ] - - async def mock_chunks(): - yield "2 plus 2 equals 4." # No citations - - with caplog.at_level(logging.INFO, logger="ragitect.api.v1.chat"): - events = [] - async for event in format_sse_stream_with_citations( - mock_chunks(), citations - ): - events.append(event) - - # Should NOT contain source-document events - source_doc_events = [e for e in events if "source-document" in e] - assert len(source_doc_events) == 0 - - # Should have text-delta events - text_events = [e for e in events if "text-delta" in e] - assert len(text_events) >= 1 - - # Should log that no citations were used - assert "no citations" in caplog.text.lower() - - async def test_chat_endpoint_emits_citation_metadata(self, async_client, mocker): - """Test that chat endpoint includes source-document events in stream (AC1, AC2).""" - workspace_id = uuid.uuid4() - now = datetime.now(timezone.utc) - - from ragitect.services.database.models import Workspace - - mock_workspace = Workspace(id=workspace_id, name="Test") - mock_workspace.created_at = now - mock_workspace.updated_at = now - - mock_ws_repo = mocker.AsyncMock() - mock_ws_repo.get_by_id.return_value = mock_workspace - - mocker.patch( - "ragitect.api.v1.chat.WorkspaceRepository", - return_value=mock_ws_repo, - ) - - mock_doc_repo = mocker.AsyncMock() - mock_doc_repo.get_by_workspace_count.return_value = 5 - - mocker.patch( - "ragitect.api.v1.chat.DocumentRepository", - return_value=mock_doc_repo, - ) - - # Setup LangGraph streaming mocks (Story 4.3) - setup_langgraph_streaming_mocks(mocker) - - response = await async_client.post( - f"/api/v1/workspaces/{workspace_id}/chat/stream", - json={"message": "What is Python?"}, - ) - - assert response.status_code == 200 - content = response.text - - # With LangGraph streaming, citations are handled by the adapter - # Just verify we got a valid SSE stream - assert "data:" in content - - class TestLangGraphStreaming: """Tests for LangGraph streaming adapter integration in chat endpoint.""" @@ -980,7 +709,7 @@ async def test_langgraph_adapter_integration(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post( @@ -1032,7 +761,7 @@ async def test_sse_event_format_compliance(self, async_client, mocker): return_value=mock_doc_repo, ) - # Setup LangGraph streaming mocks (Story 4.3) + # Setup LangGraph streaming mocks setup_langgraph_streaming_mocks(mocker) response = await async_client.post(